diff --git a/chia/_tests/core/data_layer/test_data_rpc.py b/chia/_tests/core/data_layer/test_data_rpc.py index d637154a2bc8..e678ca89f5bc 100644 --- a/chia/_tests/core/data_layer/test_data_rpc.py +++ b/chia/_tests/core/data_layer/test_data_rpc.py @@ -46,12 +46,13 @@ ProofLayer, Status, StoreProofs, + get_delta_filename_path, + get_full_tree_filename_path, key_hash, leaf_hash, ) from chia.data_layer.data_layer_wallet import DataLayerWallet, verify_offer from chia.data_layer.data_store import DataStore -from chia.data_layer.download_data import get_delta_filename_path, get_full_tree_filename_path from chia.rpc.data_layer_rpc_api import DataLayerRpcApi from chia.rpc.data_layer_rpc_client import DataLayerRpcClient from chia.rpc.wallet_rpc_api import WalletRpcApi @@ -129,6 +130,7 @@ async def init_data_layer( manage_data_interval: int = 5, maximum_full_file_count: Optional[int] = None, group_files_by_store: bool = False, + enable_batch_autoinsert: bool = True, ) -> AsyncIterator[DataLayer]: async with init_data_layer_service( wallet_rpc_port, @@ -137,7 +139,7 @@ async def init_data_layer( wallet_service, manage_data_interval, maximum_full_file_count, - True, + enable_batch_autoinsert, group_files_by_store, ) as data_layer_service: yield data_layer_service._api.data_layer @@ -251,6 +253,7 @@ def create_mnemonic(seed: bytes = b"ab") -> str: @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_create_insert_get( self_hostname: str, one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices, tmp_path: Path ) -> None: @@ -330,6 +333,7 @@ async def test_create_insert_get( @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_upsert( self_hostname: str, one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices, tmp_path: Path ) -> None: @@ -360,6 +364,7 @@ async def test_upsert( @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_create_double_insert( self_hostname: str, one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices, tmp_path: Path ) -> None: @@ -397,6 +402,7 @@ async def test_create_double_insert( @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_keys_values_ancestors( self_hostname: str, one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices, tmp_path: Path ) -> None: @@ -443,7 +449,7 @@ async def test_keys_values_ancestors( assert key in dic val = await data_rpc_api.get_ancestors({"id": store_id.hex(), "hash": val["keys_values"][4]["hash"]}) # todo better assertions for get_ancestors result - assert len(val["ancestors"]) == 3 + assert len(val["ancestors"]) == 1 res_before = await data_rpc_api.get_root({"id": store_id.hex()}) assert res_before["confirmed"] is True assert res_before["timestamp"] > 0 @@ -473,6 +479,7 @@ async def test_keys_values_ancestors( @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_get_roots( self_hostname: str, one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices, tmp_path: Path ) -> None: @@ -526,6 +533,7 @@ async def test_get_roots( @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_get_root_history( self_hostname: str, one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices, tmp_path: Path ) -> None: @@ -580,6 +588,7 @@ async def test_get_root_history( @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_get_kv_diff( self_hostname: str, one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices, tmp_path: Path ) -> None: @@ -647,13 +656,19 @@ async def test_get_kv_diff( @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_batch_update_matches_single_operations( self_hostname: str, one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices, tmp_path: Path ) -> None: wallet_rpc_api, full_node_api, wallet_rpc_port, ph, bt = await init_wallet_and_node( self_hostname, one_wallet_and_one_simulator_services ) - async with init_data_layer(wallet_rpc_port=wallet_rpc_port, bt=bt, db_path=tmp_path) as data_layer: + async with init_data_layer( + wallet_rpc_port=wallet_rpc_port, + bt=bt, + db_path=tmp_path, + enable_batch_autoinsert=False, + ) as data_layer: data_rpc_api = DataLayerRpcApi(data_layer) res = await data_rpc_api.create_data_store({}) assert res is not None @@ -719,6 +734,7 @@ async def test_batch_update_matches_single_operations( @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_get_owned_stores( self_hostname: str, one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices, tmp_path: Path ) -> None: @@ -759,6 +775,7 @@ async def test_get_owned_stores( @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_subscriptions( self_hostname: str, one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices, tmp_path: Path ) -> None: @@ -1591,6 +1608,7 @@ class MakeAndTakeReference: indirect=["offer_setup"], ) @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_make_and_take_offer(offer_setup: OfferSetup, reference: MakeAndTakeReference) -> None: offer_setup = await populate_offer_setup(offer_setup=offer_setup, count=reference.entries_to_insert) @@ -1703,6 +1721,7 @@ async def test_make_and_then_take_offer_invalid_inclusion_key( @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_verify_offer_rpc_valid(bare_data_layer_api: DataLayerRpcApi) -> None: reference = make_one_take_one_reference @@ -1721,6 +1740,7 @@ async def test_verify_offer_rpc_valid(bare_data_layer_api: DataLayerRpcApi) -> N @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_verify_offer_rpc_invalid(bare_data_layer_api: DataLayerRpcApi) -> None: reference = make_one_take_one_reference broken_taker_offer = copy.deepcopy(reference.make_offer_response) @@ -1741,6 +1761,7 @@ async def test_verify_offer_rpc_invalid(bare_data_layer_api: DataLayerRpcApi) -> @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_make_offer_failure_rolls_back_db(offer_setup: OfferSetup) -> None: # TODO: only needs the maker and db? wallet? reference = make_one_take_one_reference @@ -1783,6 +1804,7 @@ async def test_make_offer_failure_rolls_back_db(offer_setup: OfferSetup) -> None ], ) @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_make_and_cancel_offer(offer_setup: OfferSetup, reference: MakeAndTakeReference) -> None: offer_setup = await populate_offer_setup(offer_setup=offer_setup, count=reference.entries_to_insert) @@ -1859,6 +1881,7 @@ async def test_make_and_cancel_offer(offer_setup: OfferSetup, reference: MakeAnd ], ) @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_make_and_cancel_offer_then_update( offer_setup: OfferSetup, reference: MakeAndTakeReference, secure: bool ) -> None: @@ -1948,6 +1971,7 @@ async def test_make_and_cancel_offer_then_update( ], ) @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_make_and_cancel_offer_not_secure_clears_pending_roots( offer_setup: OfferSetup, reference: MakeAndTakeReference, @@ -1990,6 +2014,7 @@ async def test_make_and_cancel_offer_not_secure_clears_pending_roots( @pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") @pytest.mark.anyio +@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") async def test_get_sync_status( self_hostname: str, one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices, tmp_path: Path ) -> None: @@ -3028,7 +3053,12 @@ async def test_pagination_cmds( wallet_rpc_api, full_node_api, wallet_rpc_port, ph, bt = await init_wallet_and_node( self_hostname, one_wallet_and_one_simulator_services ) - async with init_data_layer_service(wallet_rpc_port=wallet_rpc_port, bt=bt, db_path=tmp_path) as data_layer_service: + async with init_data_layer_service( + wallet_rpc_port=wallet_rpc_port, + bt=bt, + db_path=tmp_path, + enable_batch_autoinsert=False, + ) as data_layer_service: assert data_layer_service.rpc_server is not None rpc_port = data_layer_service.rpc_server.listen_port data_layer = data_layer_service._api.data_layer @@ -3180,7 +3210,7 @@ async def test_pagination_cmds( if max_page_size is None or max_page_size == 100: assert keys == { "keys": ["0x61616161", "0x6161"], - "root_hash": "0x889a4a61b17be799ae9d36831246672ef857a24091f54481431a83309d4e890e", + "root_hash": "0x3f4ae7b8e10ef48b3114843537d5def989ee0a3b6568af7e720a71730f260fa1", "success": True, "total_bytes": 6, "total_pages": 1, @@ -3200,7 +3230,7 @@ async def test_pagination_cmds( "value": "0x6161", }, ], - "root_hash": "0x889a4a61b17be799ae9d36831246672ef857a24091f54481431a83309d4e890e", + "root_hash": "0x3f4ae7b8e10ef48b3114843537d5def989ee0a3b6568af7e720a71730f260fa1", "success": True, "total_bytes": 9, "total_pages": 1, @@ -3217,7 +3247,7 @@ async def test_pagination_cmds( elif max_page_size == 5: assert keys == { "keys": ["0x61616161"], - "root_hash": "0x889a4a61b17be799ae9d36831246672ef857a24091f54481431a83309d4e890e", + "root_hash": "0x3f4ae7b8e10ef48b3114843537d5def989ee0a3b6568af7e720a71730f260fa1", "success": True, "total_bytes": 6, "total_pages": 2, @@ -3231,7 +3261,7 @@ async def test_pagination_cmds( "value": "0x61", } ], - "root_hash": "0x889a4a61b17be799ae9d36831246672ef857a24091f54481431a83309d4e890e", + "root_hash": "0x3f4ae7b8e10ef48b3114843537d5def989ee0a3b6568af7e720a71730f260fa1", "success": True, "total_bytes": 9, "total_pages": 2, @@ -3689,6 +3719,7 @@ async def test_multistore_update( await data_rpc_api.multistore_batch_update({"store_updates": store_updates}) +@pytest.mark.skip @pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules") @pytest.mark.anyio async def test_unsubmitted_batch_db_migration( diff --git a/chia/_tests/core/data_layer/test_data_store.py b/chia/_tests/core/data_layer/test_data_store.py index 79565c5684c1..b2593b606fff 100644 --- a/chia/_tests/core/data_layer/test_data_store.py +++ b/chia/_tests/core/data_layer/test_data_store.py @@ -11,20 +11,18 @@ from dataclasses import dataclass from pathlib import Path from random import Random -from typing import Any, Callable, Optional, cast +from typing import Any, Callable, Optional import aiohttp -import aiosqlite import pytest from chia._tests.core.data_layer.util import Example, add_0123_example, add_01234567_example from chia._tests.util.misc import BenchmarkRunner, Marks, boolean_datacases, datacases -from chia.data_layer.data_layer_errors import KeyNotFoundError, NodeHashError, TreeGenerationIncrementingError +from chia.data_layer.data_layer_errors import KeyNotFoundError, TreeGenerationIncrementingError from chia.data_layer.data_layer_util import ( DiffData, InternalNode, Node, - NodeType, OperationType, ProofOfInclusion, ProofOfInclusionLayer, @@ -35,20 +33,19 @@ Subscription, TerminalNode, _debug_dump, - leaf_hash, -) -from chia.data_layer.data_store import DataStore -from chia.data_layer.download_data import ( get_delta_filename_path, get_full_tree_filename_path, - insert_from_delta_file, - insert_into_data_store_from_file, - write_files_for_root, + leaf_hash, ) +from chia.data_layer.data_store import DataStore +from chia.data_layer.download_data import insert_from_delta_file, write_files_for_root +from chia.data_layer.util.benchmark import generate_datastore +from chia.data_layer.util.merkle_blob import RawLeafMerkleNode from chia.types.blockchain_format.program import Program from chia.types.blockchain_format.sized_bytes import bytes32 from chia.util.byte_types import hexstr_to_bytes from chia.util.db_wrapper import DBWrapper2, generate_in_memory_db_uri +from chia.util.lru_cache import LRUCache log = logging.getLogger(__name__) @@ -57,27 +54,18 @@ table_columns: dict[str, list[str]] = { - "node": ["hash", "node_type", "left", "right", "key", "value"], "root": ["tree_id", "generation", "node_hash", "status"], + "subscriptions": ["tree_id", "url", "ignore_till", "num_consecutive_failures", "from_wallet"], + "schema": ["version_id", "applied_at"], + "merkleblob": ["hash", "blob", "store_id"], + "ids": ["kv_id", "blob", "store_id"], + "hashes": ["hash", "kid", "vid", "store_id"], + "nodes": ["store_id", "hash", "root_hash", "generation", "idx"], } # TODO: Someday add tests for malformed DB data to make sure we handle it gracefully # and with good error messages. - - -@pytest.mark.anyio -async def test_valid_node_values_fixture_are_valid(data_store: DataStore, valid_node_values: dict[str, Any]) -> None: - async with data_store.db_wrapper.writer() as writer: - await writer.execute( - """ - INSERT INTO node(hash, node_type, left, right, key, value) - VALUES(:hash, :node_type, :left, :right, :key, :value) - """, - valid_node_values, - ) - - @pytest.mark.parametrize(argnames=["table_name", "expected_columns"], argvalues=table_columns.items()) @pytest.mark.anyio async def test_create_creates_tables_and_columns( @@ -211,46 +199,6 @@ async def test_get_tree_generation_returns_none_when_none_available( await raw_data_store.get_tree_generation(store_id=store_id) -@pytest.mark.anyio -async def test_insert_internal_node_does_nothing_if_matching(data_store: DataStore, store_id: bytes32) -> None: - await add_01234567_example(data_store=data_store, store_id=store_id) - - kv_node = await data_store.get_node_by_key(key=b"\x04", store_id=store_id) - ancestors = await data_store.get_ancestors(node_hash=kv_node.hash, store_id=store_id) - parent = ancestors[0] - - async with data_store.db_wrapper.reader() as reader: - cursor = await reader.execute("SELECT * FROM node") - before = await cursor.fetchall() - - await data_store._insert_internal_node(left_hash=parent.left_hash, right_hash=parent.right_hash) - - async with data_store.db_wrapper.reader() as reader: - cursor = await reader.execute("SELECT * FROM node") - after = await cursor.fetchall() - - assert after == before - - -@pytest.mark.anyio -async def test_insert_terminal_node_does_nothing_if_matching(data_store: DataStore, store_id: bytes32) -> None: - await add_01234567_example(data_store=data_store, store_id=store_id) - - kv_node = await data_store.get_node_by_key(key=b"\x04", store_id=store_id) - - async with data_store.db_wrapper.reader() as reader: - cursor = await reader.execute("SELECT * FROM node") - before = await cursor.fetchall() - - await data_store._insert_terminal_node(key=kv_node.key, value=kv_node.value) - - async with data_store.db_wrapper.reader() as reader: - cursor = await reader.execute("SELECT * FROM node") - after = await cursor.fetchall() - - assert after == before - - @pytest.mark.anyio async def test_build_a_tree( data_store: DataStore, @@ -293,7 +241,7 @@ async def test_get_ancestors(data_store: DataStore, store_id: bytes32) -> None: "c852ecd8fb61549a0a42f9eb9dde65e6c94a01934dbd9c1d35ab94e2a0ae58e2", ] - ancestors_2 = await data_store.get_ancestors_optimized(node_hash=reference_node_hash, store_id=store_id) + ancestors_2 = await data_store.get_ancestors(node_hash=reference_node_hash, store_id=store_id) assert ancestors == ancestors_2 @@ -306,6 +254,10 @@ async def test_get_ancestors_optimized(data_store: DataStore, store_id: bytes32) first_insertions = [True, False, True, False, True, True, False, True, False, True, True, False, False, True, False] deleted_all = False node_count = 0 + node_hashes: list[bytes32] = [] + hash_to_key: dict[bytes32, bytes] = {} + node_hash: Optional[bytes32] + for i in range(1000): is_insert = False if i <= 14: @@ -318,12 +270,10 @@ async def test_get_ancestors_optimized(data_store: DataStore, store_id: bytes32) if not deleted_all: while node_count > 0: node_count -= 1 - seed = bytes32(b"0" * 32) - node_hash = await data_store.get_terminal_node_for_seed(store_id, seed) + node_hash = random.choice(node_hashes) assert node_hash is not None - node = await data_store.get_node(node_hash) - assert isinstance(node, TerminalNode) - await data_store.delete(key=node.key, store_id=store_id, status=Status.COMMITTED) + await data_store.delete(key=hash_to_key[node_hash], store_id=store_id, status=Status.COMMITTED) + node_hashes.remove(node_hash) deleted_all = True is_insert = True else: @@ -335,10 +285,10 @@ async def test_get_ancestors_optimized(data_store: DataStore, store_id: bytes32) key = (i % 200).to_bytes(4, byteorder="big") value = (i % 200).to_bytes(4, byteorder="big") seed = Program.to((key, value)).get_tree_hash() - node_hash = await data_store.get_terminal_node_for_seed(store_id, seed) + node_hash = None if len(node_hashes) == 0 else random.choice(node_hashes) if is_insert: node_count += 1 - side = None if node_hash is None else data_store.get_side_for_seed(seed) + side = None if node_hash is None else (Side.LEFT if seed[0] < 128 else Side.RIGHT) insert_result = await data_store.insert( key=key, @@ -346,10 +296,11 @@ async def test_get_ancestors_optimized(data_store: DataStore, store_id: bytes32) store_id=store_id, reference_node_hash=node_hash, side=side, - use_optimized=False, status=Status.COMMITTED, ) node_hash = insert_result.node_hash + hash_to_key[node_hash] = key + node_hashes.append(node_hash) if node_hash is not None: generation = await data_store.get_tree_generation(store_id=store_id) current_ancestors = await data_store.get_ancestors(node_hash=node_hash, store_id=store_id) @@ -357,34 +308,28 @@ async def test_get_ancestors_optimized(data_store: DataStore, store_id: bytes32) else: node_count -= 1 assert node_hash is not None - node = await data_store.get_node(node_hash) - assert isinstance(node, TerminalNode) - await data_store.delete(key=node.key, store_id=store_id, use_optimized=False, status=Status.COMMITTED) + node_hashes.remove(node_hash) + await data_store.delete(key=hash_to_key[node_hash], store_id=store_id, status=Status.COMMITTED) for generation, node_hash, expected_ancestors in ancestors: - current_ancestors = await data_store.get_ancestors_optimized( + current_ancestors = await data_store.get_ancestors( node_hash=node_hash, store_id=store_id, generation=generation ) assert current_ancestors == expected_ancestors @pytest.mark.anyio -@pytest.mark.parametrize( - "use_optimized", - [True, False], -) @pytest.mark.parametrize( "num_batches", [1, 5, 10, 25], ) -async def test_batch_update( +async def test_batch_update_against_single_operations( data_store: DataStore, store_id: bytes32, - use_optimized: bool, tmp_path: Path, num_batches: int, ) -> None: - total_operations = 1000 if use_optimized else 100 + total_operations = 1000 num_ops_per_batch = total_operations // num_batches saved_batches: list[list[dict[str, Any]]] = [] saved_kv: list[list[TerminalNode]] = [] @@ -412,7 +357,6 @@ async def test_batch_update( key=key, value=value, store_id=store_id, - use_optimized=use_optimized, status=Status.COMMITTED, ) else: @@ -420,7 +364,6 @@ async def test_batch_update( key=key, new_value=value, store_id=store_id, - use_optimized=use_optimized, status=Status.COMMITTED, ) action = "insert" if op_type == "insert" else "upsert" @@ -432,7 +375,6 @@ async def test_batch_update( await single_op_data_store.delete( key=key, store_id=store_id, - use_optimized=use_optimized, status=Status.COMMITTED, ) batch.append({"action": "delete", "key": key}) @@ -446,7 +388,6 @@ async def test_batch_update( key=key, new_value=new_value, store_id=store_id, - use_optimized=use_optimized, status=Status.COMMITTED, ) keys_values[key] = new_value @@ -469,38 +410,13 @@ async def test_batch_update( assert {node.key: node.value for node in current_kv} == { node.key: node.value for node in saved_kv[batch_number] } - queue: list[bytes32] = [root.node_hash] - ancestors: dict[bytes32, bytes32] = {} - while len(queue) > 0: - node_hash = queue.pop(0) - expected_ancestors = [] - ancestor = node_hash - while ancestor in ancestors: - ancestor = ancestors[ancestor] - expected_ancestors.append(ancestor) - result_ancestors = await data_store.get_ancestors_optimized(node_hash, store_id) - assert [node.hash for node in result_ancestors] == expected_ancestors - node = await data_store.get_node(node_hash) - if isinstance(node, InternalNode): - queue.append(node.left_hash) - queue.append(node.right_hash) - ancestors[node.left_hash] = node_hash - ancestors[node.right_hash] = node_hash all_kv = await data_store.get_keys_values(store_id) assert {node.key: node.value for node in all_kv} == keys_values @pytest.mark.anyio -@pytest.mark.parametrize( - "use_optimized", - [True, False], -) -async def test_upsert_ignores_existing_arguments( - data_store: DataStore, - store_id: bytes32, - use_optimized: bool, -) -> None: +async def test_upsert_ignores_existing_arguments(data_store: DataStore, store_id: bytes32) -> None: key = b"key" value = b"value1" @@ -508,7 +424,6 @@ async def test_upsert_ignores_existing_arguments( key=key, value=value, store_id=store_id, - use_optimized=use_optimized, status=Status.COMMITTED, ) node = await data_store.get_node_by_key(key, store_id) @@ -519,7 +434,6 @@ async def test_upsert_ignores_existing_arguments( key=key, new_value=new_value, store_id=store_id, - use_optimized=use_optimized, status=Status.COMMITTED, ) node = await data_store.get_node_by_key(key, store_id) @@ -529,7 +443,6 @@ async def test_upsert_ignores_existing_arguments( key=key, new_value=new_value, store_id=store_id, - use_optimized=use_optimized, status=Status.COMMITTED, ) node = await data_store.get_node_by_key(key, store_id) @@ -540,7 +453,6 @@ async def test_upsert_ignores_existing_arguments( key=key2, new_value=value, store_id=store_id, - use_optimized=use_optimized, status=Status.COMMITTED, ) node = await data_store.get_node_by_key(key2, store_id) @@ -575,30 +487,24 @@ async def test_insert_batch_reference_and_side( ) assert new_root_hash is not None, "batch insert failed or failed to update root" - parent = await data_store.get_node(new_root_hash) - assert isinstance(parent, InternalNode) + merkle_blob = await data_store.get_merkle_blob(new_root_hash) + nodes_with_indexes = merkle_blob.get_nodes_with_indexes() + nodes = [pair[1] for pair in nodes_with_indexes] + assert len(nodes) == 3 + assert isinstance(nodes[1], RawLeafMerkleNode) + assert isinstance(nodes[2], RawLeafMerkleNode) + left_terminal_node = await data_store.get_terminal_node(nodes[1].key, nodes[1].value, store_id) + right_terminal_node = await data_store.get_terminal_node(nodes[2].key, nodes[2].value, store_id) if side == Side.LEFT: - child = await data_store.get_node(parent.left_hash) - assert parent.left_hash == child.hash + assert left_terminal_node.key == b"key2" + assert right_terminal_node.key == b"key1" elif side == Side.RIGHT: - child = await data_store.get_node(parent.right_hash) - assert parent.right_hash == child.hash + assert left_terminal_node.key == b"key1" + assert right_terminal_node.key == b"key2" else: # pragma: no cover raise Exception("invalid side for test") -@pytest.mark.anyio -async def test_ancestor_table_unique_inserts(data_store: DataStore, store_id: bytes32) -> None: - await add_0123_example(data_store=data_store, store_id=store_id) - hash_1 = bytes32.from_hexstr("0763561814685fbf92f6ca71fbb1cb11821951450d996375c239979bd63e9535") - hash_2 = bytes32.from_hexstr("924be8ff27e84cba17f5bc918097f8410fab9824713a4668a21c8e060a8cab40") - await data_store._insert_ancestor_table(hash_1, hash_2, store_id, 2) - await data_store._insert_ancestor_table(hash_1, hash_2, store_id, 2) - with pytest.raises(Exception, match="^Requested insertion of ancestor"): - await data_store._insert_ancestor_table(hash_1, hash_1, store_id, 2) - await data_store._insert_ancestor_table(hash_1, hash_2, store_id, 2) - - @pytest.mark.anyio async def test_get_pairs( data_store: DataStore, @@ -609,7 +515,7 @@ async def test_get_pairs( pairs = await data_store.get_keys_values(store_id=store_id) - assert [node.hash for node in pairs] == example.terminal_nodes + assert {node.hash for node in pairs} == set(example.terminal_nodes) @pytest.mark.anyio @@ -662,37 +568,6 @@ async def test_inserting_duplicate_key_fails( ) -@pytest.mark.anyio() -async def test_inserting_invalid_length_hash_raises_original_exception( - data_store: DataStore, -) -> None: - with pytest.raises(aiosqlite.IntegrityError): - # casting since we are testing an invalid case - await data_store._insert_node( - node_hash=cast(bytes32, b"\x05"), - node_type=NodeType.TERMINAL, - left_hash=None, - right_hash=None, - key=b"\x06", - value=b"\x07", - ) - - -@pytest.mark.anyio() -async def test_inserting_invalid_length_ancestor_hash_raises_original_exception( - data_store: DataStore, - store_id: bytes32, -) -> None: - with pytest.raises(aiosqlite.IntegrityError): - # casting since we are testing an invalid case - await data_store._insert_ancestor_table( - left_hash=bytes32(b"\x01" * 32), - right_hash=bytes32(b"\x02" * 32), - store_id=store_id, - generation=0, - ) - - @pytest.mark.anyio() async def test_autoinsert_balances_from_scratch(data_store: DataStore, store_id: bytes32) -> None: random = Random() @@ -705,7 +580,7 @@ async def test_autoinsert_balances_from_scratch(data_store: DataStore, store_id: insert_result = await data_store.autoinsert(key, value, store_id, status=Status.COMMITTED) hashes.append(insert_result.node_hash) - heights = {node_hash: len(await data_store.get_ancestors_optimized(node_hash, store_id)) for node_hash in hashes} + heights = {node_hash: len(await data_store.get_ancestors(node_hash, store_id)) for node_hash in hashes} too_tall = {hash: height for hash, height in heights.items() if height > 14} assert too_tall == {} assert 11 <= statistics.mean(heights.values()) <= 12 @@ -715,7 +590,7 @@ async def test_autoinsert_balances_from_scratch(data_store: DataStore, store_id: async def test_autoinsert_balances_gaps(data_store: DataStore, store_id: bytes32) -> None: random = Random() random.seed(101, version=2) - hashes = [] + hashes: list[bytes32] = [] for i in range(2000): key = (i + 100).to_bytes(4, byteorder="big") @@ -723,7 +598,7 @@ async def test_autoinsert_balances_gaps(data_store: DataStore, store_id: bytes32 if i == 0 or i > 10: insert_result = await data_store.autoinsert(key, value, store_id, status=Status.COMMITTED) else: - reference_node_hash = await data_store.get_terminal_node_for_seed(store_id, bytes32.zeros) + reference_node_hash = hashes[-1] insert_result = await data_store.insert( key=key, value=value, @@ -732,11 +607,11 @@ async def test_autoinsert_balances_gaps(data_store: DataStore, store_id: bytes32 side=Side.LEFT, status=Status.COMMITTED, ) - ancestors = await data_store.get_ancestors_optimized(insert_result.node_hash, store_id) + ancestors = await data_store.get_ancestors(insert_result.node_hash, store_id) assert len(ancestors) == i hashes.append(insert_result.node_hash) - heights = {node_hash: len(await data_store.get_ancestors_optimized(node_hash, store_id)) for node_hash in hashes} + heights = {node_hash: len(await data_store.get_ancestors(node_hash, store_id)) for node_hash in hashes} too_tall = {hash: height for hash, height in heights.items() if height > 14} assert too_tall == {} assert 11 <= statistics.mean(heights.values()) <= 12 @@ -1036,46 +911,6 @@ async def test_check_roots_are_incrementing_gap(raw_data_store: DataStore) -> No await raw_data_store._check_roots_are_incrementing() -@pytest.mark.anyio -async def test_check_hashes_internal(raw_data_store: DataStore) -> None: - async with raw_data_store.db_wrapper.writer() as writer: - await writer.execute( - "INSERT INTO node(hash, node_type, left, right) VALUES(:hash, :node_type, :left, :right)", - { - "hash": a_bytes_32, - "node_type": NodeType.INTERNAL, - "left": a_bytes_32, - "right": a_bytes_32, - }, - ) - - with pytest.raises( - NodeHashError, - match=r"\n +000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f$", - ): - await raw_data_store._check_hashes() - - -@pytest.mark.anyio -async def test_check_hashes_terminal(raw_data_store: DataStore) -> None: - async with raw_data_store.db_wrapper.writer() as writer: - await writer.execute( - "INSERT INTO node(hash, node_type, key, value) VALUES(:hash, :node_type, :key, :value)", - { - "hash": a_bytes_32, - "node_type": NodeType.TERMINAL, - "key": Program.to((1, 2)).as_bin(), - "value": Program.to((1, 2)).as_bin(), - }, - ) - - with pytest.raises( - NodeHashError, - match=r"\n +000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f$", - ): - await raw_data_store._check_hashes() - - @pytest.mark.anyio async def test_root_state(data_store: DataStore, store_id: bytes32) -> None: key = b"\x01\x02" @@ -1127,28 +962,29 @@ async def test_kv_diff(data_store: DataStore, store_id: bytes32) -> None: insertions = 0 expected_diff: set[DiffData] = set() root_start = None + for i in range(500): key = (i + 100).to_bytes(4, byteorder="big") value = (i + 200).to_bytes(4, byteorder="big") seed = leaf_hash(key=key, value=value) - node_hash = await data_store.get_terminal_node_for_seed(store_id, seed) + node = await data_store.get_terminal_node_for_seed(seed, store_id) + side_seed = bytes(seed)[0] + side = None if node is None else (Side.LEFT if side_seed < 128 else Side.RIGHT) + if random.randint(0, 4) > 0 or insertions < 10: insertions += 1 - side = None if node_hash is None else data_store.get_side_for_seed(seed) - + reference_node_hash = node.hash if node is not None else None await data_store.insert( key=key, value=value, store_id=store_id, - reference_node_hash=node_hash, - side=side, status=Status.COMMITTED, + reference_node_hash=reference_node_hash, + side=side, ) if i > 200: expected_diff.add(DiffData(OperationType.INSERT, key, value)) else: - assert node_hash is not None - node = await data_store.get_node(node_hash) assert isinstance(node, TerminalNode) await data_store.delete(key=node.key, store_id=store_id, status=Status.COMMITTED) if i > 200: @@ -1275,6 +1111,39 @@ async def test_subscribe_unsubscribe(data_store: DataStore, store_id: bytes32) - ] +@pytest.mark.anyio +async def test_unsubscribe_clears_databases(data_store: DataStore, store_id: bytes32) -> None: + num_inserts = 100 + await data_store.subscribe(Subscription(store_id, [])) + for value in range(num_inserts): + await data_store.insert( + key=value.to_bytes(4, byteorder="big"), + value=value.to_bytes(4, byteorder="big"), + store_id=store_id, + reference_node_hash=None, + side=None, + status=Status.COMMITTED, + ) + await data_store.add_node_hashes(store_id) + + tables = ["merkleblob", "ids", "hashes", "nodes"] + for table in tables: + async with data_store.db_wrapper.reader() as reader: + async with reader.execute(f"SELECT COUNT(*) FROM {table}") as cursor: + row_count = await cursor.fetchone() + assert row_count is not None + assert row_count[0] > 0 + + await data_store.unsubscribe(store_id) + + for table in tables: + async with data_store.db_wrapper.reader() as reader: + async with reader.execute(f"SELECT COUNT(*) FROM {table}") as cursor: + row_count = await cursor.fetchone() + assert row_count is not None + assert row_count[0] == 0 + + @pytest.mark.anyio async def test_server_selection(data_store: DataStore, store_id: bytes32) -> None: start_timestamp = 1000 @@ -1451,6 +1320,7 @@ async def test_data_server_files( counter += 1 await data_store_server.insert_batch(store_id, changelist, status=Status.COMMITTED) root = await data_store_server.get_tree_root(store_id) + await data_store_server.add_node_hashes(store_id) await write_files_for_root( data_store_server, store_id, root, tmp_path, 0, group_by_store=group_files_by_store ) @@ -1466,7 +1336,7 @@ async def test_data_server_files( else: filename = get_delta_filename_path(tmp_path, store_id, root.node_hash, generation, group_files_by_store) assert filename.exists() - await insert_into_data_store_from_file(data_store, store_id, root.node_hash, tmp_path.joinpath(filename)) + await data_store.insert_into_data_store_from_file(store_id, root.node_hash, tmp_path.joinpath(filename)) current_root = await data_store.get_tree_root(store_id=store_id) assert current_root.node_hash == root.node_hash generation += 1 @@ -1614,6 +1484,22 @@ async def test_benchmark_batch_insert_speed( ) +@datacases( + BatchInsertBenchmarkCase( + pre=0, + count=50, + limit=2, + ), +) +@pytest.mark.anyio +async def test_benchmark_tool( + benchmark_runner: BenchmarkRunner, + case: BatchInsertBenchmarkCase, +) -> None: + with benchmark_runner.assert_runtime(seconds=case.limit): + await generate_datastore(case.count) + + @datacases( BatchesInsertBenchmarkCase( count=50, @@ -1648,189 +1534,6 @@ async def test_benchmark_batch_insert_speed_multiple_batches( ) -@pytest.mark.anyio -async def test_delete_store_data(raw_data_store: DataStore) -> None: - store_id = bytes32.zeros - store_id_2 = bytes32(b"\0" * 31 + b"\1") - await raw_data_store.create_tree(store_id=store_id, status=Status.COMMITTED) - await raw_data_store.create_tree(store_id=store_id_2, status=Status.COMMITTED) - total_keys = 4 - keys = [key.to_bytes(4, byteorder="big") for key in range(total_keys)] - batch1 = [ - {"action": "insert", "key": keys[0], "value": keys[0]}, - {"action": "insert", "key": keys[1], "value": keys[1]}, - ] - batch2 = batch1.copy() - batch1.append({"action": "insert", "key": keys[2], "value": keys[2]}) - batch2.append({"action": "insert", "key": keys[3], "value": keys[3]}) - assert batch1 != batch2 - await raw_data_store.insert_batch(store_id, batch1, status=Status.COMMITTED) - await raw_data_store.insert_batch(store_id_2, batch2, status=Status.COMMITTED) - keys_values_before = await raw_data_store.get_keys_values(store_id_2) - async with raw_data_store.db_wrapper.reader() as reader: - result = await reader.execute("SELECT * FROM node") - nodes = await result.fetchall() - kv_nodes_before = {} - for node in nodes: - if node["key"] is not None: - kv_nodes_before[node["key"]] = node["value"] - assert [kv_nodes_before[key] for key in keys] == keys - await raw_data_store.delete_store_data(store_id) - # Deleting from `node` table doesn't alter other stores. - keys_values_after = await raw_data_store.get_keys_values(store_id_2) - assert keys_values_before == keys_values_after - async with raw_data_store.db_wrapper.reader() as reader: - result = await reader.execute("SELECT * FROM node") - nodes = await result.fetchall() - kv_nodes_after = {} - for node in nodes: - if node["key"] is not None: - kv_nodes_after[node["key"]] = node["value"] - for i in range(total_keys): - if i != 2: - assert kv_nodes_after[keys[i]] == keys[i] - else: - # `keys[2]` was only present in the first store. - assert keys[i] not in kv_nodes_after - assert not await raw_data_store.store_id_exists(store_id) - await raw_data_store.delete_store_data(store_id_2) - async with raw_data_store.db_wrapper.reader() as reader: - async with reader.execute("SELECT COUNT(*) FROM node") as cursor: - row_count = await cursor.fetchone() - assert row_count is not None - assert row_count[0] == 0 - assert not await raw_data_store.store_id_exists(store_id_2) - - -@pytest.mark.anyio -async def test_delete_store_data_multiple_stores(raw_data_store: DataStore) -> None: - # Make sure inserting and deleting the same data works - for repetition in range(2): - num_stores = 50 - total_keys = 150 - keys_deleted_per_store = 3 - store_ids = [bytes32(i.to_bytes(32, byteorder="big")) for i in range(num_stores)] - for store_id in store_ids: - await raw_data_store.create_tree(store_id=store_id, status=Status.COMMITTED) - original_keys = [key.to_bytes(4, byteorder="big") for key in range(total_keys)] - batches = [] - for i in range(num_stores): - batch = [ - {"action": "insert", "key": key, "value": key} for key in original_keys[i * keys_deleted_per_store :] - ] - batches.append(batch) - - for store_id, batch in zip(store_ids, batches): - await raw_data_store.insert_batch(store_id, batch, status=Status.COMMITTED) - - for tree_index in range(num_stores): - async with raw_data_store.db_wrapper.reader() as reader: - result = await reader.execute("SELECT * FROM node") - nodes = await result.fetchall() - - keys = {node["key"] for node in nodes if node["key"] is not None} - assert len(keys) == total_keys - tree_index * keys_deleted_per_store - keys_after_index = set(original_keys[tree_index * keys_deleted_per_store :]) - keys_before_index = set(original_keys[: tree_index * keys_deleted_per_store]) - assert keys_after_index.issubset(keys) - assert keys.isdisjoint(keys_before_index) - await raw_data_store.delete_store_data(store_ids[tree_index]) - - async with raw_data_store.db_wrapper.reader() as reader: - async with reader.execute("SELECT COUNT(*) FROM node") as cursor: - row_count = await cursor.fetchone() - assert row_count is not None - assert row_count[0] == 0 - - -@pytest.mark.parametrize("common_keys_count", [1, 250, 499]) -@pytest.mark.anyio -async def test_delete_store_data_with_common_values(raw_data_store: DataStore, common_keys_count: int) -> None: - store_id_1 = bytes32(b"\x00" * 31 + b"\x01") - store_id_2 = bytes32(b"\x00" * 31 + b"\x02") - - await raw_data_store.create_tree(store_id=store_id_1, status=Status.COMMITTED) - await raw_data_store.create_tree(store_id=store_id_2, status=Status.COMMITTED) - - key_offset = 1000 - total_keys_per_store = 500 - assert common_keys_count < key_offset - common_keys = {key.to_bytes(4, byteorder="big") for key in range(common_keys_count)} - unique_keys_1 = { - (key + key_offset).to_bytes(4, byteorder="big") for key in range(total_keys_per_store - common_keys_count) - } - unique_keys_2 = { - (key + (2 * key_offset)).to_bytes(4, byteorder="big") for key in range(total_keys_per_store - common_keys_count) - } - - batch1 = [{"action": "insert", "key": key, "value": key} for key in common_keys.union(unique_keys_1)] - batch2 = [{"action": "insert", "key": key, "value": key} for key in common_keys.union(unique_keys_2)] - - await raw_data_store.insert_batch(store_id_1, batch1, status=Status.COMMITTED) - await raw_data_store.insert_batch(store_id_2, batch2, status=Status.COMMITTED) - - await raw_data_store.delete_store_data(store_id_1) - async with raw_data_store.db_wrapper.reader() as reader: - result = await reader.execute("SELECT * FROM node") - nodes = await result.fetchall() - - keys = {node["key"] for node in nodes if node["key"] is not None} - # Since one store got all its keys deleted, we're left only with the keys of the other store. - assert len(keys) == total_keys_per_store - assert keys.intersection(unique_keys_1) == set() - assert keys.symmetric_difference(common_keys.union(unique_keys_2)) == set() - - -@pytest.mark.anyio -@pytest.mark.parametrize("pending_status", [Status.PENDING, Status.PENDING_BATCH]) -async def test_delete_store_data_protects_pending_roots(raw_data_store: DataStore, pending_status: Status) -> None: - num_stores = 5 - total_keys = 15 - store_ids = [bytes32(i.to_bytes(32, byteorder="big")) for i in range(num_stores)] - for store_id in store_ids: - await raw_data_store.create_tree(store_id=store_id, status=Status.COMMITTED) - original_keys = [key.to_bytes(4, byteorder="big") for key in range(total_keys)] - batches = [] - keys_per_pending_root = 2 - - for i in range(num_stores - 1): - start_index = i * keys_per_pending_root - end_index = (i + 1) * keys_per_pending_root - batch = [{"action": "insert", "key": key, "value": key} for key in original_keys[start_index:end_index]] - batches.append(batch) - for store_id, batch in zip(store_ids, batches): - await raw_data_store.insert_batch(store_id, batch, status=pending_status) - - store_id = store_ids[-1] - batch = [{"action": "insert", "key": key, "value": key} for key in original_keys] - await raw_data_store.insert_batch(store_id, batch, status=Status.COMMITTED) - - async with raw_data_store.db_wrapper.reader() as reader: - result = await reader.execute("SELECT * FROM node") - nodes = await result.fetchall() - - keys = {node["key"] for node in nodes if node["key"] is not None} - assert keys == set(original_keys) - - await raw_data_store.delete_store_data(store_id) - async with raw_data_store.db_wrapper.reader() as reader: - result = await reader.execute("SELECT * FROM node") - nodes = await result.fetchall() - - keys = {node["key"] for node in nodes if node["key"] is not None} - assert keys == set(original_keys[: (num_stores - 1) * keys_per_pending_root]) - - for index in range(num_stores - 1): - store_id = store_ids[index] - root = await raw_data_store.get_pending_root(store_id) - assert root is not None - await raw_data_store.change_root_status(root, Status.COMMITTED) - kv = await raw_data_store.get_keys_values(store_id=store_id) - start_index = index * keys_per_pending_root - end_index = (index + 1) * keys_per_pending_root - assert {pair.key for pair in kv} == set(original_keys[start_index:end_index]) - - @pytest.mark.anyio @boolean_datacases(name="group_files_by_store", true="group by singleton", false="don't group by singleton") @pytest.mark.parametrize("max_full_files", [1, 2, 5]) @@ -1843,7 +1546,6 @@ async def test_insert_from_delta_file( group_files_by_store: bool, max_full_files: int, ) -> None: - await data_store.create_tree(store_id=store_id, status=Status.COMMITTED) num_files = 5 for generation in range(num_files): key = generation.to_bytes(4, byteorder="big") @@ -1854,30 +1556,34 @@ async def test_insert_from_delta_file( store_id=store_id, status=Status.COMMITTED, ) + await data_store.add_node_hashes(store_id) root = await data_store.get_tree_root(store_id=store_id) - assert root.generation == num_files + 1 + assert root.generation == num_files root_hashes = [] tmp_path_1 = tmp_path.joinpath("1") tmp_path_2 = tmp_path.joinpath("2") - for generation in range(1, num_files + 2): + for generation in range(1, num_files + 1): root = await data_store.get_tree_root(store_id=store_id, generation=generation) await write_files_for_root(data_store, store_id, root, tmp_path_1, 0, False, group_files_by_store) root_hashes.append(bytes32.zeros if root.node_hash is None else root.node_hash) store_path = tmp_path_1.joinpath(f"{store_id}") if group_files_by_store else tmp_path_1 with os.scandir(store_path) as entries: filenames = {entry.name for entry in entries} - assert len(filenames) == 2 * (num_files + 1) + assert len(filenames) == 2 * num_files for filename in filenames: if "full" in filename: store_path.joinpath(filename).unlink() with os.scandir(store_path) as entries: filenames = {entry.name for entry in entries} - assert len(filenames) == num_files + 1 + assert len(filenames) == num_files kv_before = await data_store.get_keys_values(store_id=store_id) await data_store.rollback_to_generation(store_id, 0) + async with data_store.db_wrapper.writer() as writer: + await writer.execute("DELETE FROM merkleblob") + root = await data_store.get_tree_root(store_id=store_id) assert root.generation == 0 os.rename(store_path, tmp_path_2) @@ -1950,10 +1656,10 @@ async def mock_http_download_2( assert success root = await data_store.get_tree_root(store_id=store_id) - assert root.generation == num_files + 1 + assert root.generation == num_files with os.scandir(store_path) as entries: filenames = {entry.name for entry in entries} - assert len(filenames) == num_files + 1 + max_full_files # 6 deltas and max_full_files full files + assert len(filenames) == num_files + max_full_files - 1 kv = await data_store.get_keys_values(store_id=store_id) assert kv == kv_before @@ -2013,6 +1719,7 @@ async def test_insert_from_delta_file_correct_file_exists( store_id=store_id, status=Status.COMMITTED, ) + await data_store.add_node_hashes(store_id) root = await data_store.get_tree_root(store_id=store_id) assert root.generation == num_files + 1 @@ -2035,6 +1742,8 @@ async def test_insert_from_delta_file_correct_file_exists( await data_store.rollback_to_generation(store_id, 0) root = await data_store.get_tree_root(store_id=store_id) assert root.generation == 0 + async with data_store.db_wrapper.writer() as writer: + await writer.execute("DELETE FROM merkleblob") sinfo = ServerInfo("http://127.0.0.1/8003", 0, 0) success = await insert_from_delta_file( @@ -2079,6 +1788,7 @@ async def test_insert_from_delta_file_incorrect_file_exists( store_id=store_id, status=Status.COMMITTED, ) + await data_store.add_node_hashes(store_id) root = await data_store.get_tree_root(store_id=store_id) assert root.generation == 2 @@ -2187,19 +1897,6 @@ async def test_update_keys(data_store: DataStore, store_id: bytes32, use_upsert: num_keys += new_keys -@pytest.mark.anyio -async def test_migration_unknown_version(data_store: DataStore) -> None: - async with data_store.db_wrapper.writer() as writer: - await writer.execute( - "INSERT INTO schema(version_id) VALUES(:version_id)", - { - "version_id": "unknown version", - }, - ) - with pytest.raises(Exception, match="Unknown version"): - await data_store.migrate_db() - - async def _check_ancestors( data_store: DataStore, store_id: bytes32, root_hash: bytes32 ) -> dict[bytes32, Optional[bytes32]]: @@ -2230,195 +1927,58 @@ async def _check_ancestors( @pytest.mark.anyio -async def test_build_ancestor_table(data_store: DataStore, store_id: bytes32) -> None: - num_values = 1000 - changelist: list[dict[str, Any]] = [] - for value in range(num_values): - value_bytes = value.to_bytes(4, byteorder="big") - changelist.append({"action": "upsert", "key": value_bytes, "value": value_bytes}) - await data_store.insert_batch( - store_id=store_id, - changelist=changelist, - status=Status.PENDING, - ) - - pending_root = await data_store.get_pending_root(store_id=store_id) - assert pending_root is not None - assert pending_root.node_hash is not None - await data_store.change_root_status(pending_root, Status.COMMITTED) - await data_store.build_ancestor_table_for_latest_root(store_id=store_id) - - assert pending_root.node_hash is not None - await _check_ancestors(data_store, store_id, pending_root.node_hash) - - -@pytest.mark.anyio -async def test_sparse_ancestor_table(data_store: DataStore, store_id: bytes32) -> None: - num_values = 100 - for value in range(num_values): - value_bytes = value.to_bytes(4, byteorder="big") - await data_store.autoinsert( - key=value_bytes, - value=value_bytes, - store_id=store_id, - status=Status.COMMITTED, +async def test_migration_unknown_version(data_store: DataStore, tmp_path: Path) -> None: + async with data_store.db_wrapper.writer() as writer: + await writer.execute( + "INSERT INTO schema(version_id) VALUES(:version_id)", + { + "version_id": "unknown version", + }, ) - root = await data_store.get_tree_root(store_id=store_id) - assert root.node_hash is not None - ancestors = await _check_ancestors(data_store, store_id, root.node_hash) - - # Check the ancestor table is sparse - root_generation = root.generation - current_generation_count = 0 - previous_generation_count = 0 - for node_hash, ancestor_hash in ancestors.items(): - async with data_store.db_wrapper.reader() as reader: - if ancestor_hash is not None: - cursor = await reader.execute( - "SELECT MAX(generation) AS generation FROM ancestors WHERE hash == :hash AND ancestor == :ancestor", - {"hash": node_hash, "ancestor": ancestor_hash}, - ) - else: - cursor = await reader.execute( - "SELECT MAX(generation) AS generation FROM ancestors WHERE hash == :hash AND ancestor IS NULL", - {"hash": node_hash}, - ) - row = await cursor.fetchone() - assert row is not None - generation = row["generation"] - assert generation <= root_generation - if generation == root_generation: - current_generation_count += 1 - else: - previous_generation_count += 1 - - assert current_generation_count == 15 - assert previous_generation_count == 184 - - -async def get_all_nodes(data_store: DataStore, store_id: bytes32) -> list[Node]: - root = await data_store.get_tree_root(store_id) - assert root.node_hash is not None - root_node = await data_store.get_node(root.node_hash) - nodes: list[Node] = [] - queue: list[Node] = [root_node] - - while len(queue) > 0: - node = queue.pop(0) - nodes.append(node) - if isinstance(node, InternalNode): - left_node = await data_store.get_node(node.left_hash) - right_node = await data_store.get_node(node.right_hash) - queue.append(left_node) - queue.append(right_node) - - return nodes - - -@pytest.mark.anyio -async def test_get_nodes(data_store: DataStore, store_id: bytes32) -> None: - num_values = 50 - changelist: list[dict[str, Any]] = [] - - for value in range(num_values): - value_bytes = value.to_bytes(4, byteorder="big") - changelist.append({"action": "upsert", "key": value_bytes, "value": value_bytes}) - await data_store.insert_batch( - store_id=store_id, - changelist=changelist, - status=Status.COMMITTED, - ) - - expected_nodes = await get_all_nodes(data_store, store_id) - nodes = await data_store.get_nodes([node.hash for node in expected_nodes]) - assert nodes == expected_nodes - - node_hash = bytes32.zeros - node_hash_2 = bytes32([0] * 31 + [1]) - with pytest.raises(Exception, match=f"^Nodes not found for hashes: {node_hash.hex()}, {node_hash_2.hex()}"): - await data_store.get_nodes([node_hash, node_hash_2] + [node.hash for node in expected_nodes]) + with pytest.raises(Exception, match="Unknown version"): + await data_store.migrate_db(tmp_path) +@boolean_datacases(name="group_files_by_store", false="group by singleton", true="don't group by singleton") @pytest.mark.anyio -@pytest.mark.parametrize("pre", [0, 2048]) -@pytest.mark.parametrize("batch_size", [25, 100, 500]) -async def test_get_leaf_at_minimum_height( +async def test_migration( data_store: DataStore, store_id: bytes32, - pre: int, - batch_size: int, + group_files_by_store: bool, + tmp_path: Path, ) -> None: - num_values = 1000 - value_offset = 1000000 - all_min_leafs: set[TerminalNode] = set() + num_batches = 10 + num_ops_per_batch = 100 + keys: list[bytes] = [] + counter = 0 + random = Random() + random.seed(100, version=2) - if pre > 0: - # This builds a complete binary tree, in order to test more than one batch in the queue before finding the leaf + for batch in range(num_batches): changelist: list[dict[str, Any]] = [] + for operation in range(num_ops_per_batch): + if random.randint(0, 4) > 0 or len(keys) == 0: + key = counter.to_bytes(4, byteorder="big") + value = (2 * counter).to_bytes(4, byteorder="big") + keys.append(key) + changelist.append({"action": "insert", "key": key, "value": value}) + else: + key = random.choice(keys) + keys.remove(key) + changelist.append({"action": "delete", "key": key}) + counter += 1 + await data_store.insert_batch(store_id, changelist, status=Status.COMMITTED) + root = await data_store.get_tree_root(store_id) + await data_store.add_node_hashes(store_id) + await write_files_for_root(data_store, store_id, root, tmp_path, 0, group_by_store=group_files_by_store) - for value in range(pre): - value_bytes = (value * value).to_bytes(8, byteorder="big") - changelist.append({"action": "upsert", "key": value_bytes, "value": value_bytes}) - await data_store.insert_batch( - store_id=store_id, - changelist=changelist, - status=Status.COMMITTED, - ) - - for value in range(num_values): - value_bytes = value.to_bytes(4, byteorder="big") - # Use autoinsert instead of `insert_batch` to get a more randomly shaped tree - await data_store.autoinsert( - key=value_bytes, - value=value_bytes, - store_id=store_id, - status=Status.COMMITTED, - ) - - if (value + 1) % batch_size == 0: - hash_to_parent: dict[bytes32, InternalNode] = {} - root = await data_store.get_tree_root(store_id) - assert root.node_hash is not None - min_leaf = await data_store.get_leaf_at_minimum_height(root.node_hash, hash_to_parent) - all_nodes = await get_all_nodes(data_store, store_id) - heights: dict[bytes32, int] = {} - heights[root.node_hash] = 0 - min_leaf_height = None - - for node in all_nodes: - if isinstance(node, InternalNode): - heights[node.left_hash] = heights[node.hash] + 1 - heights[node.right_hash] = heights[node.hash] + 1 - else: - if min_leaf_height is not None: - min_leaf_height = min(min_leaf_height, heights[node.hash]) - else: - min_leaf_height = heights[node.hash] - - assert min_leaf_height is not None - if pre > 0: - assert min_leaf_height >= 11 - for node in all_nodes: - if isinstance(node, TerminalNode): - assert node == min_leaf - assert heights[min_leaf.hash] == min_leaf_height - break - if node.left_hash in hash_to_parent: - assert hash_to_parent[node.left_hash] == node - if node.right_hash in hash_to_parent: - assert hash_to_parent[node.right_hash] == node - - # Push down the min height leaf, so on the next iteration we get a different leaf - pushdown_height = 20 - for repeat in range(pushdown_height): - value_bytes = (value + (repeat + 1) * value_offset).to_bytes(4, byteorder="big") - await data_store.insert( - key=value_bytes, - value=value_bytes, - store_id=store_id, - reference_node_hash=min_leaf.hash, - side=Side.RIGHT, - status=Status.COMMITTED, - ) - assert min_leaf not in all_min_leafs - all_min_leafs.add(min_leaf) + kv_before = await data_store.get_keys_values(store_id=store_id) + async with data_store.db_wrapper.writer(foreign_key_enforcement_enabled=False) as writer: + tables = [table for table in table_columns.keys() if table != "root"] + for table in tables: + await writer.execute(f"DELETE FROM {table}") + + data_store.recent_merkle_blobs = LRUCache(capacity=128) + assert await data_store.get_keys_values(store_id=store_id) == [] + await data_store.migrate_db(tmp_path) + assert await data_store.get_keys_values(store_id=store_id) == kv_before diff --git a/chia/_tests/core/data_layer/test_data_store_schema.py b/chia/_tests/core/data_layer/test_data_store_schema.py index 6e6ea10eb928..018d1c480b5b 100644 --- a/chia/_tests/core/data_layer/test_data_store_schema.py +++ b/chia/_tests/core/data_layer/test_data_store_schema.py @@ -1,171 +1,17 @@ from __future__ import annotations import sqlite3 -from typing import Any import pytest -from chia._tests.core.data_layer.util import add_01234567_example, create_valid_node_values -from chia.data_layer.data_layer_util import NodeType, Side, Status +from chia._tests.core.data_layer.util import add_01234567_example +from chia.data_layer.data_layer_util import Status from chia.data_layer.data_store import DataStore from chia.types.blockchain_format.sized_bytes import bytes32 pytestmark = pytest.mark.data_layer -@pytest.mark.anyio -async def test_node_update_fails(data_store: DataStore, store_id: bytes32) -> None: - await add_01234567_example(data_store=data_store, store_id=store_id) - node = await data_store.get_node_by_key(key=b"\x04", store_id=store_id) - - async with data_store.db_wrapper.writer() as writer: - with pytest.raises(sqlite3.IntegrityError, match=r"^updates not allowed to the node table$"): - await writer.execute( - "UPDATE node SET value = :value WHERE hash == :hash", - { - "hash": node.hash, - "value": node.value, - }, - ) - - -@pytest.mark.parametrize(argnames="length", argvalues=sorted(set(range(50)) - {32})) -@pytest.mark.anyio -async def test_node_hash_must_be_32( - data_store: DataStore, - store_id: bytes32, - length: int, - valid_node_values: dict[str, Any], -) -> None: - valid_node_values["hash"] = bytes([0] * length) - - async with data_store.db_wrapper.writer() as writer: - with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"): - await writer.execute( - """ - INSERT INTO node(hash, node_type, left, right, key, value) - VALUES(:hash, :node_type, :left, :right, :key, :value) - """, - valid_node_values, - ) - - -@pytest.mark.anyio -async def test_node_hash_must_not_be_null( - data_store: DataStore, - store_id: bytes32, - valid_node_values: dict[str, Any], -) -> None: - valid_node_values["hash"] = None - - async with data_store.db_wrapper.writer() as writer: - with pytest.raises(sqlite3.IntegrityError, match=r"^NOT NULL constraint failed: node.hash$"): - await writer.execute( - """ - INSERT INTO node(hash, node_type, left, right, key, value) - VALUES(:hash, :node_type, :left, :right, :key, :value) - """, - valid_node_values, - ) - - -@pytest.mark.anyio -async def test_node_type_must_be_valid( - data_store: DataStore, - node_type: NodeType, - bad_node_type: int, - valid_node_values: dict[str, Any], -) -> None: - valid_node_values["node_type"] = bad_node_type - - async with data_store.db_wrapper.writer() as writer: - with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"): - await writer.execute( - """ - INSERT INTO node(hash, node_type, left, right, key, value) - VALUES(:hash, :node_type, :left, :right, :key, :value) - """, - valid_node_values, - ) - - -@pytest.mark.parametrize(argnames="side", argvalues=Side) -@pytest.mark.anyio -async def test_node_internal_child_not_null(data_store: DataStore, store_id: bytes32, side: Side) -> None: - await add_01234567_example(data_store=data_store, store_id=store_id) - node_a = await data_store.get_node_by_key(key=b"\x02", store_id=store_id) - node_b = await data_store.get_node_by_key(key=b"\x04", store_id=store_id) - - values = create_valid_node_values(node_type=NodeType.INTERNAL, left_hash=node_a.hash, right_hash=node_b.hash) - - if side == Side.LEFT: - values["left"] = None - elif side == Side.RIGHT: - values["right"] = None - - async with data_store.db_wrapper.writer() as writer: - with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"): - await writer.execute( - """ - INSERT INTO node(hash, node_type, left, right, key, value) - VALUES(:hash, :node_type, :left, :right, :key, :value) - """, - values, - ) - - -@pytest.mark.parametrize(argnames="bad_child_hash", argvalues=[b"\x01" * 32, b"\0" * 31, b""]) -@pytest.mark.parametrize(argnames="side", argvalues=Side) -@pytest.mark.anyio -async def test_node_internal_must_be_valid_reference( - data_store: DataStore, - store_id: bytes32, - bad_child_hash: bytes, - side: Side, -) -> None: - await add_01234567_example(data_store=data_store, store_id=store_id) - node_a = await data_store.get_node_by_key(key=b"\x02", store_id=store_id) - node_b = await data_store.get_node_by_key(key=b"\x04", store_id=store_id) - - values = create_valid_node_values(node_type=NodeType.INTERNAL, left_hash=node_a.hash, right_hash=node_b.hash) - - if side == Side.LEFT: - values["left"] = bad_child_hash - elif side == Side.RIGHT: - values["right"] = bad_child_hash - else: # pragma: no cover - assert False - - async with data_store.db_wrapper.writer() as writer: - with pytest.raises(sqlite3.IntegrityError, match=r"^FOREIGN KEY constraint failed$"): - await writer.execute( - """ - INSERT INTO node(hash, node_type, left, right, key, value) - VALUES(:hash, :node_type, :left, :right, :key, :value) - """, - values, - ) - - -@pytest.mark.parametrize(argnames="key_or_value", argvalues=["key", "value"]) -@pytest.mark.anyio -async def test_node_terminal_key_value_not_null(data_store: DataStore, store_id: bytes32, key_or_value: str) -> None: - await add_01234567_example(data_store=data_store, store_id=store_id) - - values = create_valid_node_values(node_type=NodeType.TERMINAL) - values[key_or_value] = None - - async with data_store.db_wrapper.writer() as writer: - with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"): - await writer.execute( - """ - INSERT INTO node(hash, node_type, left, right, key, value) - VALUES(:hash, :node_type, :left, :right, :key, :value) - """, - values, - ) - - @pytest.mark.parametrize(argnames="length", argvalues=sorted(set(range(50)) - {32})) @pytest.mark.anyio async def test_root_store_id_must_be_32(data_store: DataStore, store_id: bytes32, length: int) -> None: @@ -250,21 +96,6 @@ async def test_root_generation_must_not_be_null(data_store: DataStore, store_id: ) -@pytest.mark.anyio -async def test_root_node_hash_must_reference(data_store: DataStore) -> None: - values = {"tree_id": bytes32.zeros, "generation": 0, "node_hash": bytes32.zeros, "status": Status.PENDING} - - async with data_store.db_wrapper.writer() as writer: - with pytest.raises(sqlite3.IntegrityError, match=r"^FOREIGN KEY constraint failed$"): - await writer.execute( - """ - INSERT INTO root(tree_id, generation, node_hash, status) - VALUES(:tree_id, :generation, :node_hash, :status) - """, - values, - ) - - @pytest.mark.parametrize(argnames="bad_status", argvalues=sorted(set(range(-20, 20)) - {*Status})) @pytest.mark.anyio async def test_root_status_must_be_valid(data_store: DataStore, store_id: bytes32, bad_status: int) -> None: @@ -319,44 +150,6 @@ async def test_root_store_id_generation_must_be_unique(data_store: DataStore, st ) -@pytest.mark.parametrize(argnames="length", argvalues=sorted(set(range(50)) - {32})) -@pytest.mark.anyio -async def test_ancestors_ancestor_must_be_32( - data_store: DataStore, - store_id: bytes32, - length: int, -) -> None: - async with data_store.db_wrapper.writer() as writer: - node_hash = await data_store._insert_terminal_node(key=b"\x00", value=b"\x01") - with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"): - await writer.execute( - """ - INSERT INTO ancestors(hash, ancestor, tree_id, generation) - VALUES(:hash, :ancestor, :tree_id, :generation) - """, - {"hash": node_hash, "ancestor": bytes([0] * length), "tree_id": bytes32.zeros, "generation": 0}, - ) - - -@pytest.mark.parametrize(argnames="length", argvalues=sorted(set(range(50)) - {32})) -@pytest.mark.anyio -async def test_ancestors_store_id_must_be_32( - data_store: DataStore, - store_id: bytes32, - length: int, -) -> None: - async with data_store.db_wrapper.writer() as writer: - node_hash = await data_store._insert_terminal_node(key=b"\x00", value=b"\x01") - with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"): - await writer.execute( - """ - INSERT INTO ancestors(hash, ancestor, tree_id, generation) - VALUES(:hash, :ancestor, :tree_id, :generation) - """, - {"hash": node_hash, "ancestor": bytes32.zeros, "tree_id": bytes([0] * length), "generation": 0}, - ) - - @pytest.mark.parametrize(argnames="length", argvalues=sorted(set(range(50)) - {32})) @pytest.mark.anyio async def test_subscriptions_store_id_must_be_32( diff --git a/chia/_tests/core/data_layer/test_merkle_blob.py b/chia/_tests/core/data_layer/test_merkle_blob.py index 4cd3067d26db..df2698d0be80 100644 --- a/chia/_tests/core/data_layer/test_merkle_blob.py +++ b/chia/_tests/core/data_layer/test_merkle_blob.py @@ -370,6 +370,7 @@ def test_get_raw_node_raises_for_invalid_indexes(index: TreeIndex) -> None: with pytest.raises(InvalidIndexError): merkle_blob.get_raw_node(index) + with pytest.raises(InvalidIndexError): merkle_blob.get_metadata(index) diff --git a/chia/data_layer/data_layer.py b/chia/data_layer/data_layer.py index 3aef5354ac76..04ef73ec6564 100644 --- a/chia/data_layer/data_layer.py +++ b/chia/data_layer/data_layer.py @@ -41,18 +41,14 @@ TerminalNode, Unspecified, UnsubscribeData, + get_delta_filename_path, + get_full_tree_filename_path, leaf_hash, unspecified, ) from chia.data_layer.data_layer_wallet import DataLayerWallet, Mirror, SingletonRecord, verify_offer from chia.data_layer.data_store import DataStore -from chia.data_layer.download_data import ( - delete_full_file_if_exists, - get_delta_filename_path, - get_full_tree_filename_path, - insert_from_delta_file, - write_files_for_root, -) +from chia.data_layer.download_data import delete_full_file_if_exists, insert_from_delta_file, write_files_for_root from chia.rpc.rpc_server import StateChangedProtocol, default_get_connections from chia.rpc.wallet_request_types import LogIn from chia.rpc.wallet_rpc_client import WalletRpcClient @@ -199,7 +195,7 @@ async def manage(self) -> AsyncIterator[None]: async with DataStore.managed(database=self.db_path, sql_log_path=sql_log_path) as self._data_store: self._wallet_rpc = await self.wallet_rpc_init - await self._data_store.migrate_db() + await self._data_store.migrate_db(self.server_files_location) self.periodically_manage_data_task = asyncio.create_task(self.periodically_manage_data()) try: yield @@ -254,7 +250,6 @@ async def batch_update( ) -> Optional[TransactionRecord]: status = Status.PENDING if submit_on_chain else Status.PENDING_BATCH await self.batch_insert(store_id=store_id, changelist=changelist, status=status) - await self.data_store.clean_node_table() if submit_on_chain: return await self.publish_update(store_id=store_id, fee=fee) @@ -288,8 +283,6 @@ async def multistore_batch_update( status = Status.PENDING if submit_on_chain else Status.PENDING_BATCH await self.batch_insert(store_id=store_id, changelist=changelist, status=status) - await self.data_store.clean_node_table() - if submit_on_chain: update_dictionary: dict[bytes32, bytes32] = {} for store_id in store_ids: @@ -532,7 +525,6 @@ async def _update_confirmation_status(self, store_id: bytes32) -> None: and pending_root.status == Status.PENDING ): await self.data_store.change_root_status(pending_root, Status.COMMITTED) - await self.data_store.build_ancestor_table_for_latest_root(store_id=store_id) await self.data_store.clear_pending_roots(store_id=store_id) async def fetch_and_validate(self, store_id: bytes32) -> None: @@ -835,8 +827,6 @@ async def process_unsubscribe(self, store_id: bytes32, retain_data: bool) -> Non # stop tracking first, then unsubscribe from the data store await self.wallet_rpc.dl_stop_tracking(store_id) await self.data_store.unsubscribe(store_id) - if not retain_data: - await self.data_store.delete_store_data(store_id) self.log.info(f"Unsubscribed to {store_id}") for file_path in paths: @@ -1137,7 +1127,6 @@ async def make_offer( verify_offer(maker=offer.maker, taker=offer.taker, summary=summary) - await self.data_store.clean_node_table() return offer async def take_offer( @@ -1197,8 +1186,6 @@ async def take_offer( }, } - await self.data_store.clean_node_table() - # Excluding wallet from transaction since failures in the wallet may occur # after the transaction is submitted to the chain. If we roll back data we # may lose published data. diff --git a/chia/data_layer/data_layer_util.py b/chia/data_layer/data_layer_util.py index d7561ea21a26..6d7eb688136d 100644 --- a/chia/data_layer/data_layer_util.py +++ b/chia/data_layer/data_layer_util.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field from enum import Enum, IntEnum from hashlib import sha256 +from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union import aiosqlite @@ -48,6 +49,44 @@ def key_hash(key: bytes) -> bytes32: return bytes32(sha256(b"\1" + key).digest()) +def get_full_tree_filename(store_id: bytes32, node_hash: bytes32, generation: int, group_by_store: bool = False) -> str: + if group_by_store: + return f"{store_id}/{node_hash}-full-{generation}-v1.0.dat" + return f"{store_id}-{node_hash}-full-{generation}-v1.0.dat" + + +def get_delta_filename(store_id: bytes32, node_hash: bytes32, generation: int, group_by_store: bool = False) -> str: + if group_by_store: + return f"{store_id}/{node_hash}-delta-{generation}-v1.0.dat" + return f"{store_id}-{node_hash}-delta-{generation}-v1.0.dat" + + +def get_full_tree_filename_path( + foldername: Path, + store_id: bytes32, + node_hash: bytes32, + generation: int, + group_by_store: bool = False, +) -> Path: + if group_by_store: + path = foldername.joinpath(f"{store_id}") + return path.joinpath(f"{node_hash}-full-{generation}-v1.0.dat") + return foldername.joinpath(f"{store_id}-{node_hash}-full-{generation}-v1.0.dat") + + +def get_delta_filename_path( + foldername: Path, + store_id: bytes32, + node_hash: bytes32, + generation: int, + group_by_store: bool = False, +) -> Path: + if group_by_store: + path = foldername.joinpath(f"{store_id}") + return path.joinpath(f"{node_hash}-delta-{generation}-v1.0.dat") + return foldername.joinpath(f"{store_id}-{node_hash}-delta-{generation}-v1.0.dat") + + @dataclasses.dataclass(frozen=True) class PaginationData: total_pages: int diff --git a/chia/data_layer/data_store.py b/chia/data_layer/data_store.py index ead6f196cfd2..d01a26235289 100644 --- a/chia/data_layer/data_store.py +++ b/chia/data_layer/data_store.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import copy import logging from collections import defaultdict from collections.abc import AsyncIterator, Awaitable @@ -11,7 +12,7 @@ import aiosqlite -from chia.data_layer.data_layer_errors import KeyNotFoundError, NodeHashError, TreeGenerationIncrementingError +from chia.data_layer.data_layer_errors import KeyNotFoundError, TreeGenerationIncrementingError from chia.data_layer.data_layer_util import ( DiffData, InsertResult, @@ -24,7 +25,6 @@ NodeType, OperationType, ProofOfInclusion, - ProofOfInclusionLayer, Root, SerializedNode, ServerInfo, @@ -33,6 +33,7 @@ Subscription, TerminalNode, Unspecified, + get_delta_filename_path, get_hashes_for_page, internal_hash, key_hash, @@ -40,9 +41,19 @@ row_to_node, unspecified, ) -from chia.types.blockchain_format.program import Program +from chia.data_layer.util.merkle_blob import KVId, MerkleBlob, NodeMetadata +from chia.data_layer.util.merkle_blob import NodeType as NodeTypeMerkleBlob +from chia.data_layer.util.merkle_blob import ( + RawInternalMerkleNode, + RawLeafMerkleNode, + TreeIndex, + null_parent, + pack_raw_node, + undefined_index, +) from chia.types.blockchain_format.sized_bytes import bytes32 -from chia.util.db_wrapper import SQLITE_MAX_VARIABLE_NUMBER, DBWrapper2 +from chia.util.db_wrapper import DBWrapper2 +from chia.util.lru_cache import LRUCache log = logging.getLogger(__name__) @@ -56,6 +67,7 @@ class DataStore: """A key/value store with the pairs being terminal nodes in a CLVM object tree.""" db_wrapper: DBWrapper2 + recent_merkle_blobs: LRUCache[bytes32, MerkleBlob] @classmethod @contextlib.asynccontextmanager @@ -75,46 +87,10 @@ async def managed( row_factory=aiosqlite.Row, log_path=sql_log_path, ) as db_wrapper: - self = cls(db_wrapper=db_wrapper) + recent_merkle_blobs: LRUCache[bytes32, MerkleBlob] = LRUCache(capacity=128) + self = cls(db_wrapper=db_wrapper, recent_merkle_blobs=recent_merkle_blobs) async with db_wrapper.writer() as writer: - await writer.execute( - f""" - CREATE TABLE IF NOT EXISTS node( - hash BLOB PRIMARY KEY NOT NULL CHECK(length(hash) == 32), - node_type INTEGER NOT NULL CHECK( - ( - node_type == {int(NodeType.INTERNAL)} - AND left IS NOT NULL - AND right IS NOT NULL - AND key IS NULL - AND value IS NULL - ) - OR - ( - node_type == {int(NodeType.TERMINAL)} - AND left IS NULL - AND right IS NULL - AND key IS NOT NULL - AND value IS NOT NULL - ) - ), - left BLOB REFERENCES node, - right BLOB REFERENCES node, - key BLOB, - value BLOB - ) - """ - ) - await writer.execute( - """ - CREATE TRIGGER IF NOT EXISTS no_node_updates - BEFORE UPDATE ON node - BEGIN - SELECT RAISE(FAIL, 'updates not allowed to the node table'); - END - """ - ) await writer.execute( f""" CREATE TABLE IF NOT EXISTS root( @@ -124,25 +100,7 @@ async def managed( status INTEGER NOT NULL CHECK( {" OR ".join(f"status == {status}" for status in Status)} ), - PRIMARY KEY(tree_id, generation), - FOREIGN KEY(node_hash) REFERENCES node(hash) - ) - """ - ) - # TODO: Add ancestor -> hash relationship, this might involve temporarily - # deferring the foreign key enforcement due to the insertion order - # and the node table also enforcing a similar relationship in the - # other direction. - # FOREIGN KEY(ancestor) REFERENCES ancestors(ancestor) - await writer.execute( - """ - CREATE TABLE IF NOT EXISTS ancestors( - hash BLOB NOT NULL REFERENCES node, - ancestor BLOB CHECK(length(ancestor) == 32), - tree_id BLOB NOT NULL CHECK(length(tree_id) == 32), - generation INTEGER NOT NULL, - PRIMARY KEY(hash, tree_id, generation), - FOREIGN KEY(ancestor) REFERENCES node(hash) + PRIMARY KEY(tree_id, generation) ) """ ) @@ -173,7 +131,56 @@ async def managed( ) await writer.execute( """ - CREATE INDEX IF NOT EXISTS node_key_index ON node(key) + CREATE TABLE IF NOT EXISTS merkleblob( + hash BLOB, + blob BLOB, + store_id BLOB NOT NULL CHECK(length(store_id) == 32), + PRIMARY KEY(store_id, hash) + ) + """ + ) + await writer.execute( + """ + CREATE TABLE IF NOT EXISTS ids( + kv_id INTEGER PRIMARY KEY AUTOINCREMENT, + blob BLOB, + store_id BLOB NOT NULL CHECK(length(store_id) == 32) + ) + """ + ) + await writer.execute( + """ + CREATE TABLE IF NOT EXISTS hashes( + hash BLOB, + kid INTEGER, + vid INTEGER, + store_id BLOB NOT NULL CHECK(length(store_id) == 32), + PRIMARY KEY(store_id, hash), + FOREIGN KEY (kid) REFERENCES ids(kv_id), + FOREIGN KEY (vid) REFERENCES ids(kv_id) + ) + """ + ) + await writer.execute( + """ + CREATE TABLE IF NOT EXISTS nodes( + store_id BLOB NOT NULL CHECK(length(store_id) == 32), + hash BLOB NOT NULL, + root_hash BLOB NOT NULL, + generation INTEGER NOT NULL CHECK(generation >= 0), + idx INTEGER NOT NULL, + PRIMARY KEY(store_id, hash) + ) + """ + ) + await writer.execute( + """ + CREATE INDEX IF NOT EXISTS ids_blob_index ON ids(blob) + """ + ) + await writer.execute( + """ + CREATE INDEX IF NOT EXISTS ids_store_id_index ON ids(store_id) """ ) @@ -184,20 +191,85 @@ async def transaction(self) -> AsyncIterator[None]: async with self.db_wrapper.writer(): yield - async def migrate_db(self) -> None: + async def insert_into_data_store_from_file( + self, + store_id: bytes32, + root_hash: Optional[bytes32], + filename: Path, + ) -> None: + internal_nodes: dict[bytes32, tuple[bytes32, bytes32]] = {} + terminal_nodes: dict[bytes32, tuple[KVId, KVId]] = {} + + with open(filename, "rb") as reader: + while True: + chunk = b"" + while len(chunk) < 4: + size_to_read = 4 - len(chunk) + cur_chunk = reader.read(size_to_read) + if cur_chunk is None or cur_chunk == b"": + if size_to_read < 4: + raise Exception("Incomplete read of length.") + break + chunk += cur_chunk + if chunk == b"": + break + + size = int.from_bytes(chunk, byteorder="big") + serialize_nodes_bytes = b"" + while len(serialize_nodes_bytes) < size: + size_to_read = size - len(serialize_nodes_bytes) + cur_chunk = reader.read(size_to_read) + if cur_chunk is None or cur_chunk == b"": + raise Exception("Incomplete read of blob.") + serialize_nodes_bytes += cur_chunk + serialized_node = SerializedNode.from_bytes(serialize_nodes_bytes) + + node_type = NodeType.TERMINAL if serialized_node.is_terminal else NodeType.INTERNAL + if node_type == NodeType.INTERNAL: + node_hash = internal_hash(bytes32(serialized_node.value1), bytes32(serialized_node.value2)) + internal_nodes[node_hash] = (bytes32(serialized_node.value1), bytes32(serialized_node.value2)) + else: + kid, vid = await self.add_key_value(serialized_node.value1, serialized_node.value2, store_id) + node_hash = leaf_hash(serialized_node.value1, serialized_node.value2) + terminal_nodes[node_hash] = (kid, vid) + + merkle_blob = MerkleBlob(blob=bytearray()) + if root_hash is not None: + await self.build_blob_from_nodes(internal_nodes, terminal_nodes, root_hash, merkle_blob) + + await self.insert_root_from_merkle_blob(merkle_blob, store_id, Status.COMMITTED) + await self.add_node_hashes(store_id) + + async def migrate_db(self, server_files_location: Path) -> None: async with self.db_wrapper.reader() as reader: cursor = await reader.execute("SELECT * FROM schema") - row = await cursor.fetchone() - if row is not None: + rows = await cursor.fetchall() + all_versions = {"v1.0", "v2.0"} + + for row in rows: version = row["version_id"] - if version != "v1.0": + if version not in all_versions: raise Exception("Unknown version") - log.info(f"Found DB schema version {version}. No migration needed.") - return + if version == "v2.0": + log.info(f"Found DB schema version {version}. No migration needed.") + return + + version = "v2.0" + old_tables = ["node", "root", "ancestors"] + all_stores = await self.get_store_ids() + all_roots: list[list[Root]] = [] + for store_id in all_stores: + try: + root = await self.get_tree_root(store_id=store_id) + roots = await self.get_roots_between(store_id, 1, root.generation) + all_roots.append(roots + [root]) + except Exception as e: + if "unable to find root for id, generation" in str(e): + log.error(f"Cannot find roots for {store_id}. Skipping it.") + + log.info(f"Initiating migration to version {version}. Found {len(all_roots)} stores to migrate") - version = "v1.0" - log.info(f"Initiating migration to version {version}") - async with self.db_wrapper.writer(foreign_key_enforcement_enabled=False) as writer: + async with self.db_wrapper.writer() as writer: await writer.execute( f""" CREATE TABLE IF NOT EXISTS new_root( @@ -207,16 +279,294 @@ async def migrate_db(self) -> None: status INTEGER NOT NULL CHECK( {" OR ".join(f"status == {status}" for status in Status)} ), - PRIMARY KEY(tree_id, generation), - FOREIGN KEY(node_hash) REFERENCES node(hash) + PRIMARY KEY(tree_id, generation) ) """ ) - await writer.execute("INSERT INTO new_root SELECT * FROM root") - await writer.execute("DROP TABLE root") + for old_table in old_tables: + await writer.execute(f"DROP TABLE IF EXISTS {old_table}") await writer.execute("ALTER TABLE new_root RENAME TO root") await writer.execute("INSERT INTO schema (version_id) VALUES (?)", (version,)) - log.info(f"Finished migrating DB to version {version}") + log.info(f"Initialized new DB schema {version}.") + + for roots in all_roots: + assert len(roots) > 0 + store_id = roots[0].store_id + await self.create_tree(store_id=store_id, status=Status.COMMITTED) + + for root in roots: + recovery_filename: Optional[Path] = None + + for group_by_store in (True, False): + filename = get_delta_filename_path( + server_files_location, + store_id, + bytes32([0] * 32) if root.node_hash is None else root.node_hash, + root.generation, + group_by_store, + ) + + if filename.exists(): + log.info(f"Found filename {filename}. Recovering data from it") + recovery_filename = filename + break + + if recovery_filename is None: + log.error(f"Cannot find any recovery file for root {root}") + break + + try: + await self.insert_into_data_store_from_file(store_id, root.node_hash, recovery_filename) + except Exception as e: + log.error(f"Cannot recover data from {filename}: {e}") + break + + async def get_merkle_blob(self, root_hash: Optional[bytes32]) -> MerkleBlob: + if root_hash is None: + return MerkleBlob(blob=bytearray()) + + existing_blob = self.recent_merkle_blobs.get(root_hash) + if existing_blob is not None: + return copy.deepcopy(existing_blob) + + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute( + "SELECT blob FROM merkleblob WHERE hash == :root_hash", + { + "root_hash": root_hash, + }, + ) + + row = await cursor.fetchone() + + if row is None: + raise Exception(f"Cannot find merkle blob for root hash {root_hash.hex()}") + + merkle_blob = MerkleBlob(blob=bytearray(row["blob"])) + self.recent_merkle_blobs.put(root_hash, copy.deepcopy(merkle_blob)) + return merkle_blob + + async def insert_root_from_merkle_blob( + self, + merkle_blob: MerkleBlob, + store_id: bytes32, + status: Status, + old_root: Optional[Root] = None, + ) -> Root: + if not merkle_blob.empty(): + merkle_blob.calculate_lazy_hashes() + + root_hash = merkle_blob.get_root_hash() + if old_root is not None and old_root.node_hash == root_hash: + raise ValueError("Changelist resulted in no change to tree data") + + if root_hash is not None: + async with self.db_wrapper.writer() as writer: + await writer.execute( + """ + INSERT OR REPLACE INTO merkleblob (hash, blob, store_id) + VALUES (?, ?, ?) + """, + (root_hash, merkle_blob.blob, store_id), + ) + self.recent_merkle_blobs.put(root_hash, copy.deepcopy(merkle_blob)) + + return await self._insert_root(store_id, root_hash, status) + + async def get_kvid(self, blob: bytes, store_id: bytes32) -> Optional[KVId]: + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute( + "SELECT kv_id FROM ids WHERE blob = ? AND store_id = ?", + ( + blob, + store_id, + ), + ) + row = await cursor.fetchone() + + if row is None: + return None + + return KVId(row[0]) + + async def get_blob_from_kvid(self, kv_id: KVId, store_id: bytes32) -> Optional[bytes]: + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute( + "SELECT blob FROM ids WHERE kv_id = ? AND store_id = ?", + ( + kv_id, + store_id, + ), + ) + row = await cursor.fetchone() + + if row is None: + return None + + return bytes(row[0]) + + async def get_terminal_node(self, kid: KVId, vid: KVId, store_id: bytes32) -> TerminalNode: + key = await self.get_blob_from_kvid(kid, store_id) + value = await self.get_blob_from_kvid(vid, store_id) + if key is None or value is None: + raise Exception("Cannot find the key/value pair") + + return TerminalNode(hash=leaf_hash(key, value), key=key, value=value) + + async def add_kvid(self, blob: bytes, store_id: bytes32) -> KVId: + kv_id = await self.get_kvid(blob, store_id) + if kv_id is not None: + return kv_id + + async with self.db_wrapper.writer() as writer: + await writer.execute( + "INSERT INTO ids (blob, store_id) VALUES (?, ?)", + ( + blob, + store_id, + ), + ) + + kv_id = await self.get_kvid(blob, store_id) + if kv_id is None: + raise Exception("Internal error") + return kv_id + + async def add_key_value(self, key: bytes, value: bytes, store_id: bytes32) -> tuple[KVId, KVId]: + kid = await self.add_kvid(key, store_id) + vid = await self.add_kvid(value, store_id) + hash = leaf_hash(key, value) + async with self.db_wrapper.writer() as writer: + await writer.execute( + "INSERT OR REPLACE INTO hashes (hash, kid, vid, store_id) VALUES (?, ?, ?, ?)", + ( + hash, + kid, + vid, + store_id, + ), + ) + return (kid, vid) + + async def get_node_by_hash(self, hash: bytes32, store_id: bytes32) -> tuple[KVId, KVId]: + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute( + "SELECT * FROM hashes WHERE hash = ? AND store_id = ?", + ( + hash, + store_id, + ), + ) + + row = await cursor.fetchone() + + if row is None: + raise Exception(f"Cannot find node by hash {hash.hex()}") + + kid = KVId(row["kid"]) + vid = KVId(row["vid"]) + return (kid, vid) + + async def get_terminal_node_by_hash(self, node_hash: bytes32, store_id: bytes32) -> TerminalNode: + kid, vid = await self.get_node_by_hash(node_hash, store_id) + return await self.get_terminal_node(kid, vid, store_id) + + async def get_first_generation(self, node_hash: bytes32, store_id: bytes32) -> Optional[int]: + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute( + "SELECT generation FROM nodes WHERE hash = ? AND store_id = ?", + ( + node_hash, + store_id, + ), + ) + + row = await cursor.fetchone() + if row is None: + return None + + return int(row[0]) + + async def add_node_hash( + self, store_id: bytes32, hash: bytes32, root_hash: bytes32, generation: int, index: int + ) -> None: + async with self.db_wrapper.writer() as writer: + await writer.execute( + """ + INSERT INTO nodes(store_id, hash, root_hash, generation, idx) + VALUES (?, ?, ?, ?, ?) + """, + (store_id, hash, root_hash, generation, index), + ) + + async def add_node_hashes(self, store_id: bytes32) -> None: + root = await self.get_tree_root(store_id=store_id) + if root.node_hash is None: + return + + merkle_blob = await self.get_merkle_blob(root_hash=root.node_hash) + hash_to_index = merkle_blob.get_hashes_indexes() + for hash, index in hash_to_index.items(): + existing_generation = await self.get_first_generation(hash, store_id) + if existing_generation is None: + await self.add_node_hash(store_id, hash, root.node_hash, root.generation, index) + + async def build_blob_from_nodes( + self, + internal_nodes: dict[bytes32, tuple[bytes32, bytes32]], + terminal_nodes: dict[bytes32, tuple[KVId, KVId]], + node_hash: bytes32, + merkle_blob: MerkleBlob, + ) -> TreeIndex: + if node_hash not in terminal_nodes and node_hash not in internal_nodes: + async with self.db_wrapper.reader() as reader: + cursor = await reader.execute("SELECT root_hash, idx FROM nodes WHERE hash = ?", (node_hash,)) + + row = await cursor.fetchone() + if row is None: + raise Exception(f"Unknown hash {node_hash}") + + root_hash = row["root_hash"] + index = row["idx"] + + other_merkle_blob = await self.get_merkle_blob(root_hash) + nodes = other_merkle_blob.get_nodes_with_indexes(index=index) + index_to_hash = {index: bytes32(node.hash) for index, node in nodes} + for _, node in nodes: + if isinstance(node, RawLeafMerkleNode): + terminal_nodes[bytes32(node.hash)] = (node.key, node.value) + elif isinstance(node, RawInternalMerkleNode): + internal_nodes[bytes32(node.hash)] = (index_to_hash[node.left], index_to_hash[node.right]) + + index = merkle_blob.get_new_index() + if node_hash in terminal_nodes: + kid, vid = terminal_nodes[node_hash] + merkle_blob.insert_entry_to_blob( + index, + NodeMetadata(type=NodeTypeMerkleBlob.leaf, dirty=False).pack() + + pack_raw_node(RawLeafMerkleNode(node_hash, null_parent, kid, vid)), + ) + elif node_hash in internal_nodes: + merkle_blob.insert_entry_to_blob( + index, + NodeMetadata(type=NodeTypeMerkleBlob.internal, dirty=False).pack() + + pack_raw_node( + RawInternalMerkleNode( + node_hash, + null_parent, + undefined_index, + undefined_index, + ) + ), + ) + left_hash, right_hash = internal_nodes[node_hash] + left_index = await self.build_blob_from_nodes(internal_nodes, terminal_nodes, left_hash, merkle_blob) + right_index = await self.build_blob_from_nodes(internal_nodes, terminal_nodes, right_hash, merkle_blob) + for child_index in (left_index, right_index): + merkle_blob.update_entry(index=child_index, parent=index) + merkle_blob.update_entry(index=index, left=left_index, right=right_index) + + return TreeIndex(index) async def _insert_root( self, @@ -254,173 +604,8 @@ async def _insert_root( """, new_root.to_row(), ) - - # `node_hash` is now a root, so it has no ancestor. - # Don't change the ancestor table unless the root is committed. - if node_hash is not None and status == Status.COMMITTED: - values = { - "hash": node_hash, - "tree_id": store_id, - "generation": generation, - } - await writer.execute( - """ - INSERT INTO ancestors(hash, ancestor, tree_id, generation) - VALUES (:hash, NULL, :tree_id, :generation) - """, - values, - ) - return new_root - async def _insert_node( - self, - node_hash: bytes32, - node_type: NodeType, - left_hash: Optional[bytes32], - right_hash: Optional[bytes32], - key: Optional[bytes], - value: Optional[bytes], - ) -> None: - # TODO: can we get sqlite to do this check? - values = { - "hash": node_hash, - "node_type": node_type, - "left": left_hash, - "right": right_hash, - "key": key, - "value": value, - } - - async with self.db_wrapper.writer() as writer: - try: - await writer.execute( - """ - INSERT INTO node(hash, node_type, left, right, key, value) - VALUES(:hash, :node_type, :left, :right, :key, :value) - """, - values, - ) - except aiosqlite.IntegrityError as e: - if not e.args[0].startswith("UNIQUE constraint"): - # UNIQUE constraint failed: node.hash - raise - - async with writer.execute( - "SELECT * FROM node WHERE hash == :hash LIMIT 1", - {"hash": node_hash}, - ) as cursor: - result = await cursor.fetchone() - - if result is None: - # some ideas for causes: - # an sqlite bug - # bad queries in this function - # unexpected db constraints - raise Exception("Unable to find conflicting row") from e # pragma: no cover - - result_dict = dict(result) - if result_dict != values: - raise Exception( - f"Requested insertion of node with matching hash but other values differ: {node_hash}" - ) from None - - async def insert_node(self, node_type: NodeType, value1: bytes, value2: bytes) -> None: - if node_type == NodeType.INTERNAL: - left_hash = bytes32(value1) - right_hash = bytes32(value2) - node_hash = internal_hash(left_hash, right_hash) - await self._insert_node(node_hash, node_type, bytes32(value1), bytes32(value2), None, None) - else: - node_hash = leaf_hash(key=value1, value=value2) - await self._insert_node(node_hash, node_type, None, None, value1, value2) - - async def _insert_internal_node(self, left_hash: bytes32, right_hash: bytes32) -> bytes32: - node_hash: bytes32 = internal_hash(left_hash=left_hash, right_hash=right_hash) - - await self._insert_node( - node_hash=node_hash, - node_type=NodeType.INTERNAL, - left_hash=left_hash, - right_hash=right_hash, - key=None, - value=None, - ) - - return node_hash - - async def _insert_ancestor_table( - self, - left_hash: bytes32, - right_hash: bytes32, - store_id: bytes32, - generation: int, - ) -> None: - node_hash = internal_hash(left_hash=left_hash, right_hash=right_hash) - - async with self.db_wrapper.writer() as writer: - for hash in (left_hash, right_hash): - values = { - "hash": hash, - "ancestor": node_hash, - "tree_id": store_id, - "generation": generation, - } - try: - await writer.execute( - """ - INSERT INTO ancestors(hash, ancestor, tree_id, generation) - VALUES (:hash, :ancestor, :tree_id, :generation) - """, - values, - ) - except aiosqlite.IntegrityError as e: - if not e.args[0].startswith("UNIQUE constraint"): - # UNIQUE constraint failed: ancestors.hash, ancestors.tree_id, ancestors.generation - raise - - async with writer.execute( - """ - SELECT * - FROM ancestors - WHERE hash == :hash AND generation == :generation AND tree_id == :tree_id - LIMIT 1 - """, - {"hash": hash, "generation": generation, "tree_id": store_id}, - ) as cursor: - result = await cursor.fetchone() - - if result is None: - # some ideas for causes: - # an sqlite bug - # bad queries in this function - # unexpected db constraints - raise Exception("Unable to find conflicting row") from e # pragma: no cover - - result_dict = dict(result) - if result_dict != values: - raise Exception( - "Requested insertion of ancestor, where ancestor differ, but other values are identical: " - f"{hash} {generation} {store_id}" - ) from None - - async def _insert_terminal_node(self, key: bytes, value: bytes) -> bytes32: - # forcing type hint here for: - # https://github.com/Chia-Network/clvm/pull/102 - # https://github.com/Chia-Network/clvm/pull/106 - node_hash: bytes32 = Program.to((key, value)).get_tree_hash() - - await self._insert_node( - node_hash=node_hash, - node_type=NodeType.TERMINAL, - left_hash=None, - right_hash=None, - key=key, - value=value, - ) - - return node_hash - async def get_pending_root(self, store_id: bytes32) -> Optional[Root]: async with self.db_wrapper.reader() as reader: cursor = await reader.execute( @@ -478,21 +663,8 @@ async def change_root_status(self, root: Root, status: Status = Status.PENDING) root.generation, ), ) - # `node_hash` is now a root, so it has no ancestor. - # Don't change the ancestor table unless the root is committed. if root.node_hash is not None and status == Status.COMMITTED: - values = { - "hash": root.node_hash, - "tree_id": root.store_id, - "generation": root.generation, - } - await writer.execute( - """ - INSERT INTO ancestors(hash, ancestor, tree_id, generation) - VALUES (:hash, NULL, :tree_id, :generation) - """, - values, - ) + await self.add_node_hashes(root.store_id) async def check(self) -> None: for check in self._checks: @@ -519,30 +691,7 @@ async def _check_roots_are_incrementing(self) -> None: if len(bad_trees) > 0: raise TreeGenerationIncrementingError(store_ids=bad_trees) - async def _check_hashes(self) -> None: - async with self.db_wrapper.reader() as reader: - cursor = await reader.execute("SELECT * FROM node") - - bad_node_hashes: list[bytes32] = [] - async for row in cursor: - node = row_to_node(row=row) - if isinstance(node, InternalNode): - expected_hash = internal_hash(left_hash=node.left_hash, right_hash=node.right_hash) - elif isinstance(node, TerminalNode): - expected_hash = Program.to((node.key, node.value)).get_tree_hash() - else: - raise Exception(f"Internal error, unknown node type: {node!r}") - - if node.hash != expected_hash: - bad_node_hashes.append(node.hash) - - if len(bad_node_hashes) > 0: - raise NodeHashError(node_hashes=bad_node_hashes) - - _checks: tuple[Callable[[DataStore], Awaitable[None]], ...] = ( - _check_roots_are_incrementing, - _check_hashes, - ) + _checks: tuple[Callable[[DataStore], Awaitable[None]], ...] = (_check_roots_are_incrementing,) async def create_tree(self, store_id: bytes32, status: Status = Status.PENDING) -> bool: await self._insert_root(store_id=store_id, node_hash=None, status=status) @@ -659,72 +808,19 @@ async def get_ancestors( node_hash: bytes32, store_id: bytes32, root_hash: Optional[bytes32] = None, - ) -> list[InternalNode]: - async with self.db_wrapper.reader() as reader: - if root_hash is None: - root = await self.get_tree_root(store_id=store_id) - root_hash = root.node_hash - if root_hash is None: - raise Exception(f"Root hash is unspecified for store ID: {store_id.hex()}") - cursor = await reader.execute( - """ - WITH RECURSIVE - tree_from_root_hash(hash, node_type, left, right, key, value, depth) AS ( - SELECT node.*, 0 AS depth FROM node WHERE node.hash == :root_hash - UNION ALL - SELECT node.*, tree_from_root_hash.depth + 1 AS depth FROM node, tree_from_root_hash - WHERE node.hash == tree_from_root_hash.left OR node.hash == tree_from_root_hash.right - ), - ancestors(hash, node_type, left, right, key, value, depth) AS ( - SELECT node.*, NULL AS depth FROM node - WHERE node.left == :reference_hash OR node.right == :reference_hash - UNION ALL - SELECT node.*, NULL AS depth FROM node, ancestors - WHERE node.left == ancestors.hash OR node.right == ancestors.hash - ) - SELECT * FROM tree_from_root_hash INNER JOIN ancestors - WHERE tree_from_root_hash.hash == ancestors.hash - ORDER BY tree_from_root_hash.depth DESC - """, - {"reference_hash": node_hash, "root_hash": root_hash}, - ) - - # The resulting rows must represent internal nodes. InternalNode.from_row() - # does some amount of validation in the sense that it will fail if left - # or right can't turn into a bytes32 as expected. There is room for more - # validation here if desired. - ancestors = [InternalNode.from_row(row=row) async for row in cursor] - - return ancestors - - async def get_ancestors_optimized( - self, - node_hash: bytes32, - store_id: bytes32, generation: Optional[int] = None, - root_hash: Optional[bytes32] = None, ) -> list[InternalNode]: async with self.db_wrapper.reader(): - nodes = [] if root_hash is None: root = await self.get_tree_root(store_id=store_id, generation=generation) root_hash = root.node_hash - if root_hash is None: - return [] - - while True: - internal_node = await self._get_one_ancestor(node_hash, store_id, generation) - if internal_node is None: - break - nodes.append(internal_node) - node_hash = internal_node.hash + raise Exception(f"Root hash is unspecified for store ID: {store_id.hex()}") - if len(nodes) > 0: - if root_hash != nodes[-1].hash: - raise RuntimeError("Ancestors list didn't produce the root as top result.") + merkle_blob = await self.get_merkle_blob(root_hash=root_hash) + reference_kid, _ = await self.get_node_by_hash(node_hash, store_id) - return nodes + return merkle_blob.get_lineage_by_key_id(reference_kid) async def get_internal_nodes(self, store_id: bytes32, root_hash: Optional[bytes32] = None) -> list[InternalNode]: async with self.db_wrapper.reader() as reader: @@ -755,45 +851,12 @@ async def get_internal_nodes(self, store_id: bytes32, root_hash: Optional[bytes3 return internal_nodes - async def get_keys_values_cursor( - self, - reader: aiosqlite.Connection, - root_hash: Optional[bytes32], - only_keys: bool = False, - ) -> aiosqlite.Cursor: - select_clause = "SELECT hash, key" if only_keys else "SELECT *" - maybe_value = "" if only_keys else "value, " - select_node_clause = "node.hash, node.node_type, node.left, node.right, node.key" if only_keys else "node.*" - return await reader.execute( - f""" - WITH RECURSIVE - tree_from_root_hash(hash, node_type, left, right, key, {maybe_value}depth, rights) AS ( - SELECT {select_node_clause}, 0 AS depth, 0 AS rights FROM node WHERE node.hash == :root_hash - UNION ALL - SELECT - {select_node_clause}, - tree_from_root_hash.depth + 1 AS depth, - CASE - WHEN node.hash == tree_from_root_hash.right - THEN tree_from_root_hash.rights + (1 << (62 - tree_from_root_hash.depth)) - ELSE tree_from_root_hash.rights - END AS rights - FROM node, tree_from_root_hash - WHERE node.hash == tree_from_root_hash.left OR node.hash == tree_from_root_hash.right - ) - {select_clause} FROM tree_from_root_hash - WHERE node_type == :node_type - ORDER BY depth ASC, rights ASC - """, - {"root_hash": root_hash, "node_type": NodeType.TERMINAL}, - ) - async def get_keys_values( self, store_id: bytes32, root_hash: Union[bytes32, Unspecified] = unspecified, ) -> list[TerminalNode]: - async with self.db_wrapper.reader() as reader: + async with self.db_wrapper.reader(): resolved_root_hash: Optional[bytes32] if root_hash is unspecified: root = await self.get_tree_root(store_id=store_id) @@ -801,25 +864,19 @@ async def get_keys_values( else: resolved_root_hash = root_hash - cursor = await self.get_keys_values_cursor(reader, resolved_root_hash) + try: + merkle_blob = await self.get_merkle_blob(root_hash=resolved_root_hash) + except Exception as e: + if str(e).startswith("Cannot find merkle blob for root hash"): + return [] + raise + + kv_ids = merkle_blob.get_keys_values() + terminal_nodes: list[TerminalNode] = [] - async for row in cursor: - if row["depth"] > 62: - # TODO: Review the value and implementation of left-to-right order - # reporting. Initial use is for balanced insertion with the - # work done in the query. - - # This is limited based on the choice of 63 for the maximum left - # shift in the query. This is in turn based on the SQLite integers - # ranging in size up to signed 8 bytes, 64 bits. If we exceed this then - # we no longer guarantee the left-to-right ordering of the node - # list. While 63 allows for a lot of nodes in a balanced tree, in - # the worst case it allows only 62 terminal nodes. - raise Exception("Tree depth exceeded 62, unable to guarantee left-to-right node order.") - node = row_to_node(row=row) - if not isinstance(node, TerminalNode): - raise Exception(f"Unexpected internal node found: {node.hash.hex()}") - terminal_nodes.append(node) + for kid, vid in kv_ids.items(): + terminal_node = await self.get_terminal_node(kid, vid, store_id) + terminal_nodes.append(terminal_node) return terminal_nodes @@ -828,7 +885,7 @@ async def get_keys_values_compressed( store_id: bytes32, root_hash: Union[bytes32, Unspecified] = unspecified, ) -> KeysValuesCompressed: - async with self.db_wrapper.reader() as reader: + async with self.db_wrapper.reader(): resolved_root_hash: Optional[bytes32] if root_hash is unspecified: root = await self.get_tree_root(store_id=store_id) @@ -836,36 +893,25 @@ async def get_keys_values_compressed( else: resolved_root_hash = root_hash - cursor = await self.get_keys_values_cursor(reader, resolved_root_hash) - keys_values_hashed: dict[bytes32, bytes32] = {} - key_hash_to_length: dict[bytes32, int] = {} - leaf_hash_to_length: dict[bytes32, int] = {} - async for row in cursor: - if row["depth"] > 62: - raise Exception("Tree depth exceeded 62, unable to guarantee left-to-right node order.") - node = row_to_node(row=row) - if not isinstance(node, TerminalNode): - raise Exception(f"Unexpected internal node found: {node.hash.hex()}") - keys_values_hashed[key_hash(node.key)] = leaf_hash(node.key, node.value) - key_hash_to_length[key_hash(node.key)] = len(node.key) - leaf_hash_to_length[leaf_hash(node.key, node.value)] = len(node.key) + len(node.value) - - return KeysValuesCompressed(keys_values_hashed, key_hash_to_length, leaf_hash_to_length, resolved_root_hash) - - async def get_leaf_hashes_by_hashed_key( - self, store_id: bytes32, root_hash: Optional[bytes32] = None - ) -> dict[bytes32, bytes32]: - result: dict[bytes32, bytes32] = {} - async with self.db_wrapper.reader() as reader: - if root_hash is None: - root = await self.get_tree_root(store_id=store_id) - root_hash = root.node_hash - - cursor = await self.get_keys_values_cursor(reader, root_hash, True) - async for row in cursor: - result[key_hash(row["key"])] = bytes32(row["hash"]) + keys_values_hashed: dict[bytes32, bytes32] = {} + key_hash_to_length: dict[bytes32, int] = {} + leaf_hash_to_length: dict[bytes32, int] = {} + if resolved_root_hash is not None: + try: + merkle_blob = await self.get_merkle_blob(root_hash=resolved_root_hash) + except Exception as e: + if str(e).startswith("Cannot find merkle blob for root hash"): + return KeysValuesCompressed({}, {}, {}, resolved_root_hash) + raise + kv_ids = merkle_blob.get_keys_values() + for kid, vid in kv_ids.items(): + node = await self.get_terminal_node(kid, vid, store_id) + + keys_values_hashed[key_hash(node.key)] = leaf_hash(node.key, node.value) + key_hash_to_length[key_hash(node.key)] = len(node.key) + leaf_hash_to_length[leaf_hash(node.key, node.value)] = len(node.key) + len(node.value) - return result + return KeysValuesCompressed(keys_values_hashed, key_hash_to_length, leaf_hash_to_length, resolved_root_hash) async def get_keys_paginated( self, @@ -880,7 +926,7 @@ async def get_keys_paginated( keys: list[bytes] = [] for hash in pagination_data.hashes: leaf_hash = keys_values_compressed.keys_values_hashed[hash] - node = await self.get_node(leaf_hash) + node = await self.get_terminal_node_by_hash(leaf_hash, store_id) assert isinstance(node, TerminalNode) keys.append(node.key) @@ -903,7 +949,7 @@ async def get_keys_values_paginated( keys_values: list[TerminalNode] = [] for hash in pagination_data.hashes: - node = await self.get_node(hash) + node = await self.get_terminal_node_by_hash(hash, store_id) assert isinstance(node, TerminalNode) keys_values.append(node) @@ -945,7 +991,7 @@ async def get_kv_diff_paginated( kv_diff: list[DiffData] = [] for hash in pagination_data.hashes: - node = await self.get_node(hash) + node = await self.get_terminal_node_by_hash(hash, store_id) assert isinstance(node, TerminalNode) if hash in insertions: kv_diff.append(DiffData(OperationType.INSERT, node.key, node.value)) @@ -971,94 +1017,23 @@ async def get_node_type(self, node_hash: bytes32) -> NodeType: return NodeType(raw_node_type["node_type"]) - async def get_terminal_node_for_seed( - self, store_id: bytes32, seed: bytes32, root_hash: Optional[bytes32] = None - ) -> Optional[bytes32]: - path = "".join(reversed("".join(f"{b:08b}" for b in seed))) - async with self.db_wrapper.reader() as reader: - if root_hash is None: - root = await self.get_tree_root(store_id) - root_hash = root.node_hash - if root_hash is None: - return None - - async with reader.execute( - """ - WITH RECURSIVE - random_leaf(hash, node_type, left, right, depth, side) AS ( - SELECT - node.hash AS hash, - node.node_type AS node_type, - node.left AS left, - node.right AS right, - 1 AS depth, - SUBSTR(:path, 1, 1) as side - FROM node - WHERE node.hash == :root_hash - UNION ALL - SELECT - node.hash AS hash, - node.node_type AS node_type, - node.left AS left, - node.right AS right, - random_leaf.depth + 1 AS depth, - SUBSTR(:path, random_leaf.depth + 1, 1) as side - FROM node, random_leaf - WHERE ( - (random_leaf.side == "0" AND node.hash == random_leaf.left) - OR (random_leaf.side != "0" AND node.hash == random_leaf.right) - ) - ) - SELECT hash AS hash FROM random_leaf - WHERE node_type == :node_type - LIMIT 1 - """, - {"root_hash": root_hash, "node_type": NodeType.TERMINAL, "path": path}, - ) as cursor: - row = await cursor.fetchone() - if row is None: - # No cover since this is an error state that should be unreachable given the code - # above has already verified that there is a non-empty tree. - raise Exception("No terminal node found for seed") # pragma: no cover - return bytes32(row["hash"]) - - def get_side_for_seed(self, seed: bytes32) -> Side: - side_seed = bytes(seed)[0] - return Side.LEFT if side_seed < 128 else Side.RIGHT - async def autoinsert( self, key: bytes, value: bytes, store_id: bytes32, - use_optimized: bool = True, status: Status = Status.PENDING, root: Optional[Root] = None, ) -> InsertResult: - async with self.db_wrapper.writer(): - if root is None: - root = await self.get_tree_root(store_id=store_id) - - was_empty = root.node_hash is None - - if was_empty: - reference_node_hash = None - side = None - else: - seed = leaf_hash(key=key, value=value) - reference_node_hash = await self.get_terminal_node_for_seed(store_id, seed, root_hash=root.node_hash) - side = self.get_side_for_seed(seed) - - return await self.insert( - key=key, - value=value, - store_id=store_id, - reference_node_hash=reference_node_hash, - side=side, - use_optimized=use_optimized, - status=status, - root=root, - ) + return await self.insert( + key=key, + value=value, + store_id=store_id, + reference_node_hash=None, + side=None, + status=status, + root=root, + ) async def get_keys_values_dict( self, @@ -1073,265 +1048,106 @@ async def get_keys( store_id: bytes32, root_hash: Union[bytes32, Unspecified] = unspecified, ) -> list[bytes]: - async with self.db_wrapper.reader() as reader: + async with self.db_wrapper.reader(): if root_hash is unspecified: root = await self.get_tree_root(store_id=store_id) resolved_root_hash = root.node_hash else: resolved_root_hash = root_hash - cursor = await reader.execute( - """ - WITH RECURSIVE - tree_from_root_hash(hash, node_type, left, right, key) AS ( - SELECT node.hash, node.node_type, node.left, node.right, node.key - FROM node WHERE node.hash == :root_hash - UNION ALL - SELECT - node.hash, node.node_type, node.left, node.right, node.key FROM node, tree_from_root_hash - WHERE node.hash == tree_from_root_hash.left OR node.hash == tree_from_root_hash.right - ) - SELECT key FROM tree_from_root_hash WHERE node_type == :node_type - """, - {"root_hash": resolved_root_hash, "node_type": NodeType.TERMINAL}, - ) - keys: list[bytes] = [row["key"] async for row in cursor] + try: + merkle_blob = await self.get_merkle_blob(root_hash=resolved_root_hash) + except Exception as e: + if str(e).startswith("Cannot find merkle blob for root hash"): + return [] + raise + + kv_ids = merkle_blob.get_keys_values() + keys: list[bytes] = [] + for kid in kv_ids.keys(): + key = await self.get_blob_from_kvid(kid, store_id) + if key is None: + raise Exception(f"Unknown key corresponding to KVId: {kid}") + keys.append(key) return keys - async def get_ancestors_common( - self, - node_hash: bytes32, - store_id: bytes32, - root_hash: Optional[bytes32], - generation: Optional[int] = None, - use_optimized: bool = True, - ) -> list[InternalNode]: - if use_optimized: - ancestors: list[InternalNode] = await self.get_ancestors_optimized( - node_hash=node_hash, - store_id=store_id, - generation=generation, - root_hash=root_hash, - ) - else: - ancestors = await self.get_ancestors_optimized( - node_hash=node_hash, - store_id=store_id, - generation=generation, - root_hash=root_hash, - ) - ancestors_2: list[InternalNode] = await self.get_ancestors( - node_hash=node_hash, store_id=store_id, root_hash=root_hash - ) - if ancestors != ancestors_2: - raise RuntimeError("Ancestors optimized didn't produce the expected result.") - - if len(ancestors) >= 62: - raise RuntimeError("Tree exceeds max height of 62.") - return ancestors - - async def update_ancestor_hashes_on_insert( - self, - store_id: bytes32, - left: bytes32, - right: bytes32, - traversal_node_hash: bytes32, - ancestors: list[InternalNode], - status: Status, - root: Root, - ) -> Root: - # update ancestors after inserting root, to keep table constraints. - insert_ancestors_cache: list[tuple[bytes32, bytes32, bytes32]] = [] - new_generation = root.generation + 1 - # create first new internal node - new_hash = await self._insert_internal_node(left_hash=left, right_hash=right) - insert_ancestors_cache.append((left, right, store_id)) - - # create updated replacements for the rest of the internal nodes - for ancestor in ancestors: - if not isinstance(ancestor, InternalNode): - raise Exception(f"Expected an internal node but got: {type(ancestor).__name__}") - - if ancestor.left_hash == traversal_node_hash: - left = new_hash - right = ancestor.right_hash - elif ancestor.right_hash == traversal_node_hash: - left = ancestor.left_hash - right = new_hash - - traversal_node_hash = ancestor.hash - - new_hash = await self._insert_internal_node(left_hash=left, right_hash=right) - insert_ancestors_cache.append((left, right, store_id)) - - new_root = await self._insert_root( - store_id=store_id, - node_hash=new_hash, - status=status, - generation=new_generation, - ) - - if status == Status.COMMITTED: - for left_hash, right_hash, store_id in insert_ancestors_cache: - await self._insert_ancestor_table(left_hash, right_hash, store_id, new_generation) + def get_reference_kid_side(self, merkle_blob: MerkleBlob, seed: bytes32) -> tuple[KVId, Side]: + side_seed = bytes(seed)[0] + side = Side.LEFT if side_seed < 128 else Side.RIGHT + reference_node = merkle_blob.get_random_leaf_node(seed) + kid = reference_node.key + return (kid, side) + + async def get_terminal_node_for_seed(self, seed: bytes32, store_id: bytes32) -> Optional[TerminalNode]: + root = await self.get_tree_root(store_id=store_id) + if root is None or root.node_hash is None: + return None - return new_root + merkle_blob = await self.get_merkle_blob(root.node_hash) + assert not merkle_blob.empty() + kid, _ = self.get_reference_kid_side(merkle_blob, seed) + key = await self.get_blob_from_kvid(kid, store_id) + assert key is not None + node = await self.get_node_by_key(key, store_id) + return node async def insert( self, key: bytes, value: bytes, store_id: bytes32, - reference_node_hash: Optional[bytes32], - side: Optional[Side], - use_optimized: bool = True, + reference_node_hash: Optional[bytes32] = None, + side: Optional[Side] = None, status: Status = Status.PENDING, root: Optional[Root] = None, ) -> InsertResult: async with self.db_wrapper.writer(): if root is None: root = await self.get_tree_root(store_id=store_id) + merkle_blob = await self.get_merkle_blob(root_hash=root.node_hash) - try: - await self.get_node_by_key(key=key, store_id=store_id) - raise Exception(f"Key already present: {key.hex()}") - except KeyNotFoundError: - pass + kid, vid = await self.add_key_value(key, value, store_id) + hash = leaf_hash(key, value) + reference_kid = None + if reference_node_hash is not None: + reference_kid, _ = await self.get_node_by_hash(reference_node_hash, store_id) was_empty = root.node_hash is None - if reference_node_hash is None: - if not was_empty: - raise Exception(f"Reference node hash must be specified for non-empty tree: {store_id.hex()}") - else: - reference_node_type = await self.get_node_type(node_hash=reference_node_hash) - if reference_node_type == NodeType.INTERNAL: - raise Exception("can not insert a new key/value on an internal node") - - # create new terminal node - new_terminal_node_hash = await self._insert_terminal_node(key=key, value=value) - - if was_empty: + if not was_empty and reference_kid is None: if side is not None: - raise Exception("Tree was empty so side must be unspecified, got: {side!r}") + raise Exception("Side specified without reference node hash") - new_root = await self._insert_root( - store_id=store_id, - node_hash=new_terminal_node_hash, - status=status, - ) - else: - if side is None: - raise Exception("Tree was not empty, side must be specified.") - if reference_node_hash is None: - raise Exception("Tree was not empty, reference node hash must be specified.") - if root.node_hash is None: - raise Exception("Internal error.") - - if side == Side.LEFT: - left = new_terminal_node_hash - right = reference_node_hash - elif side == Side.RIGHT: - left = reference_node_hash - right = new_terminal_node_hash - else: - raise Exception(f"Internal error, unknown side: {side!r}") - - ancestors = await self.get_ancestors_common( - node_hash=reference_node_hash, - store_id=store_id, - root_hash=root.node_hash, - generation=root.generation, - use_optimized=use_optimized, - ) - new_root = await self.update_ancestor_hashes_on_insert( - store_id=store_id, - left=left, - right=right, - traversal_node_hash=reference_node_hash, - ancestors=ancestors, - status=status, - root=root, - ) + seed = leaf_hash(key=key, value=value) + reference_kid, side = self.get_reference_kid_side(merkle_blob, seed) + + try: + merkle_blob.insert(kid, vid, hash, reference_kid, side) + except Exception as e: + if str(e) == "Key already present": + raise Exception(f"Key already present: {key.hex()}") + raise - return InsertResult(node_hash=new_terminal_node_hash, root=new_root) + new_root = await self.insert_root_from_merkle_blob(merkle_blob, store_id, status) + return InsertResult(node_hash=hash, root=new_root) async def delete( self, key: bytes, store_id: bytes32, - use_optimized: bool = True, status: Status = Status.PENDING, root: Optional[Root] = None, ) -> Optional[Root]: - root_hash = None if root is None else root.node_hash async with self.db_wrapper.writer(): - try: - node = await self.get_node_by_key(key=key, store_id=store_id) - node_hash = node.hash - assert isinstance(node, TerminalNode) - except KeyNotFoundError: - log.debug(f"Request to delete an unknown key ignored: {key.hex()}") - return root - - ancestors: list[InternalNode] = await self.get_ancestors_common( - node_hash=node_hash, - store_id=store_id, - root_hash=root_hash, - use_optimized=use_optimized, - ) - - if len(ancestors) == 0: - # the only node is being deleted - return await self._insert_root( - store_id=store_id, - node_hash=None, - status=status, - ) - - parent = ancestors[0] - other_hash = parent.other_child_hash(hash=node_hash) - - if len(ancestors) == 1: - # the parent is the root so the other side will become the new root - return await self._insert_root( - store_id=store_id, - node_hash=other_hash, - status=status, - ) - - old_child_hash = parent.hash - new_child_hash = other_hash if root is None: - new_generation = await self.get_tree_generation(store_id) + 1 - else: - new_generation = root.generation + 1 - # update ancestors after inserting root, to keep table constraints. - insert_ancestors_cache: list[tuple[bytes32, bytes32, bytes32]] = [] - # more parents to handle so let's traverse them - for ancestor in ancestors[1:]: - if ancestor.left_hash == old_child_hash: - left_hash = new_child_hash - right_hash = ancestor.right_hash - elif ancestor.right_hash == old_child_hash: - left_hash = ancestor.left_hash - right_hash = new_child_hash - else: - raise Exception("Internal error.") + root = await self.get_tree_root(store_id=store_id) + merkle_blob = await self.get_merkle_blob(root_hash=root.node_hash) - new_child_hash = await self._insert_internal_node(left_hash=left_hash, right_hash=right_hash) - insert_ancestors_cache.append((left_hash, right_hash, store_id)) - old_child_hash = ancestor.hash + kid = await self.get_kvid(key, store_id) + if kid is not None: + merkle_blob.delete(kid) - new_root = await self._insert_root( - store_id=store_id, - node_hash=new_child_hash, - status=status, - generation=new_generation, - ) - if status == Status.COMMITTED: - for left_hash, right_hash, store_id in insert_ancestors_cache: - await self._insert_ancestor_table(left_hash, right_hash, store_id, new_generation) + new_root = await self.insert_root_from_merkle_blob(merkle_blob, store_id, status) return new_root @@ -1340,151 +1156,20 @@ async def upsert( key: bytes, new_value: bytes, store_id: bytes32, - use_optimized: bool = True, status: Status = Status.PENDING, root: Optional[Root] = None, ) -> InsertResult: async with self.db_wrapper.writer(): if root is None: root = await self.get_tree_root(store_id=store_id) + merkle_blob = await self.get_merkle_blob(root_hash=root.node_hash) - try: - old_node = await self.get_node_by_key(key=key, store_id=store_id) - except KeyNotFoundError: - log.debug(f"Key not found: {key.hex()}. Doing an autoinsert instead") - return await self.autoinsert( - key=key, - value=new_value, - store_id=store_id, - use_optimized=use_optimized, - status=status, - root=root, - ) - if old_node.value == new_value: - log.debug(f"New value matches old value in upsert operation: {key.hex()}. Ignoring upsert") - return InsertResult(leaf_hash(key, new_value), root) - - # create new terminal node - new_terminal_node_hash = await self._insert_terminal_node(key=key, value=new_value) - - ancestors = await self.get_ancestors_common( - node_hash=old_node.hash, - store_id=store_id, - root_hash=root.node_hash, - generation=root.generation, - use_optimized=use_optimized, - ) - - # Store contains only the old root, replace it with a new root having the terminal node. - if len(ancestors) == 0: - new_root = await self._insert_root( - store_id=store_id, - node_hash=new_terminal_node_hash, - status=status, - ) - else: - parent = ancestors[0] - if parent.left_hash == old_node.hash: - left = new_terminal_node_hash - right = parent.right_hash - elif parent.right_hash == old_node.hash: - left = parent.left_hash - right = new_terminal_node_hash - else: - raise Exception("Internal error.") - - new_root = await self.update_ancestor_hashes_on_insert( - store_id=store_id, - left=left, - right=right, - traversal_node_hash=parent.hash, - ancestors=ancestors[1:], - status=status, - root=root, - ) - - return InsertResult(node_hash=new_terminal_node_hash, root=new_root) - - async def clean_node_table(self, writer: Optional[aiosqlite.Connection] = None) -> None: - query = """ - WITH RECURSIVE pending_nodes AS ( - SELECT node_hash AS hash FROM root - WHERE status IN (:pending_status, :pending_batch_status) - UNION ALL - SELECT n.left FROM node n - INNER JOIN pending_nodes pn ON n.hash = pn.hash - WHERE n.left IS NOT NULL - UNION ALL - SELECT n.right FROM node n - INNER JOIN pending_nodes pn ON n.hash = pn.hash - WHERE n.right IS NOT NULL - ) - DELETE FROM node - WHERE hash IN ( - SELECT n.hash FROM node n - LEFT JOIN ancestors a ON n.hash = a.hash - LEFT JOIN pending_nodes pn ON n.hash = pn.hash - WHERE a.hash IS NULL AND pn.hash IS NULL - ) - """ - params = {"pending_status": Status.PENDING.value, "pending_batch_status": Status.PENDING_BATCH.value} - if writer is None: - async with self.db_wrapper.writer(foreign_key_enforcement_enabled=False) as writer: - await writer.execute(query, params) - else: - await writer.execute(query, params) - - async def get_nodes(self, node_hashes: list[bytes32]) -> list[Node]: - query_parameter_place_holders = ",".join("?" for _ in node_hashes) - async with self.db_wrapper.reader() as reader: - # TODO: handle SQLITE_MAX_VARIABLE_NUMBER - cursor = await reader.execute( - f"SELECT * FROM node WHERE hash IN ({query_parameter_place_holders})", - [*node_hashes], - ) - rows = await cursor.fetchall() - - hash_to_node = {row["hash"]: row_to_node(row=row) for row in rows} + kid, vid = await self.add_key_value(key, new_value, store_id) + hash = leaf_hash(key, new_value) + merkle_blob.upsert(kid, vid, hash) - missing_hashes = [node_hash.hex() for node_hash in node_hashes if node_hash not in hash_to_node] - if missing_hashes: - raise Exception(f"Nodes not found for hashes: {', '.join(missing_hashes)}") - - return [hash_to_node[node_hash] for node_hash in node_hashes] - - async def get_leaf_at_minimum_height( - self, root_hash: bytes32, hash_to_parent: dict[bytes32, InternalNode] - ) -> TerminalNode: - queue: list[bytes32] = [root_hash] - batch_size = min(500, SQLITE_MAX_VARIABLE_NUMBER - 10) - - while True: - assert len(queue) > 0 - nodes = await self.get_nodes(queue[:batch_size]) - queue = queue[batch_size:] - - for node in nodes: - if isinstance(node, TerminalNode): - return node - hash_to_parent[node.left_hash] = node - hash_to_parent[node.right_hash] = node - queue.append(node.left_hash) - queue.append(node.right_hash) - - async def batch_upsert( - self, - hash: bytes32, - to_update_hashes: set[bytes32], - pending_upsert_new_hashes: dict[bytes32, bytes32], - ) -> bytes32: - if hash not in to_update_hashes: - return hash - node = await self.get_node(hash) - if isinstance(node, TerminalNode): - return pending_upsert_new_hashes[hash] - new_left_hash = await self.batch_upsert(node.left_hash, to_update_hashes, pending_upsert_new_hashes) - new_right_hash = await self.batch_upsert(node.right_hash, to_update_hashes, pending_upsert_new_hashes) - return await self._insert_internal_node(new_left_hash, new_right_hash) + new_root = await self.insert_root_from_merkle_blob(merkle_blob, store_id, status) + return InsertResult(node_hash=hash, root=new_root) async def insert_batch( self, @@ -1494,22 +1179,19 @@ async def insert_batch( enable_batch_autoinsert: bool = True, ) -> Optional[bytes32]: async with self.transaction(): - old_root = await self.get_tree_root(store_id) + old_root = await self.get_tree_root(store_id=store_id) pending_root = await self.get_pending_root(store_id=store_id) - if pending_root is None: - latest_local_root: Optional[Root] = old_root - else: + if pending_root is not None: if pending_root.status == Status.PENDING_BATCH: # We have an unfinished batch, continue the current batch on top of it. if pending_root.generation != old_root.generation + 1: raise Exception("Internal error") - await self.change_root_status(pending_root, Status.COMMITTED) - await self.build_ancestor_table_for_latest_root(store_id=store_id) - latest_local_root = pending_root + old_root = pending_root + await self.clear_pending_roots(store_id) else: raise Exception("Internal error") - assert latest_local_root is not None + merkle_blob = await self.get_merkle_blob(root_hash=old_root.node_hash) key_hash_frequency: dict[bytes32, int] = {} first_action: dict[bytes32, str] = {} @@ -1523,166 +1205,58 @@ async def insert_batch( first_action[hash] = change["action"] last_action[hash] = change["action"] - pending_autoinsert_hashes: list[bytes32] = [] - pending_upsert_new_hashes: dict[bytes32, bytes32] = {} - leaf_hashes = await self.get_leaf_hashes_by_hashed_key(store_id) + batch_keys_values: list[tuple[KVId, KVId]] = [] + batch_hashes: list[bytes] = [] for change in changelist: if change["action"] == "insert": key = change["key"] value = change["value"] + reference_node_hash = change.get("reference_node_hash", None) side = change.get("side", None) + reference_kid: Optional[KVId] = None + if reference_node_hash is not None: + reference_kid, _ = await self.get_node_by_hash(reference_node_hash, store_id) + + key_hashed = key_hash(key) + kid, vid = await self.add_key_value(key, value, store_id) + if merkle_blob.key_exists(kid): + raise Exception(f"Key already present: {key.hex()}") + hash = leaf_hash(key, value) + if reference_node_hash is None and side is None: - hash = key_hash(key) - # The key is not referenced in any other operation but this autoinsert, hence the order - # of performing these should not matter. We perform all these autoinserts as a batch - # at the end, to speed up the tree processing operations. - # Additionally, if the first action is a delete, we can still perform the autoinsert at the - # end, since the order will be preserved. - if enable_batch_autoinsert: - if key_hash_frequency[hash] == 1 or ( - key_hash_frequency[hash] == 2 and first_action[hash] == "delete" + if enable_batch_autoinsert and reference_kid is None: + if key_hash_frequency[key_hashed] == 1 or ( + key_hash_frequency[key_hashed] == 2 and first_action[key_hashed] == "delete" ): - old_node = await self.maybe_get_node_from_key_hash(leaf_hashes, hash) - terminal_node_hash = await self._insert_terminal_node(key, value) - - if old_node is None: - pending_autoinsert_hashes.append(terminal_node_hash) - else: - if key_hash_frequency[hash] == 1: - raise Exception(f"Key already present: {key.hex()}") - else: - pending_upsert_new_hashes[old_node.hash] = terminal_node_hash + batch_keys_values.append((kid, vid)) + batch_hashes.append(hash) continue - insert_result = await self.autoinsert( - key, value, store_id, True, Status.COMMITTED, root=latest_local_root - ) - latest_local_root = insert_result.root - else: - if reference_node_hash is None or side is None: - raise Exception("Provide both reference_node_hash and side or neither.") - insert_result = await self.insert( - key, - value, - store_id, - reference_node_hash, - side, - True, - Status.COMMITTED, - root=latest_local_root, - ) - latest_local_root = insert_result.root + if not merkle_blob.empty(): + seed = leaf_hash(key=key, value=value) + reference_kid, side = self.get_reference_kid_side(merkle_blob, seed) + + merkle_blob.insert(kid, vid, hash, reference_kid, side) elif change["action"] == "delete": key = change["key"] - hash = key_hash(key) - if key_hash_frequency[hash] == 2 and last_action[hash] == "insert" and enable_batch_autoinsert: - continue - latest_local_root = await self.delete(key, store_id, True, Status.COMMITTED, root=latest_local_root) + deletion_kid = await self.get_kvid(key, store_id) + if deletion_kid is not None: + merkle_blob.delete(deletion_kid) elif change["action"] == "upsert": key = change["key"] new_value = change["value"] - hash = key_hash(key) - if key_hash_frequency[hash] == 1 and enable_batch_autoinsert: - terminal_node_hash = await self._insert_terminal_node(key, new_value) - old_node = await self.maybe_get_node_from_key_hash(leaf_hashes, hash) - if old_node is not None: - pending_upsert_new_hashes[old_node.hash] = terminal_node_hash - else: - pending_autoinsert_hashes.append(terminal_node_hash) - continue - insert_result = await self.upsert( - key, new_value, store_id, True, Status.COMMITTED, root=latest_local_root - ) - latest_local_root = insert_result.root + kid, vid = await self.add_key_value(key, new_value, store_id) + hash = leaf_hash(key, new_value) + merkle_blob.upsert(kid, vid, hash) else: raise Exception(f"Operation in batch is not insert or delete: {change}") - if len(pending_upsert_new_hashes) > 0: - to_update_hashes: set[bytes32] = set(pending_upsert_new_hashes.keys()) - to_update_queue: list[bytes32] = list(pending_upsert_new_hashes.keys()) - batch_size = min(500, SQLITE_MAX_VARIABLE_NUMBER - 10) - - while len(to_update_queue) > 0: - nodes = await self._get_one_ancestor_multiple_hashes(to_update_queue[:batch_size], store_id) - to_update_queue = to_update_queue[batch_size:] - for node in nodes: - if node.hash not in to_update_hashes: - to_update_hashes.add(node.hash) - to_update_queue.append(node.hash) - - assert latest_local_root is not None - assert latest_local_root.node_hash is not None - new_root_hash = await self.batch_upsert( - latest_local_root.node_hash, - to_update_hashes, - pending_upsert_new_hashes, - ) - latest_local_root = await self._insert_root(store_id, new_root_hash, Status.COMMITTED) - - # Start with the leaf nodes and pair them to form new nodes at the next level up, repeating this process - # in a bottom-up fashion until a single root node remains. This constructs a balanced tree from the leaves. - while len(pending_autoinsert_hashes) > 1: - new_hashes: list[bytes32] = [] - for i in range(0, len(pending_autoinsert_hashes) - 1, 2): - internal_node_hash = await self._insert_internal_node( - pending_autoinsert_hashes[i], pending_autoinsert_hashes[i + 1] - ) - new_hashes.append(internal_node_hash) - if len(pending_autoinsert_hashes) % 2 != 0: - new_hashes.append(pending_autoinsert_hashes[-1]) - - pending_autoinsert_hashes = new_hashes - - if len(pending_autoinsert_hashes): - subtree_hash = pending_autoinsert_hashes[0] - if latest_local_root is None or latest_local_root.node_hash is None: - await self._insert_root(store_id=store_id, node_hash=subtree_hash, status=Status.COMMITTED) - else: - hash_to_parent: dict[bytes32, InternalNode] = {} - min_height_leaf = await self.get_leaf_at_minimum_height(latest_local_root.node_hash, hash_to_parent) - ancestors: list[InternalNode] = [] - hash = min_height_leaf.hash - while hash in hash_to_parent: - node = hash_to_parent[hash] - ancestors.append(node) - hash = node.hash - - await self.update_ancestor_hashes_on_insert( - store_id=store_id, - left=min_height_leaf.hash, - right=subtree_hash, - traversal_node_hash=min_height_leaf.hash, - ancestors=ancestors, - status=Status.COMMITTED, - root=latest_local_root, - ) + if len(batch_keys_values) > 0: + merkle_blob.batch_insert(batch_keys_values, batch_hashes) - root = await self.get_tree_root(store_id=store_id) - if root.node_hash == old_root.node_hash: - if len(changelist) != 0: - await self.rollback_to_generation(store_id, old_root.generation) - raise ValueError("Changelist resulted in no change to tree data") - # We delete all "temporary" records stored in root and ancestor tables and store only the final result. - await self.rollback_to_generation(store_id, old_root.generation) - await self.insert_root_with_ancestor_table(store_id=store_id, node_hash=root.node_hash, status=status) - if status in (Status.PENDING, Status.PENDING_BATCH): - new_root = await self.get_pending_root(store_id=store_id) - assert new_root is not None - elif status == Status.COMMITTED: - new_root = await self.get_tree_root(store_id=store_id) - else: - raise Exception(f"No known status: {status}") - if new_root.node_hash != root.node_hash: - raise RuntimeError( - f"Tree root mismatches after batch update: Expected: {root.node_hash}. Got: {new_root.node_hash}" - ) - if new_root.generation != old_root.generation + 1: - raise RuntimeError( - "Didn't get the expected generation after batch update: " - f"Expected: {old_root.generation + 1}. Got: {new_root.generation}" - ) - return root.node_hash + new_root = await self.insert_root_from_merkle_blob(merkle_blob, store_id, status, old_root) + return new_root.node_hash async def _get_one_ancestor( self, @@ -1737,105 +1311,12 @@ async def _get_one_ancestor_multiple_hashes( rows = await cursor.fetchall() return [InternalNode.from_row(row=row) for row in rows] - async def build_ancestor_table_for_latest_root(self, store_id: bytes32) -> None: - async with self.db_wrapper.writer(): - root = await self.get_tree_root(store_id=store_id) - if root.node_hash is None: - return - previous_root = await self.get_tree_root( - store_id=store_id, - generation=max(root.generation - 1, 0), - ) - - if previous_root.node_hash is not None: - previous_internal_nodes: list[InternalNode] = await self.get_internal_nodes( - store_id=store_id, - root_hash=previous_root.node_hash, - ) - known_hashes: set[bytes32] = {node.hash for node in previous_internal_nodes} - else: - known_hashes = set() - internal_nodes: list[InternalNode] = await self.get_internal_nodes( - store_id=store_id, - root_hash=root.node_hash, - ) - for node in internal_nodes: - # We already have the same values in ancestor tables, if we have the same internal node. - # Don't reinsert it so we can save DB space. - if node.hash not in known_hashes: - await self._insert_ancestor_table(node.left_hash, node.right_hash, store_id, root.generation) - - async def insert_root_with_ancestor_table( - self, store_id: bytes32, node_hash: Optional[bytes32], status: Status = Status.PENDING - ) -> None: - async with self.db_wrapper.writer(): - await self._insert_root(store_id=store_id, node_hash=node_hash, status=status) - # Don't update the ancestor table for non-committed status. - if status == Status.COMMITTED: - await self.build_ancestor_table_for_latest_root(store_id=store_id) - - async def get_node_by_key_latest_generation(self, key: bytes, store_id: bytes32) -> TerminalNode: - async with self.db_wrapper.reader() as reader: - root = await self.get_tree_root(store_id=store_id) - if root.node_hash is None: - raise KeyNotFoundError(key=key) - - cursor = await reader.execute( - """ - SELECT a.hash FROM ancestors a - JOIN node n ON a.hash = n.hash - WHERE n.key = :key - AND a.tree_id = :tree_id - ORDER BY a.generation DESC - LIMIT 1 - """, - {"key": key, "tree_id": store_id}, - ) - - row = await cursor.fetchone() - if row is None: - raise KeyNotFoundError(key=key) - - node = await self.get_node(row["hash"]) - node_hash = node.hash - while True: - internal_node = await self._get_one_ancestor(node_hash, store_id) - if internal_node is None: - break - node_hash = internal_node.hash - - if node_hash != root.node_hash: - raise KeyNotFoundError(key=key) - assert isinstance(node, TerminalNode) - return node - - async def maybe_get_node_from_key_hash( - self, leaf_hashes: dict[bytes32, bytes32], hash: bytes32 - ) -> Optional[TerminalNode]: - if hash in leaf_hashes: - leaf_hash = leaf_hashes[hash] - node = await self.get_node(leaf_hash) - assert isinstance(node, TerminalNode) - return node - - return None - - async def maybe_get_node_by_key(self, key: bytes, store_id: bytes32) -> Optional[TerminalNode]: - try: - node = await self.get_node_by_key_latest_generation(key, store_id) - return node - except KeyNotFoundError: - return None - async def get_node_by_key( self, key: bytes, store_id: bytes32, root_hash: Union[bytes32, Unspecified] = unspecified, ) -> TerminalNode: - if root_hash is unspecified: - return await self.get_node_by_key_latest_generation(key, store_id) - nodes = await self.get_keys_values(store_id=store_id, root_hash=root_hash) for node in nodes: @@ -1856,33 +1337,29 @@ async def get_node(self, node_hash: bytes32) -> Node: return node async def get_tree_as_nodes(self, store_id: bytes32) -> Node: - async with self.db_wrapper.reader() as reader: + async with self.db_wrapper.reader(): root = await self.get_tree_root(store_id=store_id) # TODO: consider actual proper behavior assert root.node_hash is not None - root_node = await self.get_node(node_hash=root.node_hash) - cursor = await reader.execute( - """ - WITH RECURSIVE - tree_from_root_hash(hash, node_type, left, right, key, value) AS ( - SELECT node.* FROM node WHERE node.hash == :root_hash - UNION ALL - SELECT node.* FROM node, tree_from_root_hash - WHERE node.hash == tree_from_root_hash.left OR node.hash == tree_from_root_hash.right - ) - SELECT * FROM tree_from_root_hash - """, - {"root_hash": root_node.hash}, - ) - nodes = [row_to_node(row=row) async for row in cursor] + merkle_blob = await self.get_merkle_blob(root_hash=root.node_hash) + + nodes = merkle_blob.get_nodes_with_indexes() hash_to_node: dict[bytes32, Node] = {} - for node in reversed(nodes): - if isinstance(node, InternalNode): - node = replace(node, left=hash_to_node[node.left_hash], right=hash_to_node[node.right_hash]) - hash_to_node[node.hash] = node + tree_node: Node + for _, node in reversed(nodes): + if isinstance(node, RawInternalMerkleNode): + left_hash = merkle_blob.get_hash_at_index(node.left) + right_hash = merkle_blob.get_hash_at_index(node.right) + tree_node = InternalNode.from_child_nodes( + left=hash_to_node[left_hash], right=hash_to_node[right_hash] + ) + else: + assert isinstance(node, RawLeafMerkleNode) + tree_node = await self.get_terminal_node(node.key, node.value, store_id) + hash_to_node[bytes32(node.hash)] = tree_node - root_node = hash_to_node[root_node.hash] + root_node = hash_to_node[root.node_hash] return root_node @@ -1891,66 +1368,25 @@ async def get_proof_of_inclusion_by_hash( node_hash: bytes32, store_id: bytes32, root_hash: Optional[bytes32] = None, - use_optimized: bool = False, ) -> ProofOfInclusion: - """Collect the information for a proof of inclusion of a hash in the Merkle - tree. - """ - - # Ideally this would use get_ancestors_common, but this _common function has this interesting property - # when used with use_optimized=False - it will compare both methods in this case and raise an exception. - # this is undesirable in the DL Offers flow where PENDING roots can cause the optimized code to fail. - if use_optimized: - ancestors = await self.get_ancestors_optimized(node_hash=node_hash, store_id=store_id, root_hash=root_hash) - else: - ancestors = await self.get_ancestors(node_hash=node_hash, store_id=store_id, root_hash=root_hash) - - layers: list[ProofOfInclusionLayer] = [] - child_hash = node_hash - for parent in ancestors: - layer = ProofOfInclusionLayer.from_internal_node(internal_node=parent, traversal_child_hash=child_hash) - layers.append(layer) - child_hash = parent.hash - - proof_of_inclusion = ProofOfInclusion(node_hash=node_hash, layers=layers) - - if len(ancestors) > 0: - expected_root = ancestors[-1].hash - else: - expected_root = node_hash - - if expected_root != proof_of_inclusion.root_hash: - raise Exception( - f"Incorrect root, expected: {expected_root.hex()}" - f"\n has: {proof_of_inclusion.root_hash.hex()}" - ) - - return proof_of_inclusion + if root_hash is None: + root = await self.get_tree_root(store_id=store_id) + root_hash = root.node_hash + merkle_blob = await self.get_merkle_blob(root_hash=root_hash) + kid, _ = await self.get_node_by_hash(node_hash, store_id) + return merkle_blob.get_proof_of_inclusion(kid) async def get_proof_of_inclusion_by_key( self, key: bytes, store_id: bytes32, ) -> ProofOfInclusion: - """Collect the information for a proof of inclusion of a key and its value in - the Merkle tree. - """ - async with self.db_wrapper.reader(): - node = await self.get_node_by_key(key=key, store_id=store_id) - return await self.get_proof_of_inclusion_by_hash(node_hash=node.hash, store_id=store_id) - - async def get_first_generation(self, node_hash: bytes32, store_id: bytes32) -> int: - async with self.db_wrapper.reader() as reader: - cursor = await reader.execute( - "SELECT MIN(generation) AS generation FROM ancestors WHERE hash == :hash AND tree_id == :tree_id", - {"hash": node_hash, "tree_id": store_id}, - ) - row = await cursor.fetchone() - if row is None: - raise RuntimeError("Hash not found in ancestor table.") - - generation = row["generation"] - return int(generation) + root = await self.get_tree_root(store_id=store_id) + merkle_blob = await self.get_merkle_blob(root_hash=root.node_hash) + kid = await self.get_kvid(key, store_id) + if kid is None: + raise Exception(f"Cannot find key: {key.hex()}") + return merkle_blob.get_proof_of_inclusion(kid) async def write_tree_to_file( self, @@ -1959,25 +1395,38 @@ async def write_tree_to_file( store_id: bytes32, deltas_only: bool, writer: BinaryIO, + merkle_blob: Optional[MerkleBlob] = None, + hash_to_index: Optional[dict[bytes32, TreeIndex]] = None, ) -> None: if node_hash == bytes32.zeros: return + if merkle_blob is None: + merkle_blob = await self.get_merkle_blob(root.node_hash) + if hash_to_index is None: + hash_to_index = merkle_blob.get_hashes_indexes() + if deltas_only: generation = await self.get_first_generation(node_hash, store_id) # Root's generation is not the first time we see this hash, so it's not a new delta. if root.generation != generation: return - node = await self.get_node(node_hash) + + raw_index = hash_to_index[node_hash] + raw_node = merkle_blob.get_raw_node(raw_index) + to_write = b"" - if isinstance(node, InternalNode): - await self.write_tree_to_file(root, node.left_hash, store_id, deltas_only, writer) - await self.write_tree_to_file(root, node.right_hash, store_id, deltas_only, writer) - to_write = bytes(SerializedNode(False, bytes(node.left_hash), bytes(node.right_hash))) - elif isinstance(node, TerminalNode): + if isinstance(raw_node, RawInternalMerkleNode): + left_hash = merkle_blob.get_hash_at_index(raw_node.left) + right_hash = merkle_blob.get_hash_at_index(raw_node.right) + await self.write_tree_to_file(root, left_hash, store_id, deltas_only, writer, merkle_blob, hash_to_index) + await self.write_tree_to_file(root, right_hash, store_id, deltas_only, writer, merkle_blob, hash_to_index) + to_write = bytes(SerializedNode(False, bytes(left_hash), bytes(right_hash))) + elif isinstance(raw_node, RawLeafMerkleNode): + node = await self.get_terminal_node(raw_node.key, raw_node.value, store_id) to_write = bytes(SerializedNode(True, node.key, node.value)) else: - raise Exception(f"Node is neither InternalNode nor TerminalNode: {node}") + raise Exception(f"Node is neither InternalNode nor TerminalNode: {raw_node}") writer.write(len(to_write).to_bytes(4, byteorder="big")) writer.write(to_write) @@ -2065,94 +1514,38 @@ async def remove_subscriptions(self, store_id: bytes32, urls: list[str]) -> None }, ) - async def delete_store_data(self, store_id: bytes32) -> None: - async with self.db_wrapper.writer(foreign_key_enforcement_enabled=False) as writer: - await self.clean_node_table(writer) - cursor = await writer.execute( - """ - WITH RECURSIVE all_nodes AS ( - SELECT a.hash, n.left, n.right - FROM ancestors AS a - JOIN node AS n ON a.hash = n.hash - WHERE a.tree_id = :tree_id - ), - pending_nodes AS ( - SELECT node_hash AS hash FROM root - WHERE status IN (:pending_status, :pending_batch_status) - UNION ALL - SELECT n.left FROM node n - INNER JOIN pending_nodes pn ON n.hash = pn.hash - WHERE n.left IS NOT NULL - UNION ALL - SELECT n.right FROM node n - INNER JOIN pending_nodes pn ON n.hash = pn.hash - WHERE n.right IS NOT NULL - ) - - SELECT hash, left, right - FROM all_nodes - WHERE hash NOT IN (SELECT hash FROM ancestors WHERE tree_id != :tree_id) - AND hash NOT IN (SELECT hash from pending_nodes) - """, - { - "tree_id": store_id, - "pending_status": Status.PENDING.value, - "pending_batch_status": Status.PENDING_BATCH.value, - }, - ) - to_delete: dict[bytes, tuple[bytes, bytes]] = {} - ref_counts: dict[bytes, int] = {} - async for row in cursor: - hash = row["hash"] - left = row["left"] - right = row["right"] - if hash in to_delete: - prev_left, prev_right = to_delete[hash] - assert prev_left == left - assert prev_right == right - continue - to_delete[hash] = (left, right) - if left is not None: - ref_counts[left] = ref_counts.get(left, 0) + 1 - if right is not None: - ref_counts[right] = ref_counts.get(right, 0) + 1 - - await writer.execute("DELETE FROM ancestors WHERE tree_id == ?", (store_id,)) - await writer.execute("DELETE FROM root WHERE tree_id == ?", (store_id,)) - queue = [hash for hash in to_delete if ref_counts.get(hash, 0) == 0] - while queue: - hash = queue.pop(0) - if hash not in to_delete: - continue - await writer.execute("DELETE FROM node WHERE hash == ?", (hash,)) - - left, right = to_delete[hash] - if left is not None: - ref_counts[left] -= 1 - if ref_counts[left] == 0: - queue.append(left) - - if right is not None: - ref_counts[right] -= 1 - if ref_counts[right] == 0: - queue.append(right) - async def unsubscribe(self, store_id: bytes32) -> None: async with self.db_wrapper.writer() as writer: await writer.execute( "DELETE FROM subscriptions WHERE tree_id == :tree_id", {"tree_id": store_id}, ) + await writer.execute( + "DELETE FROM hashes WHERE store_id == :store_id", + {"store_id": store_id}, + ) + await writer.execute( + "DELETE FROM merkleblob WHERE store_id == :store_id", + {"store_id": store_id}, + ) + await writer.execute( + "DELETE FROM ids WHERE store_id == :store_id", + {"store_id": store_id}, + ) + await writer.execute( + "DELETE FROM nodes WHERE store_id == :store_id", + {"store_id": store_id}, + ) async def rollback_to_generation(self, store_id: bytes32, target_generation: int) -> None: async with self.db_wrapper.writer() as writer: await writer.execute( - "DELETE FROM ancestors WHERE tree_id == :tree_id AND generation > :target_generation", + "DELETE FROM root WHERE tree_id == :tree_id AND generation > :target_generation", {"tree_id": store_id, "target_generation": target_generation}, ) await writer.execute( - "DELETE FROM root WHERE tree_id == :tree_id AND generation > :target_generation", - {"tree_id": store_id, "target_generation": target_generation}, + "DELETE FROM nodes WHERE store_id == :store_id AND generation > :target_generation", + {"store_id": store_id, "target_generation": target_generation}, ) async def update_server_info(self, store_id: bytes32, server_info: ServerInfo) -> None: diff --git a/chia/data_layer/download_data.py b/chia/data_layer/download_data.py index 9b210123995c..e8f315895310 100644 --- a/chia/data_layer/download_data.py +++ b/chia/data_layer/download_data.py @@ -10,49 +10,19 @@ import aiohttp from typing_extensions import Literal -from chia.data_layer.data_layer_util import NodeType, PluginRemote, Root, SerializedNode, ServerInfo, Status +from chia.data_layer.data_layer_util import ( + PluginRemote, + Root, + ServerInfo, + get_delta_filename, + get_delta_filename_path, + get_full_tree_filename, + get_full_tree_filename_path, +) from chia.data_layer.data_store import DataStore from chia.types.blockchain_format.sized_bytes import bytes32 -def get_full_tree_filename(store_id: bytes32, node_hash: bytes32, generation: int, group_by_store: bool = False) -> str: - if group_by_store: - return f"{store_id}/{node_hash}-full-{generation}-v1.0.dat" - return f"{store_id}-{node_hash}-full-{generation}-v1.0.dat" - - -def get_delta_filename(store_id: bytes32, node_hash: bytes32, generation: int, group_by_store: bool = False) -> str: - if group_by_store: - return f"{store_id}/{node_hash}-delta-{generation}-v1.0.dat" - return f"{store_id}-{node_hash}-delta-{generation}-v1.0.dat" - - -def get_full_tree_filename_path( - foldername: Path, - store_id: bytes32, - node_hash: bytes32, - generation: int, - group_by_store: bool = False, -) -> Path: - if group_by_store: - path = foldername.joinpath(f"{store_id}") - return path.joinpath(f"{node_hash}-full-{generation}-v1.0.dat") - return foldername.joinpath(f"{store_id}-{node_hash}-full-{generation}-v1.0.dat") - - -def get_delta_filename_path( - foldername: Path, - store_id: bytes32, - node_hash: bytes32, - generation: int, - group_by_store: bool = False, -) -> Path: - if group_by_store: - path = foldername.joinpath(f"{store_id}") - return path.joinpath(f"{node_hash}-delta-{generation}-v1.0.dat") - return foldername.joinpath(f"{store_id}-{node_hash}-delta-{generation}-v1.0.dat") - - def is_filename_valid(filename: str, group_by_store: bool = False) -> bool: if group_by_store: if filename.count("/") != 1: @@ -87,45 +57,6 @@ def is_filename_valid(filename: str, group_by_store: bool = False) -> bool: return reformatted == filename -async def insert_into_data_store_from_file( - data_store: DataStore, - store_id: bytes32, - root_hash: Optional[bytes32], - filename: Path, -) -> int: - num_inserted = 0 - with open(filename, "rb") as reader: - while True: - chunk = b"" - while len(chunk) < 4: - size_to_read = 4 - len(chunk) - cur_chunk = reader.read(size_to_read) - if cur_chunk is None or cur_chunk == b"": - if size_to_read < 4: - raise Exception("Incomplete read of length.") - break - chunk += cur_chunk - if chunk == b"": - break - - size = int.from_bytes(chunk, byteorder="big") - serialize_nodes_bytes = b"" - while len(serialize_nodes_bytes) < size: - size_to_read = size - len(serialize_nodes_bytes) - cur_chunk = reader.read(size_to_read) - if cur_chunk is None or cur_chunk == b"": - raise Exception("Incomplete read of blob.") - serialize_nodes_bytes += cur_chunk - serialized_node = SerializedNode.from_bytes(serialize_nodes_bytes) - - node_type = NodeType.TERMINAL if serialized_node.is_terminal else NodeType.INTERNAL - await data_store.insert_node(node_type, serialized_node.value1, serialized_node.value2) - num_inserted += 1 - - await data_store.insert_root_with_ancestor_table(store_id=store_id, node_hash=root_hash, status=Status.COMMITTED) - return num_inserted - - @dataclass class WriteFilesResult: result: bool @@ -288,15 +219,14 @@ async def insert_from_delta_file( existing_generation, group_files_by_store, ) - num_inserted = await insert_into_data_store_from_file( - data_store, + await data_store.insert_into_data_store_from_file( store_id, None if root_hash == bytes32.zeros else root_hash, target_filename_path, ) log.info( f"Successfully inserted hash {root_hash} from delta file. " - f"Generation: {existing_generation}. Store id: {store_id}. Nodes inserted: {num_inserted}." + f"Generation: {existing_generation}. Store id: {store_id}." ) if target_generation - existing_generation <= maximum_full_file_count - 1: @@ -386,4 +316,4 @@ async def http_download( new_percentage = f"{progress_byte / size:.0%}" if new_percentage != progress_percentage: progress_percentage = new_percentage - log.debug(f"Downloading delta file {filename}. {progress_percentage} of {size} bytes.") + log.info(f"Downloading delta file {filename}. {progress_percentage} of {size} bytes.") diff --git a/chia/data_layer/util/benchmark.py b/chia/data_layer/util/benchmark.py index 2808060897a3..1c3d0ce3bfaa 100644 --- a/chia/data_layer/util/benchmark.py +++ b/chia/data_layer/util/benchmark.py @@ -6,14 +6,13 @@ import tempfile import time from pathlib import Path -from typing import Optional -from chia.data_layer.data_layer_util import Side, TerminalNode, leaf_hash +from chia.data_layer.data_layer_util import Side, Status, leaf_hash from chia.data_layer.data_store import DataStore from chia.types.blockchain_format.sized_bytes import bytes32 -async def generate_datastore(num_nodes: int, slow_mode: bool) -> None: +async def generate_datastore(num_nodes: int) -> None: with tempfile.TemporaryDirectory() as temp_directory: temp_directory_path = Path(temp_directory) db_path = temp_directory_path.joinpath("dl_benchmark.sqlite") @@ -23,9 +22,8 @@ async def generate_datastore(num_nodes: int, slow_mode: bool) -> None: os.remove(db_path) async with DataStore.managed(database=db_path) as data_store: - store_id = bytes32(b"0" * 32) - await data_store.create_tree(store_id) + await data_store.create_tree(store_id, status=Status.COMMITTED) insert_time = 0.0 insert_count = 0 @@ -37,58 +35,40 @@ async def generate_datastore(num_nodes: int, slow_mode: bool) -> None: for i in range(num_nodes): key = i.to_bytes(4, byteorder="big") value = (2 * i).to_bytes(4, byteorder="big") - seed = leaf_hash(key=key, value=value) - reference_node_hash: Optional[bytes32] = await data_store.get_terminal_node_for_seed(store_id, seed) - side: Optional[Side] = data_store.get_side_for_seed(seed) + seed = leaf_hash(key, value) + node = await data_store.get_terminal_node_for_seed(seed, store_id) - if i == 0: - reference_node_hash = None - side = None if i % 3 == 0: t1 = time.time() - if not slow_mode: - await data_store.insert( - key=key, - value=value, - store_id=store_id, - reference_node_hash=reference_node_hash, - side=side, - ) - else: - await data_store.insert( - key=key, - value=value, - store_id=store_id, - reference_node_hash=reference_node_hash, - side=side, - use_optimized=False, - ) + await data_store.autoinsert( + key=key, + value=value, + store_id=store_id, + status=Status.COMMITTED, + ) t2 = time.time() - insert_time += t2 - t1 - insert_count += 1 + autoinsert_count += 1 elif i % 3 == 1: + assert node is not None + reference_node_hash = node.hash + side_seed = bytes(seed)[0] + side = Side.LEFT if side_seed < 128 else Side.RIGHT t1 = time.time() - if not slow_mode: - await data_store.autoinsert(key=key, value=value, store_id=store_id) - else: - await data_store.autoinsert( - key=key, - value=value, - store_id=store_id, - use_optimized=False, - ) + await data_store.insert( + key=key, + value=value, + store_id=store_id, + reference_node_hash=reference_node_hash, + side=side, + status=Status.COMMITTED, + ) t2 = time.time() - autoinsert_time += t2 - t1 - autoinsert_count += 1 + insert_time += t2 - t1 + insert_count += 1 else: t1 = time.time() - assert reference_node_hash is not None - node = await data_store.get_node(reference_node_hash) - assert isinstance(node, TerminalNode) - if not slow_mode: - await data_store.delete(key=node.key, store_id=store_id) - else: - await data_store.delete(key=node.key, store_id=store_id, use_optimized=False) + assert node is not None + await data_store.delete(key=node.key, store_id=store_id, status=Status.COMMITTED) t2 = time.time() delete_time += t2 - t1 delete_count += 1 @@ -96,13 +76,10 @@ async def generate_datastore(num_nodes: int, slow_mode: bool) -> None: print(f"Average insert time: {insert_time / insert_count}") print(f"Average autoinsert time: {autoinsert_time / autoinsert_count}") print(f"Average delete time: {delete_time / delete_count}") - print(f"Total time for {num_nodes} operations: {insert_time + autoinsert_time + delete_time}") + print(f"Total time for {num_nodes} operations: {insert_time + delete_time + autoinsert_time}") root = await data_store.get_tree_root(store_id=store_id) print(f"Root hash: {root.node_hash}") if __name__ == "__main__": - slow_mode = False - if len(sys.argv) > 2 and sys.argv[2] == "slow": - slow_mode = True - asyncio.run(generate_datastore(int(sys.argv[1]), slow_mode)) + asyncio.run(generate_datastore(int(sys.argv[1]))) diff --git a/chia/data_layer/util/merkle_blob.py b/chia/data_layer/util/merkle_blob.py index 9cad261545cd..d488259cde56 100644 --- a/chia/data_layer/util/merkle_blob.py +++ b/chia/data_layer/util/merkle_blob.py @@ -208,16 +208,16 @@ def update_entry( self.blob[data_start:end] = pack_raw_node(new_node) def get_random_leaf_node(self, seed: bytes) -> RawLeafMerkleNode: + path = "".join(reversed("".join(f"{b:08b}" for b in seed))) node = self.get_raw_node(TreeIndex(0)) - for byte in seed: - for bit in range(8): - if isinstance(node, RawLeafMerkleNode): - return node - assert isinstance(node, RawInternalMerkleNode) - if byte & (1 << bit): - node = self.get_raw_node(node.left) - else: - node = self.get_raw_node(node.right) + for bit in path: + if isinstance(node, RawLeafMerkleNode): + return node + assert isinstance(node, RawInternalMerkleNode) + if bit == "0": + node = self.get_raw_node(node.left) + else: + node = self.get_raw_node(node.right) raise Exception("Cannot find leaf from seed") @@ -239,6 +239,22 @@ def get_keys_indexes(self) -> dict[KVId, TreeIndex]: return key_to_index + def get_hashes_indexes(self) -> dict[bytes32, TreeIndex]: + if len(self.blob) == 0: + return {} + + hash_to_index: dict[bytes32, TreeIndex] = {} + queue: list[TreeIndex] = [TreeIndex(0)] + while len(queue) > 0: + node_index = queue.pop() + node = self.get_raw_node(node_index) + hash_to_index[bytes32(node.hash)] = node_index + if isinstance(node, RawInternalMerkleNode): + queue.append(node.left) + queue.append(node.right) + + return hash_to_index + def get_keys_values(self) -> dict[KVId, KVId]: if len(self.blob) == 0: return {} @@ -326,6 +342,9 @@ def insert_from_leaf(self, old_leaf_index: TreeIndex, new_index: TreeIndex, side if isinstance(new_node, RawLeafMerkleNode): self.key_to_index[new_node.key] = new_index + def key_exists(self, key: KVId) -> bool: + return key in self.key_to_index + def insert( self, key: KVId, @@ -359,7 +378,10 @@ def insert( if len(self.key_to_index) == 1: self.blob.clear() - internal_node_hash = internal_hash(bytes32(old_leaf.hash), bytes32(hash)) + if side == Side.LEFT: + internal_node_hash = internal_hash(bytes32(hash), bytes32(old_leaf.hash)) + else: + internal_node_hash = internal_hash(bytes32(old_leaf.hash), bytes32(hash)) self.blob.extend( NodeMetadata(type=NodeType.internal, dirty=False).pack() + pack_raw_node( @@ -476,6 +498,7 @@ def get_nodes_with_indexes(self, index: TreeIndex = TreeIndex(0)) -> list[tuple[ return this assert isinstance(node, RawInternalMerkleNode) + left_nodes = self.get_nodes_with_indexes(node.left) right_nodes = self.get_nodes_with_indexes(node.right) diff --git a/chia/rpc/data_layer_rpc_api.py b/chia/rpc/data_layer_rpc_api.py index f7f7ca308e32..7916008aa917 100644 --- a/chia/rpc/data_layer_rpc_api.py +++ b/chia/rpc/data_layer_rpc_api.py @@ -605,7 +605,7 @@ async def get_proof(self, request: GetProofRequest) -> GetProofResponse: for key in request.keys: node = await self.service.data_store.get_node_by_key(store_id=request.store_id, key=key) pi = await self.service.data_store.get_proof_of_inclusion_by_hash( - store_id=request.store_id, node_hash=node.hash, use_optimized=True + store_id=request.store_id, node_hash=node.hash ) proof = HashOnlyProof.from_key_value(