Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parsing TapLeaf scripts when decoding PSBT #198

Merged
merged 6 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions electrumx/lib/psbt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import List, Tuple

from electrumx.lib.util import (
unpack_le_uint16_from,
unpack_le_uint32_from,
unpack_le_uint64_from,
)


def read_bytes(data, offset, length):
if offset + length > len(data):
raise IndexError(f"Offset out of range while reading bytes at offset {offset}")
return data[offset : offset + length], offset + length


def find_tapleaf_scripts(inputs):
tapleaf_scripts = []
for input_map in inputs:
for key, value in input_map.items():
if key[0] == 0x15: # 0x15 is the type for tapleaf scripts
tapleaf_scripts.append(value)
return tapleaf_scripts


def parse_psbt_hex_and_operations(psbt_hex: str) -> Tuple[str, List[bytes]]:
"""
Parse the PSBT into raw TX, and resolves the optional Atomicals operations from Taproot Leaf scripts.
:param psbt_hex: The PSBT text in hex format.
:return: converted TX in hex format and optional Atomicals operations.
"""
psbt_bytes = bytes.fromhex(psbt_hex)
magic = psbt_bytes[:5]
if magic != b"\x70\x73\x62\x74\xff":
raise ValueError("Invalid PSBT magic bytes")

offset = 5
global_map = {}
inputs = []
outputs = []

def read_varint(data, cursor):
v = data[cursor]
cursor += 1
if v < 0xFD:
return v, cursor
elif v == 0xFD:
return unpack_le_uint16_from(data, cursor)[0], cursor + 2
elif v == 0xFE:
return unpack_le_uint32_from(data, cursor)[0], cursor + 4
else:
return unpack_le_uint64_from(data, cursor)[0], cursor + 8

while offset < len(psbt_bytes):
key_len, offset = read_varint(psbt_bytes, offset)
if key_len == 0:
break
key = psbt_bytes[offset : offset + key_len]
offset += key_len
value_len, offset = read_varint(psbt_bytes, offset)
value = psbt_bytes[offset : offset + value_len]
offset += value_len

if key[0] == 0x00:
global_map[key] = value
elif key[0] == 0x01:
inputs.append((key, value))
elif key[0] == 0x02:
outputs.append((key, value))

unsigned_tx = global_map.get(b"\x00")
if unsigned_tx is None:
raise ValueError("No unsigned transaction found in PSBT")

def parse_map(data, o):
m = {}
while o < len(data) and data[o] != 0x00:
kl, o = read_varint(data, o)
k, o = read_bytes(data, o, kl)
vl, o = read_varint(data, o)
v, o = read_bytes(data, o, vl)
m[k] = v
return m, o + 1

input_count, offset_tx = read_varint(unsigned_tx, 4)
offset_tx += 4

for i in range(input_count):
if offset >= len(psbt_bytes):
raise IndexError(f"Offset out of range while parsing input map at index {i}")
input_map, offset = parse_map(psbt_bytes, offset)
inputs.append(input_map)

tap_leafs = find_tapleaf_scripts(inputs)
return unsigned_tx.hex(), tap_leafs
47 changes: 0 additions & 47 deletions electrumx/lib/tx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,50 +1462,3 @@ def read_header(self, static_header_size):
header_end = self.cursor
self.cursor = start
return self._read_nbytes(header_end - start)


def psbt_hex_to_tx_hex(psbt_hex: str):
psbt_bytes = bytes.fromhex(psbt_hex)
magic = psbt_bytes[:5]
if magic != b"\x70\x73\x62\x74\xff":
raise ValueError("Invalid PSBT magic bytes")

offset = 5
global_map = {}
inputs = []
outputs = []

def read_varint(data, cursor):
v = data[cursor]
cursor += 1
if v < 0xFD:
return v, cursor
elif v == 0xFD:
return unpack_le_uint16_from(data, cursor)[0], cursor + 2
elif v == 0xFE:
return unpack_le_uint32_from(data, cursor)[0], cursor + 4
else:
return unpack_le_uint64_from(data, cursor)[0], cursor + 8

while offset < len(psbt_bytes):
key_len, offset = read_varint(psbt_bytes, offset)
if key_len == 0:
break
key = psbt_bytes[offset : offset + key_len]
offset += key_len
value_len, offset = read_varint(psbt_bytes, offset)
value = psbt_bytes[offset : offset + value_len]
offset += value_len

if key[0] == 0x00:
global_map[key] = value
elif key[0] == 0x01:
inputs.append((key, value))
elif key[0] == 0x02:
outputs.append((key, value))

unsigned_tx = global_map.get(b"\x00")
if unsigned_tx is None:
raise ValueError("No unsigned transaction found in PSBT")

return unsigned_tx.hex()
41 changes: 37 additions & 4 deletions electrumx/lib/util_atomicals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,12 +1247,45 @@ def is_op_return_dmitem_payment_marker_atomical_id(script):
return script[start_index + 5 + 2 + 1 : start_index + 5 + 2 + 1 + 36]


