Skip to content

Commit

Permalink
Merge pull request #6 from kcalvinalvin/2024-08-20-simplify-calculate…
Browse files Browse the repository at this point in the history
…-roots

simplify calculate roots
  • Loading branch information
kcalvinalvin authored Aug 21, 2024
2 parents e63eb74 + 43bb12c commit 2380e01
Showing 1 changed file with 37 additions and 57 deletions.
94 changes: 37 additions & 57 deletions pytreexo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import bisect
import copy
import hashlib

Expand Down Expand Up @@ -83,88 +84,67 @@ def tree_rows(n: int) -> int:
return 0 if n == 0 else (n - 1).bit_length()


def row_maxpos(row: int, total_row: int) -> int:
mask = (2 << total_row) - 1
return ((mask << int(total_row-row)) & mask) - 1


def root_position(leaves: int, row: int, total_rows: int) -> int:
mask = (2 << total_rows) - 1
before = leaves & (mask << (row + 1))
shifted = (before >> row) | (mask << (total_rows + 1 - row))
return shifted & mask


def isroot(position: int, numleaves: int, row: int, total_rows: int) -> bool:
root_present = numleaves & (1 << row) != 0
rootpos = root_position(numleaves, row, total_rows)
return root_present and rootpos == position
def detect_row(position: int, total_rows: int) -> int:
marker = 1 << total_rows
h = 0
while position & marker != 0:
marker >>= 1
h += 1

return h

def next_least_list(list0, list1):
if list0 and list1:
if list0[0] < list1[0]:
return 0
else:
return 1
elif list0 and not list1:
return 0
elif not list0 and list1:
return 1
else:
return None

def isroot(position: int, numleaves: int, total_rows: int) -> bool:
row = detect_row(position, total_rows)
rootpos = root_position(numleaves, row, total_rows)
root_present = numleaves & (1 << row) != 0
return root_present and rootpos == position


def calculate_roots(numleaves: int, dels: [bytes], proof: Proof) -> [bytes]:
total_rows = tree_rows(numleaves)

if not proof.targets:
return []

dels = dels if dels is not None else [None] * len(proof.targets)
proof.targets, dels = (list(t) for t in zip(*sorted(zip(proof.targets, dels))))

next_hashes, next_positions, calculated_roots = [], [], []

row = 0
while row <= total_rows:
pos, cur_hash = -1, bytes
sib_present, sib_pos, sib_hash = False, -1, bytes

index = next_least_list(proof.targets, next_positions)
if index is None:
break
calculated_roots = []

pos = proof.targets.pop(0) if index == 0 else next_positions.pop(0)
cur_hash = dels.pop(0) if index == 0 else next_hashes.pop(0)
posHash = {}
for i, target in enumerate(proof.targets):
if dels is None:
posHash[target] = None
else:
posHash[target] = dels[i]

while pos > row_maxpos(row, total_rows):
row += 1
sortedTargets = sorted(proof.targets)
while sortedTargets:
pos = sortedTargets.pop(0)
cur_hash = posHash[pos]
del posHash[pos]

if isroot(pos, numleaves, row, total_rows):
if isroot(pos, numleaves, tree_rows(numleaves)):
calculated_roots.append(cur_hash)
continue

index = next_least_list(proof.targets, next_positions)
if index is not None:
sib_pos = proof.targets[0] if index == 0 else next_positions[0]
if pos | 1 == sib_pos:
sib_present = True
sib_pos = proof.targets.pop(0) if index == 0 else next_positions.pop(0)
sib_hash = dels.pop(0) if index == 0 else next_hashes.pop(0)

next_hash = bytes
if sib_present:
next_hash = parent_hash(cur_hash, sib_hash)
parent_pos = parent(pos, tree_rows(numleaves))
bisect.insort(sortedTargets, parent_pos)

if sortedTargets and pos | 1 == sortedTargets[0]:
sib_pos = sortedTargets.pop(0)
posHash[parent_pos] = parent_hash(cur_hash, posHash[sib_pos])

del posHash[sib_pos]
else:
proofhash = proof.proof.pop(0)

if pos & 1 == 0:
next_hash = parent_hash(cur_hash, proofhash)
posHash[parent_pos] = parent_hash(cur_hash, proofhash)
else:
next_hash = parent_hash(proofhash, cur_hash)

next_hashes.append(next_hash)
next_positions.append(parent(pos, total_rows))
posHash[parent_pos] = parent_hash(proofhash, cur_hash)

return calculated_roots

0 comments on commit 2380e01

Please sign in to comment.