Skip to content

Commit

Permalink
simplify calculate roots
Browse files Browse the repository at this point in the history
  • Loading branch information
kcalvinalvin committed Aug 21, 2024
1 parent cace749 commit 56e6602
Showing 1 changed file with 37 additions and 38 deletions.
75 changes: 37 additions & 38 deletions pytreexo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import bisect
import copy
import hashlib

Expand Down Expand Up @@ -95,9 +96,20 @@ def root_position(leaves: int, row: int, total_rows: int) -> int:
return shifted & mask


def isroot(position: int, numleaves: int, row: int, total_rows: int) -> bool:
root_present = numleaves & (1 << row) != 0
def detect_row(position: int, total_rows: int) -> int:
marker = 1 << total_rows
h = 0
while position & marker != 0:
marker >>= 1
h += 1

return h


def isroot(position: int, numleaves: int, total_rows: int) -> bool:
row = detect_row(position, total_rows)
rootpos = root_position(numleaves, row, total_rows)
root_present = numleaves & (1 << row) != 0
return root_present and rootpos == position


Expand All @@ -116,55 +128,42 @@ def next_least_list(list0, list1):


def calculate_roots(numleaves: int, dels: [bytes], proof: Proof) -> [bytes]:
total_rows = tree_rows(numleaves)

if not proof.targets:
return []

dels = dels if dels is not None else [None] * len(proof.targets)
proof.targets, dels = (list(t) for t in zip(*sorted(zip(proof.targets, dels))))

next_hashes, next_positions, calculated_roots = [], [], []
calculated_roots = []

row = 0
while row <= total_rows:
pos, cur_hash = -1, bytes
sib_present, sib_pos, sib_hash = False, -1, bytes

index = next_least_list(proof.targets, next_positions)
if index is None:
break

pos = proof.targets.pop(0) if index == 0 else next_positions.pop(0)
cur_hash = dels.pop(0) if index == 0 else next_hashes.pop(0)
posHash = {}
for i, target in enumerate(proof.targets):
if dels is None:
posHash[target] = None
else:
posHash[target] = dels[i]

while pos > row_maxpos(row, total_rows):
row += 1
sortedTargets = sorted(proof.targets)
while sortedTargets:
pos = sortedTargets.pop(0)
cur_hash = posHash[pos]
del posHash[pos]

if isroot(pos, numleaves, row, total_rows):
if isroot(pos, numleaves, tree_rows(numleaves)):
calculated_roots.append(cur_hash)
continue

index = next_least_list(proof.targets, next_positions)
if index is not None:
sib_pos = proof.targets[0] if index == 0 else next_positions[0]
if pos | 1 == sib_pos:
sib_present = True
sib_pos = proof.targets.pop(0) if index == 0 else next_positions.pop(0)
sib_hash = dels.pop(0) if index == 0 else next_hashes.pop(0)

next_hash = bytes
if sib_present:
next_hash = parent_hash(cur_hash, sib_hash)
parent_pos = parent(pos, tree_rows(numleaves))
bisect.insort(sortedTargets, parent_pos)

if sortedTargets and pos | 1 == sortedTargets[0]:
sib_pos = sortedTargets.pop(0)
posHash[parent_pos] = parent_hash(cur_hash, posHash[sib_pos])

del posHash[sib_pos]
else:
proofhash = proof.proof.pop(0)

if pos & 1 == 0:
next_hash = parent_hash(cur_hash, proofhash)
posHash[parent_pos] = parent_hash(cur_hash, proofhash)
else:
next_hash = parent_hash(proofhash, cur_hash)

next_hashes.append(next_hash)
next_positions.append(parent(pos, total_rows))
posHash[parent_pos] = parent_hash(proofhash, cur_hash)

return calculated_roots

0 comments on commit 56e6602

Please sign in to comment.