def parse_atomicals_operations_from_tap_leafs(scripts, allow_args_bytes: bool):
# All inputs are parsed but further upstream most operations will only function if placed in the 0'th input
op_name, payload, index = parse_protocols_operations_from_witness_for_input(scripts)
if not op_name:
return None
decoded_object = {}
if payload:
# Ensure that the payload is cbor encoded dictionary or empty
try:
decoded_object = loads(payload)
if not isinstance(decoded_object, dict):
return None
except Exception as e:
return None
# Also enforce that if there are meta, args, or ctx fields that they must be dicts
# This is done to ensure that these fields are always easily parseable and do not contain unexpected data
# which could cause parsing problems later.
# Ensure that they are not allowed to contain bytes like objects
if (
not is_sanitized_dict_whitelist_only(decoded_object.get("meta", {}))
or not is_sanitized_dict_whitelist_only(decoded_object.get("args", {}), allow_args_bytes)
or not is_sanitized_dict_whitelist_only(decoded_object.get("ctx", {}))
or not is_sanitized_dict_whitelist_only(decoded_object.get("init", {}), True)
):
return None
return {
"op": op_name,
"payload": decoded_object,
"input_index": index,
}
return None


# Parses and detects valid Atomicals protocol operations in a witness script
# Stops when it finds the first operation in the first input
def parse_protocols_operations_from_witness_for_input(txinwitness):
"""Detect and parse all operations across the witness input arrays from a tx"""
atomical_operation_type_map = {}
for script in txinwitness:
for i, script in enumerate(txinwitness):
n = 0
script_entry_len = len(script)
if script_entry_len < 39 or script[0] != 0x20:
Expand All @@ -1274,13 +1307,13 @@ def parse_protocols_operations_from_witness_for_input(txinwitness):
# Parse to ensure it is in the right format
operation_type, payload = parse_operation_from_script(script, n + 5)
if operation_type is not None:
return operation_type, payload
return operation_type, payload, i
break
if found_operation_definition:
break
else:
break
return None, None
return None, None, None


# Parses and detects the witness script array and detects the Atomicals operations
Expand All @@ -1291,7 +1324,7 @@ def parse_protocols_operations_from_witness_array(tx, tx_hash, allow_args_bytes)
txin_idx = 0
for txinwitness in tx.witness:
# All inputs are parsed but further upstream most operations will only function if placed in the 0'th input
op_name, payload = parse_protocols_operations_from_witness_for_input(txinwitness)
op_name, payload, _ = parse_protocols_operations_from_witness_for_input(txinwitness)
if not op_name:
continue
decoded_object = {}
Expand Down
36 changes: 21 additions & 15 deletions electrumx/server/session/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,11 @@ def validate_raw_tx_blueprint(self, raw_tx, raise_if_burned=True) -> AtomicalsVa
)

# Helper method to decode the transaction and returns formatted structure.
async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
async def transaction_decode_raw_tx_blueprint(
self,
raw_tx: bytes,
tap_leafs: Optional[List[bytes]],
) -> dict:
# Deserialize the transaction
tx, tx_hash = self.env.coin.DESERIALIZER(raw_tx, 0).read_tx_and_hash()
cache_res = self._tx_decode_cache.get(tx_hash)
Expand All @@ -848,7 +852,10 @@ async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
return cache_res

