Skip to content

Commit

Permalink
Parsing TapLeaf scripts when decoding PSBT (#198)
Browse files Browse the repository at this point in the history
  • Loading branch information
wizz-wallet-dev authored Jun 13, 2024
2 parents e1a0640 + 419e97a commit 729fd84
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 72 deletions.
94 changes: 94 additions & 0 deletions electrumx/lib/psbt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import List, Tuple

from electrumx.lib.util import (
unpack_le_uint16_from,
unpack_le_uint32_from,
unpack_le_uint64_from,
)


def read_bytes(data, offset, length):
if offset + length > len(data):
raise IndexError(f"Offset out of range while reading bytes at offset {offset}")
return data[offset : offset + length], offset + length


def find_tapleaf_scripts(inputs):
tapleaf_scripts = []
for input_map in inputs:
for key, value in input_map.items():
if key[0] == 0x15: # 0x15 is the type for tapleaf scripts
tapleaf_scripts.append(value)
return tapleaf_scripts


def parse_psbt_hex_and_operations(psbt_hex: str) -> Tuple[str, List[bytes]]:
"""
Parse the PSBT into raw TX, and resolves the optional Atomicals operations from Taproot Leaf scripts.
:param psbt_hex: The PSBT text in hex format.
:return: converted TX in hex format and optional Atomicals operations.
"""
psbt_bytes = bytes.fromhex(psbt_hex)
magic = psbt_bytes[:5]
if magic != b"\x70\x73\x62\x74\xff":
raise ValueError("Invalid PSBT magic bytes")

offset = 5
global_map = {}
inputs = []
outputs = []

def read_varint(data, cursor):
v = data[cursor]
cursor += 1
if v < 0xFD:
return v, cursor
elif v == 0xFD:
return unpack_le_uint16_from(data, cursor)[0], cursor + 2
elif v == 0xFE:
return unpack_le_uint32_from(data, cursor)[0], cursor + 4
else:
return unpack_le_uint64_from(data, cursor)[0], cursor + 8

while offset < len(psbt_bytes):
key_len, offset = read_varint(psbt_bytes, offset)
if key_len == 0:
break
key = psbt_bytes[offset : offset + key_len]
offset += key_len
value_len, offset = read_varint(psbt_bytes, offset)
value = psbt_bytes[offset : offset + value_len]
offset += value_len

if key[0] == 0x00:
global_map[key] = value
elif key[0] == 0x01:
inputs.append((key, value))
elif key[0] == 0x02:
outputs.append((key, value))

unsigned_tx = global_map.get(b"\x00")
if unsigned_tx is None:
raise ValueError("No unsigned transaction found in PSBT")

def parse_map(data, o):
m = {}
while o < len(data) and data[o] != 0x00:
kl, o = read_varint(data, o)
k, o = read_bytes(data, o, kl)
vl, o = read_varint(data, o)
v, o = read_bytes(data, o, vl)
m[k] = v
return m, o + 1

input_count, offset_tx = read_varint(unsigned_tx, 4)
offset_tx += 4

for i in range(input_count):
if offset >= len(psbt_bytes):
raise IndexError(f"Offset out of range while parsing input map at index {i}")
input_map, offset = parse_map(psbt_bytes, offset)
inputs.append(input_map)

tap_leafs = find_tapleaf_scripts(inputs)
return unsigned_tx.hex(), tap_leafs
47 changes: 0 additions & 47 deletions electrumx/lib/tx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,50 +1462,3 @@ def read_header(self, static_header_size):
header_end = self.cursor
self.cursor = start
return self._read_nbytes(header_end - start)


def psbt_hex_to_tx_hex(psbt_hex: str):
psbt_bytes = bytes.fromhex(psbt_hex)
magic = psbt_bytes[:5]
if magic != b"\x70\x73\x62\x74\xff":
raise ValueError("Invalid PSBT magic bytes")

offset = 5
global_map = {}
inputs = []
outputs = []

def read_varint(data, cursor):
v = data[cursor]
cursor += 1
if v < 0xFD:
return v, cursor
elif v == 0xFD:
return unpack_le_uint16_from(data, cursor)[0], cursor + 2
elif v == 0xFE:
return unpack_le_uint32_from(data, cursor)[0], cursor + 4
else:
return unpack_le_uint64_from(data, cursor)[0], cursor + 8

while offset < len(psbt_bytes):
key_len, offset = read_varint(psbt_bytes, offset)
if key_len == 0:
break
key = psbt_bytes[offset : offset + key_len]
offset += key_len
value_len, offset = read_varint(psbt_bytes, offset)
value = psbt_bytes[offset : offset + value_len]
offset += value_len

if key[0] == 0x00:
global_map[key] = value
elif key[0] == 0x01:
inputs.append((key, value))
elif key[0] == 0x02:
outputs.append((key, value))

unsigned_tx = global_map.get(b"\x00")
if unsigned_tx is None:
raise ValueError("No unsigned transaction found in PSBT")

return unsigned_tx.hex()
41 changes: 37 additions & 4 deletions electrumx/lib/util_atomicals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,12 +1247,45 @@ def is_op_return_dmitem_payment_marker_atomical_id(script):
return script[start_index + 5 + 2 + 1 : start_index + 5 + 2 + 1 + 36]


def parse_atomicals_operations_from_tap_leafs(scripts, allow_args_bytes: bool):
# All inputs are parsed but further upstream most operations will only function if placed in the 0'th input
op_name, payload, index = parse_protocols_operations_from_witness_for_input(scripts)
if not op_name:
return None
decoded_object = {}
if payload:
# Ensure that the payload is cbor encoded dictionary or empty
try:
decoded_object = loads(payload)
if not isinstance(decoded_object, dict):
return None
except Exception as e:
return None
# Also enforce that if there are meta, args, or ctx fields that they must be dicts
# This is done to ensure that these fields are always easily parseable and do not contain unexpected data
# which could cause parsing problems later.
# Ensure that they are not allowed to contain bytes like objects
if (
not is_sanitized_dict_whitelist_only(decoded_object.get("meta", {}))
or not is_sanitized_dict_whitelist_only(decoded_object.get("args", {}), allow_args_bytes)
or not is_sanitized_dict_whitelist_only(decoded_object.get("ctx", {}))
or not is_sanitized_dict_whitelist_only(decoded_object.get("init", {}), True)
):
return None
return {
"op": op_name,
"payload": decoded_object,
"input_index": index,
}
return None


# Parses and detects valid Atomicals protocol operations in a witness script
# Stops when it finds the first operation in the first input
def parse_protocols_operations_from_witness_for_input(txinwitness):
"""Detect and parse all operations across the witness input arrays from a tx"""
atomical_operation_type_map = {}
for script in txinwitness:
for i, script in enumerate(txinwitness):
n = 0
script_entry_len = len(script)
if script_entry_len < 39 or script[0] != 0x20:
Expand All @@ -1274,13 +1307,13 @@ def parse_protocols_operations_from_witness_for_input(txinwitness):
# Parse to ensure it is in the right format
operation_type, payload = parse_operation_from_script(script, n + 5)
if operation_type is not None:
return operation_type, payload
return operation_type, payload, i
break
if found_operation_definition:
break
else:
break
return None, None
return None, None, None


# Parses and detects the witness script array and detects the Atomicals operations
Expand All @@ -1291,7 +1324,7 @@ def parse_protocols_operations_from_witness_array(tx, tx_hash, allow_args_bytes)
txin_idx = 0
for txinwitness in tx.witness:
# All inputs are parsed but further upstream most operations will only function if placed in the 0'th input
op_name, payload = parse_protocols_operations_from_witness_for_input(txinwitness)
op_name, payload, _ = parse_protocols_operations_from_witness_for_input(txinwitness)
if not op_name:
continue
decoded_object = {}
Expand Down
36 changes: 21 additions & 15 deletions electrumx/server/session/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,11 @@ def validate_raw_tx_blueprint(self, raw_tx, raise_if_burned=True) -> AtomicalsVa
)

# Helper method to decode the transaction and returns formatted structure.
async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
async def transaction_decode_raw_tx_blueprint(
self,
raw_tx: bytes,
tap_leafs: Optional[List[bytes]],
) -> dict:
# Deserialize the transaction
tx, tx_hash = self.env.coin.DESERIALIZER(raw_tx, 0).read_tx_and_hash()
cache_res = self._tx_decode_cache.get(tx_hash)
Expand All @@ -848,7 +852,10 @@ async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
return cache_res

# Determine if there are any other operations at the transfer
operations_found_at_inputs = parse_protocols_operations_from_witness_array(tx, tx_hash, True)
if tap_leafs:
found_operations = parse_atomicals_operations_from_tap_leafs(tap_leafs, True)
else:
found_operations = parse_protocols_operations_from_witness_array(tx, tx_hash, True)
# Build the map of the atomicals potential spent at the tx
atomicals_spent_at_inputs: Dict[int:List] = self.bp.build_atomicals_spent_at_inputs_for_validation_only(tx)
# Build a structure of organizing into NFT and FTs
Expand All @@ -857,7 +864,7 @@ async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
blueprint_builder = AtomicalsTransferBlueprintBuilder(
self.logger,
atomicals_spent_at_inputs,
operations_found_at_inputs,
found_operations,
tx_hash,
tx,
self.bp.get_atomicals_id_mint_info,
Expand All @@ -870,15 +877,16 @@ async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
encoded_spent_at_inputs = encode_atomical_ids_hex(atomicals_spent_at_inputs)
encoded_ft_output_blueprint: Dict[str, Dict] = dict(encode_atomical_ids_hex(ft_output_blueprint))
encoded_nft_output_blueprint: Dict[str, Dict] = dict(encode_atomical_ids_hex(nft_output_blueprint))
op = operations_found_at_inputs.get("op") or "transfer"
payload = operations_found_at_inputs.get("payload")
op = found_operations.get("op") or "transfer"
burned = {
**auto_encode_bytes_items(encoded_ft_output_blueprint["fts_burned"]),
**auto_encode_bytes_items(encoded_nft_output_blueprint["nfts_burned"]),
}
ret = {
"op": [op],
"burned": {
**auto_encode_bytes_items(encoded_ft_output_blueprint["fts_burned"]),
**auto_encode_bytes_items(encoded_nft_output_blueprint["nfts_burned"]),
},
"burned": dict(sorted(burned.items())),
}
payload = found_operations.get("payload")
if payload:
ret["op_payload"] = payload
atomicals = []
Expand All @@ -905,7 +913,7 @@ async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
if not outputs.get(k3):
outputs[k3] = {}
outputs[k3][atomical_id] = item3.atomical_value
mint_info = {}
mint_info: Dict | None = None
if blueprint_builder.is_mint:
if op in ["dmt", "ft"]:
tx_out = tx.outputs[0]
Expand All @@ -915,7 +923,6 @@ async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
mint_info = {
"atomical_id": atomical_id,
"outputs": {
"atomical_id": atomical_id,
"index": 0,
"value": tx_out.value,
},
Expand All @@ -927,7 +934,6 @@ async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
mint_info = {
"atomical_id": atomical_id,
"outputs": {
"atomical_id": atomical_id,
"index": 0,
"value": tx_out.value,
},
Expand All @@ -939,7 +945,7 @@ async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
if not outputs.get(index):
outputs[index] = {}
outputs[index][atomical_id] = value
payment_info = {}
payment_info: Dict | None = None
(
payment_id,
payment_idx,
Expand All @@ -951,8 +957,8 @@ async def transaction_decode_raw_tx_blueprint(self, raw_tx: bytes) -> dict:
"payment_marker_idx": payment_idx,
}
ret["atomicals"] = [await self.atomical_id_get(atomical_id) for atomical_id in atomicals]
ret["inputs"] = inputs
ret["outputs"] = outputs
ret["inputs"] = dict(sorted(inputs.items()))
ret["outputs"] = dict(sorted(outputs.items()))
ret["payment"] = payment_info
self._tx_decode_cache[tx_hash] = ret
return ret
Expand Down
13 changes: 8 additions & 5 deletions electrumx/server/session/shared_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from electrumx.lib import util
from electrumx.lib.atomicals_blueprint_builder import AtomicalsValidationError
from electrumx.lib.psbt import parse_psbt_hex_and_operations
from electrumx.lib.script2addr import get_address_from_output_script
from electrumx.lib.tx import psbt_hex_to_tx_hex
from electrumx.lib.util_atomicals import *
from electrumx.server.daemon import DaemonError
from electrumx.server.session import ATOMICALS_INVALID_TX, BAD_REQUEST
Expand Down Expand Up @@ -964,7 +964,7 @@ async def transaction_broadcast_force(self, raw_tx: str):
return hex_hash

def transaction_validate_psbt_blueprint(self, psbt_hex: str):
raw_tx = psbt_hex_to_tx_hex(psbt_hex)
raw_tx, _ = parse_psbt_hex_and_operations(psbt_hex)
return self.transaction_validate_tx_blueprint(raw_tx)

def transaction_validate_tx_blueprint(self, raw_tx: str):
Expand All @@ -973,13 +973,16 @@ def transaction_validate_tx_blueprint(self, raw_tx: str):
return {"result": dict(result)}

async def transaction_decode_psbt(self, psbt_hex: str):
tx = psbt_hex_to_tx_hex(psbt_hex)
return await self.transaction_decode_tx(tx)
tx, tap_leafs = parse_psbt_hex_and_operations(psbt_hex)
return await self._transaction_decode(tx, tap_leafs)

async def transaction_decode_tx(self, tx: str):
return await self._transaction_decode(tx)

async def _transaction_decode(self, tx: str, tap_leafs=None):
raw_tx = bytes.fromhex(tx)
self.bump_cost(0.25 + len(raw_tx) / 5000)
result = await self.session_mgr.transaction_decode_raw_tx_blueprint(raw_tx)
result = await self.session_mgr.transaction_decode_raw_tx_blueprint(raw_tx, tap_leafs)
self.logger.debug(f"transaction_decode: {result}")
return {"result": result}

Expand Down
3 changes: 2 additions & 1 deletion tests/lib/test_tx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import bitcointx.core.psbt as psbt

import electrumx.lib.psbt as psbt_lib
import electrumx.lib.tx as tx_lib

tests = [
Expand Down Expand Up @@ -74,7 +75,7 @@ def test_psbt_parse():
"c5f5be0001012ba086010000000000225120bf3b636c6e2727c374ee2cf87d3c44515c27c5c771c9a81aa6eee938ef6a6f8e0117208e4a"
"17bed47864479d5259371382debdb949c03185c1ac6603eb4946cd7da3f30000000000000000000000"
)
decoded = tx_lib.psbt_hex_to_tx_hex(psbt_hex)
decoded, operations = psbt_lib.parse_psbt_hex_and_operations(psbt_hex)
expected_tx = (
psbt.PartiallySignedTransaction.from_base64_or_binary(
bytes.fromhex(psbt_hex),
Expand Down

0 comments on commit 729fd84

Please sign in to comment.