# Determine if there are any other operations at the transfer
operations_found_at_inputs = parse_protocols_operations_from_witness_array(tx, tx_hash, True)
if tap_leafs:
found_operations = parse_atomicals_operations_from_tap_leafs(tap_leafs, True)
else:
found_operations = parse_protocols_operations_from_witness_array(tx, tx_hash, True)
# Build the map of the atomicals potential spent at the tx
atomicals_spent_at_inputs: Dict[int:List] = self.bp.build_atomicals_spent_at_inputs_for_validation_only(tx)
# Build a structure of organizing into NFT and FTs
Expand All @@ -857,7 +864,7 @@ async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
blueprint_builder = AtomicalsTransferBlueprintBuilder(
self.logger,
atomicals_spent_at_inputs,
operations_found_at_inputs,
found_operations,
tx_hash,
tx,
self.bp.get_atomicals_id_mint_info,
Expand All @@ -870,15 +877,16 @@ async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
encoded_spent_at_inputs = encode_atomical_ids_hex(atomicals_spent_at_inputs)
encoded_ft_output_blueprint: Dict[str, Dict] = dict(encode_atomical_ids_hex(ft_output_blueprint))
encoded_nft_output_blueprint: Dict[str, Dict] = dict(encode_atomical_ids_hex(nft_output_blueprint))
op = operations_found_at_inputs.get("op") or "transfer"
payload = operations_found_at_inputs.get("payload")
op = found_operations.get("op") or "transfer"
burned = {
**auto_encode_bytes_items(encoded_ft_output_blueprint["fts_burned"]),
**auto_encode_bytes_items(encoded_nft_output_blueprint["nfts_burned"]),
}
ret = {
"op": [op],
"burned": {
**auto_encode_bytes_items(encoded_ft_output_blueprint["fts_burned"]),
**auto_encode_bytes_items(encoded_nft_output_blueprint["nfts_burned"]),
},
"burned": dict(sorted(burned.items())),
}
payload = found_operations.get("payload")
if payload:
ret["op_payload"] = payload
atomicals = []
Expand All @@ -905,7 +913,7 @@ async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
if not outputs.get(k3):
outputs[k3] = {}
outputs[k3][atomical_id] = item3.atomical_value
mint_info = {}
mint_info: Dict | None = None
if blueprint_builder.is_mint:
if op in ["dmt", "ft"]:
tx_out = tx.outputs[0]
Expand All @@ -915,7 +923,6 @@ async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
mint_info = {
"atomical_id": atomical_id,
"outputs": {
"atomical_id": atomical_id,
"index": 0,
"value": tx_out.value,
},
Expand All @@ -927,7 +934,6 @@ async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
mint_info = {
"atomical_id": atomical_id,
"outputs": {
"atomical_id": atomical_id,
"index": 0,
"value": tx_out.value,
},
Expand All @@ -939,7 +945,7 @@ async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
if not outputs.get(index):
outputs[index] = {}
outputs[index][atomical_id] = value
payment_info = {}
payment_info: Dict | None = None
(
payment_id,
payment_idx,
Expand All @@ -951,8 +957,8 @@ async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
"payment_marker_idx": payment_idx,
}
ret["atomicals"] = [await self.atomical_id_get(atomical_id) for atomical_id in atomicals]
ret["inputs"] = inputs
ret["outputs"] = outputs
ret["inputs"] = dict(sorted(inputs.items()))
ret["outputs"] = dict(sorted(outputs.items()))
ret["payment"] = payment_info
self._tx_decode_cache[tx_hash] = ret
return ret
Expand Down
13 changes: 8 additions & 5 deletions electrumx/server/session/shared_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from electrumx.lib import util
from electrumx.lib.atomicals_blueprint_builder import AtomicalsValidationError
from electrumx.lib.psbt import parse_psbt_hex_and_operations
from electrumx.lib.script2addr import get_address_from_output_script
from electrumx.lib.tx import psbt_hex_to_tx_hex
from electrumx.lib.util_atomicals import *
from electrumx.server.daemon import DaemonError
from electrumx.server.session import ATOMICALS_INVALID_TX, BAD_REQUEST
Expand Down Expand Up @@ -964,7 +964,7 @@ async def transaction_broadcast_force(self, raw_tx: str):
return hex_hash

def transaction_validate_psbt_blueprint(self, psbt_hex: str):
raw_tx = psbt_hex_to_tx_hex(psbt_hex)
raw_tx, _ = parse_psbt_hex_and_operations(psbt_hex)
return self.transaction_validate_tx_blueprint(raw_tx)

def transaction_validate_tx_blueprint(self, raw_tx: str):
Expand All @@ -973,13 +973,16 @@ def transaction_validate_tx_blueprint(self, raw_tx: str):
return {"result": dict(result)}

async def transaction_decode_psbt(self, psbt_hex: str):
tx = psbt_hex_to_tx_hex(psbt_hex)
return await self.transaction_decode_tx(tx)
tx, tap_leafs = parse_psbt_hex_and_operations(psbt_hex)
return await self._transaction_decode(tx, tap_leafs)

async def transaction_decode_tx(self, tx: str):
return await self._transaction_decode(tx)

async def _transaction_decode(self, tx: str, tap_leafs=None):
raw_tx = bytes.fromhex(tx)
self.bump_cost(0.25 + len(raw_tx) / 5000)
result = await self.session_mgr.transaction_decode_raw_tx_blueprint(raw_tx)
result = await self.session_mgr.transaction_decode_raw_tx_blueprint(raw_tx, tap_leafs)
self.logger.debug(f"transaction_decode: {result}")
return {"result": result}

Expand Down
3 changes: 2 additions & 1 deletion tests/lib/test_tx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import bitcointx.core.psbt as psbt

import electrumx.lib.psbt as psbt_lib
import electrumx.lib.tx as tx_lib

tests = [
Expand Down Expand Up @@ -74,7 +75,7 @@ def test_psbt_parse():
"c5f5be0001012ba086010000000000225120bf3b636c6e2727c374ee2cf87d3c44515c27c5c771c9a81aa6eee938ef6a6f8e0117208e4a"
"17bed47864479d5259371382debdb949c03185c1ac6603eb4946cd7da3f30000000000000000000000"
)
decoded = tx_lib.psbt_hex_to_tx_hex(psbt_hex)
decoded, operations = psbt_lib.parse_psbt_hex_and_operations(psbt_hex)
expected_tx = (
psbt.PartiallySignedTransaction.from_base64_or_binary(
bytes.fromhex(psbt_hex),
Expand Down
Loading