From 9bbf1da13fa6d1b5db87afc5fbac7ebf53079aba Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Fri, 29 Dec 2023 19:19:28 +0400 Subject: [PATCH 01/76] [wip] Range-based set reconciliation --- hashsync/monoid.go | 61 +++ hashsync/monoid_tree.go | 803 ++++++++++++++++++++++++++++++++++ hashsync/monoid_tree_store.go | 105 +++++ hashsync/monoid_tree_test.go | 430 ++++++++++++++++++ hashsync/rangesync.go | 208 +++++++++ hashsync/rangesync_test.go | 459 +++++++++++++++++++ 6 files changed, 2066 insertions(+) create mode 100644 hashsync/monoid.go create mode 100644 hashsync/monoid_tree.go create mode 100644 hashsync/monoid_tree_store.go create mode 100644 hashsync/monoid_tree_test.go create mode 100644 hashsync/rangesync.go create mode 100644 hashsync/rangesync_test.go diff --git a/hashsync/monoid.go b/hashsync/monoid.go new file mode 100644 index 0000000000..0e4672fcb6 --- /dev/null +++ b/hashsync/monoid.go @@ -0,0 +1,61 @@ +package hashsync + +type Monoid interface { + Identity() any + Op(a, b any) any + Fingerprint(v any) any +} + +type CountingMonoid struct{} + +var _ Monoid = CountingMonoid{} + +func (m CountingMonoid) Identity() any { return 0 } +func (m CountingMonoid) Op(a, b any) any { return a.(int) + b.(int) } +func (m CountingMonoid) Fingerprint(v any) any { return 1 } + +type combinedMonoid struct { + m1, m2 Monoid +} + +func CombineMonoids(m1, m2 Monoid) Monoid { + return combinedMonoid{m1: m1, m2: m2} +} + +type CombinedFingerprint struct { + First any + Second any +} + +func (m combinedMonoid) Identity() any { + return CombinedFingerprint{ + First: m.m1.Identity(), + Second: m.m2.Identity(), + } +} + +func (m combinedMonoid) Op(a, b any) any { + ac := a.(CombinedFingerprint) + bc := b.(CombinedFingerprint) + return CombinedFingerprint{ + First: m.m1.Op(ac.First, bc.First), + Second: m.m2.Op(ac.Second, bc.Second), + } +} + +func (m combinedMonoid) Fingerprint(v any) any { + return CombinedFingerprint{ + First: m.m1.Fingerprint(v), + Second: m.m2.Fingerprint(v), + } +} + +func CombinedFirst[T any](fp any) T { + cfp := fp.(CombinedFingerprint) + return cfp.First.(T) +} + +func CombinedSecond[T any](fp any) T { + cfp := fp.(CombinedFingerprint) + return cfp.Second.(T) +} diff --git a/hashsync/monoid_tree.go b/hashsync/monoid_tree.go new file mode 100644 index 0000000000..f379fc8f40 --- /dev/null +++ b/hashsync/monoid_tree.go @@ -0,0 +1,803 @@ +// TBD: add paper ref +package hashsync + +import ( + "fmt" + "io" + "slices" + "strings" +) + +type Ordered interface { + Compare(other Ordered) int +} + +type LowerBound struct{} + +var _ Ordered = LowerBound{} + +func (vb LowerBound) Compare(x Ordered) int { return -1 } + +type UpperBound struct{} + +var _ Ordered = UpperBound{} + +func (vb UpperBound) Compare(x Ordered) int { return 1 } + +type FingerprintPredicate func(fp any) bool + +func (fpred FingerprintPredicate) Match(y any) bool { + return fpred != nil && fpred(y) +} + +type MonoidTree interface { + Fingerprint() any + Add(v Ordered) + Min() MonoidTreeNode + Max() MonoidTreeNode + RangeFingerprint(node MonoidTreeNode, start, end Ordered, stop FingerprintPredicate) (fp any, startNode, endNode MonoidTreeNode) + Dump() string +} + +func MonoidTreeFromSortedSlice[T Ordered](m Monoid, items []T) MonoidTree { + s := make([]Ordered, len(items)) + for n, item := range items { + s[n] = item + } + mt := NewMonoidTree(m).(*monoidTree) + mt.root = mt.buildFromSortedSlice(nil, s) + return mt +} + +func MonoidTreeFromSlice[T Ordered](m Monoid, items []T) MonoidTree { + sorted := make([]T, len(items)) + copy(sorted, items) + slices.SortFunc(sorted, func(a, b T) int { + return a.Compare(b) + }) + return MonoidTreeFromSortedSlice(m, items) +} + +type MonoidTreeNode interface { + Key() Ordered + Prev() MonoidTreeNode + Next() MonoidTreeNode +} + +type color uint8 + +const ( + red color = 0 + black color = 1 +) + +func (c color) flip() color { return c ^ 1 } + +func (c color) String() string { + switch c { + case red: + return "red" + case black: + return "black" + default: + return fmt.Sprintf("", c) + } +} + +type dir uint8 + +const ( + left dir = 0 + right dir = 1 +) + +func (d dir) flip() dir { return d ^ 1 } + +func (d dir) String() string { + switch d { + case left: + return "left" + case right: + return "right" + default: + return fmt.Sprintf("", d) + } +} + +type monoidTreeNode struct { + parent *monoidTreeNode + left *monoidTreeNode + right *monoidTreeNode + key Ordered + max Ordered + fingerprint any + color color +} + +func (mn *monoidTreeNode) red() bool { + return mn != nil && mn.color == red +} + +func (mn *monoidTreeNode) black() bool { + return mn == nil || mn.color == black +} + +func (mn *monoidTreeNode) child(dir dir) *monoidTreeNode { + if mn == nil { + return nil + } + if dir == left { + return mn.left + } + return mn.right +} + +func (mn *monoidTreeNode) setChild(dir dir, child *monoidTreeNode) { + if mn == nil { + panic("setChild for a nil node") + } + if dir == left { + mn.left = child + } else { + mn.right = child + } + if child != nil { + child.parent = mn + } +} + +func (mn *monoidTreeNode) flip() { + if mn.left == nil || mn.right == nil { + panic("can't flip color with one or more nil children") + } + mn.color = mn.color.flip() + mn.left.color = mn.left.color.flip() + mn.right.color = mn.right.color.flip() +} + +func (mn *monoidTreeNode) Key() Ordered { return mn.key } + +func (mn *monoidTreeNode) minNode() *monoidTreeNode { + if mn.left == nil { + return mn + } + return mn.left.minNode() +} + +func (mn *monoidTreeNode) maxNode() *monoidTreeNode { + if mn.right == nil { + return mn + } + return mn.right.maxNode() +} + +func (mn *monoidTreeNode) prev() *monoidTreeNode { + switch { + case mn == nil: + return nil + case mn.left != nil: + return mn.left.maxNode() + default: + p := mn.parent + for p != nil && mn == p.left { + mn = p + p = p.parent + } + return p + } +} + +func (mn *monoidTreeNode) next() *monoidTreeNode { + switch { + case mn == nil: + return nil + case mn.right != nil: + return mn.right.minNode() + default: + p := mn.parent + for p != nil && mn == p.right { + mn = p + p = p.parent + } + return p + } +} + +func (mn *monoidTreeNode) rmmeStr() string { + if mn == nil { + return "" + } + return fmt.Sprintf("%s", mn.key) +} + +func (mn *monoidTreeNode) Prev() MonoidTreeNode { + if prev := mn.prev(); prev != nil { + return prev + } + return nil +} + +func (mn *monoidTreeNode) Next() MonoidTreeNode { + if next := mn.next(); next != nil { + return next + } + return nil +} + +func (mn *monoidTreeNode) dump(w io.Writer, indent int) { + indentStr := strings.Repeat(" ", indent) + fmt.Fprintf(w, "%skey: %v\n", indentStr, mn.key) + fmt.Fprintf(w, "%smax: %v\n", indentStr, mn.max) + fmt.Fprintf(w, "%sfp: %v\n", indentStr, mn.fingerprint) + if mn.left != nil { + fmt.Fprintf(w, "%sleft:\n", indentStr) + mn.left.dump(w, indent+1) + if mn.left.parent != mn { + fmt.Fprintf(w, "%sERROR: bad parent on the left\n", indentStr) + } + if mn.left.key.Compare(mn.key) >= 0 { + fmt.Fprintf(w, "%sERROR: left key >= parent key\n", indentStr) + } + } + if mn.right != nil { + fmt.Fprintf(w, "%sright:\n", indentStr) + mn.right.dump(w, indent+1) + if mn.right.parent != mn { + fmt.Fprintf(w, "%sERROR: bad parent on the right\n", indentStr) + } + if mn.right.key.Compare(mn.key) <= 0 { + fmt.Fprintf(w, "%sERROR: right key <= parent key\n", indentStr) + } + } +} + +func (mn *monoidTreeNode) dumpSubtree() string { + var sb strings.Builder + mn.dump(&sb, 0) + return sb.String() +} + +type monoidTree struct { + m Monoid + root *monoidTreeNode + cachedMinNode *monoidTreeNode + cachedMaxNode *monoidTreeNode +} + +func NewMonoidTree(m Monoid) MonoidTree { + return &monoidTree{m: m} +} + +func (mt *monoidTree) Min() MonoidTreeNode { + if mt.root == nil { + return nil + } + if mt.cachedMinNode == nil { + mt.cachedMinNode = mt.root.minNode() + } + if mt.cachedMinNode == nil { + panic("BUG: no minNode in a non-empty tree") + } + return mt.cachedMinNode +} + +func (mt *monoidTree) Max() MonoidTreeNode { + if mt.root == nil { + return nil + } + if mt.cachedMaxNode == nil { + mt.cachedMaxNode = mt.root.maxNode() + } + if mt.cachedMaxNode == nil { + panic("BUG: no maxNode in a non-empty tree") + } + return mt.cachedMaxNode +} + +func (mt *monoidTree) Fingerprint() any { + if mt.root == nil { + return mt.m.Identity() + } + return mt.root.fingerprint +} + +func (mt *monoidTree) newNode(parent *monoidTreeNode, v Ordered) *monoidTreeNode { + return &monoidTreeNode{ + parent: parent, + key: v, + max: v, + fingerprint: mt.m.Fingerprint(v), + } +} + +func (mt *monoidTree) buildFromSortedSlice(parent *monoidTreeNode, s []Ordered) *monoidTreeNode { + switch len(s) { + case 0: + return nil + case 1: + return mt.newNode(nil, s[0]) + } + middle := len(s) / 2 + node := mt.newNode(parent, s[middle]) + node.left = mt.buildFromSortedSlice(node, s[:middle]) + node.right = mt.buildFromSortedSlice(node, s[middle+1:]) + if node.left != nil { + node.left.parent = node + node.fingerprint = mt.m.Op(node.left.fingerprint, node.fingerprint) + } + if node.right != nil { + node.right.parent = node + node.fingerprint = mt.m.Op(node.fingerprint, node.right.fingerprint) + node.max = node.right.max + } + return node +} + +func (mt *monoidTree) safeFingerprint(mn *monoidTreeNode) any { + if mn == nil { + return mt.m.Identity() + } + return mn.fingerprint +} + +func (mt *monoidTree) updateFingerprintAndMax(mn *monoidTreeNode) { + fp := mt.m.Op(mt.safeFingerprint(mn.left), mt.m.Fingerprint(mn.key)) + mn.fingerprint = mt.m.Op(fp, mt.safeFingerprint(mn.right)) + if mn.right != nil { + mn.max = mn.right.max + } else { + mn.max = mn.key + } +} + +func (mt *monoidTree) rotate(mn *monoidTreeNode, d dir) *monoidTreeNode { + // mn.verify() + + rd := d.flip() + tmp := mn.child(rd) + // fmt.Fprintf(os.Stderr, "QQQQQ: rotate %s (child at %s is %s): subtree:\n%s\n", + // d, rd, tmp.key, mn.dumpSubtree()) + mn.setChild(rd, tmp.child(d)) + tmp.parent = mn.parent + tmp.setChild(d, mn) + + tmp.color = mn.color + mn.color = red + + // it's important to update mn first as it may be the new right child of + // tmp, and we need to update tmp.max too + mt.updateFingerprintAndMax(mn) + mt.updateFingerprintAndMax(tmp) + + return tmp +} + +func (mt *monoidTree) doubleRotate(mn *monoidTreeNode, d dir) *monoidTreeNode { + rd := d.flip() + mn.setChild(rd, mt.rotate(mn.child(rd), rd)) + return mt.rotate(mn, d) +} + +func (mt *monoidTree) Add(v Ordered) { + mt.root = mt.insert(mt.root, v, true) + mt.root.color = black +} + +func (mt *monoidTree) insert(mn *monoidTreeNode, v Ordered, rb bool) *monoidTreeNode { + // simplified insert implementation idea from + // https://zarif98sjs.github.io/blog/blog/redblacktree/ + if mn == nil { + mn = mt.newNode(nil, v) + if mt.cachedMinNode != nil && v.Compare(mt.cachedMinNode.key) < 0 { + mt.cachedMinNode = mn + } + if mt.cachedMaxNode != nil && v.Compare(mt.cachedMaxNode.key) > 0 { + mt.cachedMaxNode = mn + } + return mn + } + c := v.Compare(mn.key) + if c == 0 { + return mn + } + d := left + if c > 0 { + d = right + } + oldChild := mn.child(d) + newChild := mt.insert(oldChild, v, rb) + mn.setChild(d, newChild) + updateFP := true + if rb { + // non-red-black insert is used for testing + mn, updateFP = mt.insertFixup(mn, d, oldChild != newChild) + } + if updateFP { + mt.updateFingerprintAndMax(mn) + } + return mn +} + +func (mt *monoidTree) insertFixup(mn *monoidTreeNode, d dir, updateFP bool) (*monoidTreeNode, bool) { + child := mn.child(d) + rd := d.flip() + switch { + case child.black(): + return mn, true + case mn.child(rd).red(): + updateFP = true + // both children of mn are red => any child has 2 reds in a row + // (LL LR RR RL) => flip colors + if child.child(d).red() || child.child(rd).red() { + mn.flip() + } + case child.child(d).red(): + // another child of mn is black + // any child has 2 reds in a row (LL RR) => rotate + // rotate will update fingerprint of mn and the node + // that replaces it + mn = mt.rotate(mn, rd) + case child.child(rd).red(): + // another child of mn is black + // any child has 2 reds in a row (LR RL) => align first, then rotate + // doubleRotate will update fingerprint of mn and the node + // that replaces it + mn = mt.doubleRotate(mn, rd) + default: + updateFP = true + } + return mn, updateFP +} + +func (mt *monoidTree) findGTENode(mn *monoidTreeNode, x Ordered) *monoidTreeNode { + switch { + case mn == nil: + return nil + case x.Compare(mn.key) == 0: + // Exact match + return mn + case x.Compare(mn.max) > 0: + // All of this subtree is below v, maybe we can have + // some luck with the parent node + return mt.findGTENode(mn.parent, x) + case x.Compare(mn.key) >= 0: + // We're still below x (or at x, but allowEqual is + // false), but given that we checked Max and saw that + // this subtree has some keys that are greater than + // or equal to x, we can find them on the right + if mn.right == nil { + // mn.Max lied to us + panic("BUG: MonoidTreeNode: x > mn.Max but no right branch") + } + // Avoid endless recursion in case of a bug + if x.Compare(mn.right.max) > 0 { + panic("BUG: MonoidTreeNode: inconsistent Max on the right branch") + } + return mt.findGTENode(mn.right, x) + case mn.left == nil || x.Compare(mn.left.max) > 0: + // The current node's key is greater than x and the + // left branch is either empty or fully below x, so + // the current node is what we were looking for + return mn + default: + // Some keys on the left branch are greater or equal + // than x accordingto mn.Left.Max + r := mt.findGTENode(mn.left, x) + if r == nil { + panic("BUG: MonoidTreeNode: inconsistent Max on the left branch") + } + return r + } +} + +func (mt *monoidTree) invRangeFingerprint(mn *monoidTreeNode, x, y Ordered, stop FingerprintPredicate) (any, *monoidTreeNode) { + // QQQQQ: rename: rollover range + next := mn + minNode := mn.minNode() + + var acc any + var stopped bool + rightStartNode := mt.findGTENode(mn, x) + if rightStartNode != nil { + acc, next, stopped = mt.aggregateUntil(rightStartNode, acc, x, UpperBound{}, stop) + if stopped { + return acc, next + } + } else { + acc = mt.m.Identity() + } + + if y.Compare(minNode.key) > 0 { + acc, next, _ = mt.aggregateUntil(minNode, acc, LowerBound{}, y, stop) + } + + return acc, next +} + +func (mt *monoidTree) rangeFingerprint(node MonoidTreeNode, start, end Ordered, stop FingerprintPredicate) (fp any, startNode, endNode *monoidTreeNode) { + if mt.root == nil { + return mt.m.Identity(), nil, nil + } + if node == nil { + node = mt.root + } + + mn := node.(*monoidTreeNode) + minNode := mt.root.minNode() + acc := mt.m.Identity() + startNode = mt.findGTENode(mn, start) + switch { + case start.Compare(end) >= 0: + // rollover range, which includes the case start == end + // this includes 2 subranges: + // [start, max_element] and [min_element, end) + var stopped bool + if node != nil { + acc, endNode, stopped = mt.aggregateUntil(startNode, acc, start, UpperBound{}, stop) + } + + if !stopped && end.Compare(minNode.key) > 0 { + acc, endNode, _ = mt.aggregateUntil(minNode, acc, LowerBound{}, end, stop) + } + case node != nil: + // normal range, that is, start < end + acc, endNode, _ = mt.aggregateUntil(startNode, mt.m.Identity(), start, end, stop) + } + + if startNode == nil { + startNode = minNode + } + if endNode == nil { + endNode = minNode + } + + return acc, startNode, endNode +} + +func (mt *monoidTree) RangeFingerprint(node MonoidTreeNode, start, end Ordered, stop FingerprintPredicate) (fp any, startNode, endNode MonoidTreeNode) { + fp, stn, endn := mt.rangeFingerprint(node, start, end, stop) + switch { + case stn == nil && endn == nil: + // avoid wrapping nil in MonoidTreeNode interface + return fp, nil, nil + case stn == nil || endn == nil: + panic("BUG: can't have nil node just on one end") + default: + return fp, stn, endn + } +} + +func (mt *monoidTree) aggregateUntil(mn *monoidTreeNode, acc any, start, end Ordered, stop FingerprintPredicate) (fp any, node *monoidTreeNode, stopped bool) { + acc, node, stopped = mt.aggregateUp(mn, acc, start, end, stop) + if node == nil || end.Compare(node.key) <= 0 || stopped { + return acc, node, stopped + } + + f := mt.m.Op(acc, mt.m.Fingerprint(node.key)) + if stop.Match(f) { + return acc, node, true + } + return mt.aggregateDown(node.right, f, end, stop) +} + +// aggregateUp ascends from the left (lower) end of the range towards the LCA +// (lowest common ancestor) of nodes within the range [start,end). Instead of +// descending from the root node, the LCA is determined by the way of checking +// whether the stored max subtree key is below or at the end or not, saving +// some extra tree traversal when processing the ascending ranges. +// On the way up, if the current node is within the range, we include the right +// subtree in the aggregation using its saved fingerprint, as it is guaranteed +// to lie with the range. When we happen to go up from the right branch, we can +// only reach a predecessor node that lies below the start, and in this case we +// don't include the right subtree in the aggregation to avoid aggregating the +// same subset of nodes twice. +// If stop function is passed, we find the node on which it returns true +// for the fingerprint accumulated between start and that node, if the target +// node is somewhere to the left from the LCA. +func (mt *monoidTree) aggregateUp(mn *monoidTreeNode, acc any, start, end Ordered, stop FingerprintPredicate) (fp any, node *monoidTreeNode, stopped bool) { + switch { + case mn == nil: + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: null node\n") + return acc, nil, false + case stop.Match(acc): + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: stop: node %v acc %v\n", mn.key, acc) + return acc, mn.prev(), true + case end.Compare(mn.max) <= 0: + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: LCA: node %v acc %v\n", mn.key, acc) + // This node is a the LCA, the starting point for AggregateDown + return acc, mn, false + case start.Compare(mn.key) <= 0: + // This node is within the target range + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: in-range node %v acc %v\n", mn.key, acc) + f := mt.m.Op(acc, mt.m.Fingerprint(mn.key)) + if stop.Match(f) { + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: stop at the own node %v acc %v\n", mn.key, acc) + return acc, mn, true + } + f1 := mt.m.Op(f, mt.safeFingerprint(mn.right)) + if stop.Match(f1) { + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: right subtree matches node %v acc %v f1 %v\n", mn.key, acc, f1) + // The target node is somewhere in the right subtree + if mn.right == nil { + panic("BUG: nil right child with non-identity fingerprint") + } + acc, node := mt.boundedAggregate(mn.right, f, stop) + if node == nil { + panic("BUG: aggregateUp: bad subtree fingerprint on the right branch") + } + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: right subtree: node %v acc %v\n", node.key, acc) + return acc, node, true + } else { + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: no right subtree match: node %v acc %v f1 %v\n", mn.key, acc, f1) + acc = f1 + } + } + if mn.parent == nil { + // No need for AggregateDown as we've covered the entire + // [start, end) range + return acc, nil, false + } + return mt.aggregateUp(mn.parent, acc, start, end, stop) +} + +// aggregateDown descends from the LCA (lowest common ancestor) of nodes within +// the range ending at the 'end'. On the way down, the unvisited left subtrees +// are guaranteed to lie within the range, so they're included into the +// aggregation using their saved fingerprint. +// If stop function is passed, we find the node on which it returns true +// for the fingerprint accumulated between start and that node +func (mt *monoidTree) aggregateDown(mn *monoidTreeNode, acc any, end Ordered, stop FingerprintPredicate) (fp any, node *monoidTreeNode, stopped bool) { + switch { + case mn == nil: + // fmt.Fprintf(os.Stderr, "QQQQQ: mn == nil\n") + return acc, nil, false + case stop.Match(acc): + // fmt.Fprintf(os.Stderr, "QQQQQ: stop on node\n") + return acc, mn.prev(), true + case end.Compare(mn.key) > 0: + // fmt.Fprintf(os.Stderr, "QQQQQ: within the range\n") + // We're within the range but there also may be nodes + // within the range to the right. The left branch is + // fully within the range + f := mt.m.Op(acc, mt.safeFingerprint(mn.left)) + if stop.Match(f) { + // fmt.Fprintf(os.Stderr, "QQQQQ: left subtree covers it\n") + // The target node is somewhere in the left subtree + if mn.left == nil { + panic("BUG: aggregateDown: nil left child with non-identity fingerprint") + } + acc, node := mt.boundedAggregate(mn.left, acc, stop) + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: returned acc %v node %#v\n", acc, node) + if node == nil { + panic("BUG: aggregateDown: bad subtree fingerprint on the left branch") + } + return acc, node, true + } + f1 := mt.m.Op(f, mt.m.Fingerprint(mn.key)) + if stop.Match(f1) { + // fmt.Fprintf(os.Stderr, "QQQQQ: stop at the node, prev %#v\n", node.prev()) + return f, mn, true + } else { + acc = f1 + } + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateDown on the right\n") + return mt.aggregateDown(mn.right, acc, end, stop) + case mn.left == nil || end.Compare(mn.left.max) > 0: + // fmt.Fprintf(os.Stderr, "QQQQQ: node covers the range\n") + // Found the rightmost bounding node + f := mt.m.Op(acc, mt.safeFingerprint(mn.left)) + if stop.Match(f) { + // The target node is somewhere in the left subtree + if mn.left == nil { + panic("BUG: aggregateDown: nil left child with non-identity fingerprint") + } + acc, node := mt.boundedAggregate(mn.left, acc, stop) + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate(2): returned acc %v node %#v\n", acc, node) + if node == nil { + panic("BUG: aggregateDown: bad subtree fingerprint on the left branch") + } + return acc, node, true + } + return f, mn, false + default: + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateDown: going further down\n") + // We're too far to the right, outside the range + return mt.aggregateDown(mn.left, acc, end, stop) + } +} + +func (mt *monoidTree) boundedAggregate(mn *monoidTreeNode, acc any, stop FingerprintPredicate) (any, *monoidTreeNode) { + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: node %v, acc %v\n", mn.key, acc) + if mn == nil { + return acc, nil + } + + // If we don't need to stop, or if the stop point is somewhere after + // this subtree, we can just use the pre-calculated subtree fingerprint + if f := mt.m.Op(acc, mn.fingerprint); !stop.Match(f) { + return f, nil + } + + // This function is not supposed to be called with acc already matching + // the stop condition + if stop(acc) { + panic("BUG: boundedAggregate: initial fingerprint is matched before the first node") + } + + if mn.left != nil { + // See if we can skip recursion on the left branch + f := mt.m.Op(acc, mn.left.fingerprint) + if !stop(f) { + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: mn Left non-nil and no-stop %v, f %v, left fingerprint %v\n", mn.key, f, mn.Left.Fingerprint) + acc = f + } else { + // The target node must be contained in the left subtree + var node *monoidTreeNode + acc, node = mt.boundedAggregate(mn.left, acc, stop) + if node == nil { + panic("BUG: boundedAggregate: bad subtree fingerprint on the left branch") + } + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: mn Left non-nil and stop %v, new node %v, acc %v\n", mn.key, node.key, acc) + return acc, node + } + } + + f := mt.m.Op(acc, mt.m.Fingerprint(mn.key)) + + switch { + case stop(f): + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: stop at this node %v, fp %v\n", mn.key, acc) + return acc, mn + case mn.right != nil: + f1 := mt.m.Op(f, mn.right.fingerprint) + if !stop(f1) { + // The right branch is still below the target fingerprint + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: mn Right non-nil and no-stop %v, acc %v\n", mn.key, acc) + acc = f1 + } else { + // The target node must be contained in the right subtree + var node *monoidTreeNode + acc, node := mt.boundedAggregate(mn.right, f, stop) + if node == nil { + panic("BUG: boundedAggregate: bad subtree fingerprint on the right branch") + } + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: mn Right non-nil and stop %v, new node %v, acc %v\n", mn.key, node.key, acc) + return acc, node + } + } + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: %v -- return acc %v\n", mn.key, acc) + // QQQQQ: ZXXXXXX: return acc, nil !!!! + return f, nil +} + +func (mt *monoidTree) Next(node MonoidTreeNode) MonoidTreeNode { + next := node.(*monoidTreeNode).next() + if next == nil { + return nil + } + return next +} + +func (mt *monoidTree) Dump() string { + if mt.root == nil { + return "" + } + var sb strings.Builder + mt.root.dump(&sb, 0) + return sb.String() +} + +// TBD: !!! values and Lookup (via findGTENode) !!! +// TBD: maybe: persistent rbtree -- note that MonoidTree will be immutable, +// too, in this case (Insert returns a new tree => no problem with thread safety +// or cached min/max) + +// Persistent tree: +// In 'derived' mode, along with the color, use 2 more bits: +// * whether this node is new +// * whether this node has new descendants +// Note that any combination possible (a new node MAY NOT have any new descendants) +// Derived trees are created with Derive(). The persistent mode may only be used +// for derived trees (need to see if this is worth it). +// With these bits, it is fairly easy to grab the newly added items w/o +// using the local sync algorithm diff --git a/hashsync/monoid_tree_store.go b/hashsync/monoid_tree_store.go new file mode 100644 index 0000000000..7ad948ccf3 --- /dev/null +++ b/hashsync/monoid_tree_store.go @@ -0,0 +1,105 @@ +package hashsync + +type monoidTreeIterator struct { + mt MonoidTree + node MonoidTreeNode +} + +var _ Iterator = monoidTreeIterator{} + +func (it monoidTreeIterator) Equal(other Iterator) bool { + o := other.(monoidTreeIterator) + if it.mt != o.mt { + panic("comparing iterators from different MonoidTreeStore") + } + return it.node == o.node +} + +func (it monoidTreeIterator) Key() Ordered { + return it.node.Key() +} + +func (it monoidTreeIterator) Next() Iterator { + next := it.node.Next() + if next == nil { + next = it.mt.Min() + } + if next == nil { + return nil + } + if next.(*monoidTreeNode) == nil { + panic("QQQQQ: wrapped nil in Next") + } + return monoidTreeIterator{ + mt: it.mt, + node: next, + } +} + +// TBD: Lookup method +type MonoidTreeStore struct { + mt MonoidTree +} + +var _ ItemStore = &MonoidTreeStore{} + +func NewMonoidTreeStore(m Monoid) ItemStore { + return &MonoidTreeStore{ + mt: NewMonoidTree(CombineMonoids(m, CountingMonoid{})), + } +} + +// Add implements ItemStore. +func (mts *MonoidTreeStore) Add(v Ordered) { + mts.mt.Add(v) +} + +func (mts *MonoidTreeStore) iter(node MonoidTreeNode) Iterator { + if node == nil { + return nil + } + if node.(*monoidTreeNode) == nil { + panic("QQQQQ: wrapped nil") + } + return monoidTreeIterator{ + mt: mts.mt, + node: node, + } +} + +// GetRangeInfo implements ItemStore. +func (mts *MonoidTreeStore) GetRangeInfo(preceding Iterator, x Ordered, y Ordered, count int) RangeInfo { + var stop FingerprintPredicate + var node MonoidTreeNode + if preceding != nil { + p := preceding.(monoidTreeIterator) + if p.mt != mts.mt { + panic("GetRangeInfo: preceding iterator from a wrong MonoidTreeStore") + } + node = p.node + } + if count >= 0 { + stop = func(fp any) bool { + return CombinedSecond[int](fp) > count + } + } + fp, startNode, endNode := mts.mt.RangeFingerprint(node, x, y, stop) + // fmt.Fprintf(os.Stderr, "QQQQQ: fp %v, startNode %#v, endNode %#v\n", fp, startNode, endNode) + cfp := fp.(CombinedFingerprint) + return RangeInfo{ + Fingerprint: cfp.First, + Count: cfp.Second.(int), + Start: mts.iter(startNode), + End: mts.iter(endNode), + } +} + +// Min implements ItemStore. +func (mts *MonoidTreeStore) Min() Iterator { + return mts.iter(mts.mt.Min()) +} + +// Max implements ItemStore. +func (mts *MonoidTreeStore) Max() Iterator { + return mts.iter(mts.mt.Max()) +} diff --git a/hashsync/monoid_tree_test.go b/hashsync/monoid_tree_test.go new file mode 100644 index 0000000000..f4bf2370bf --- /dev/null +++ b/hashsync/monoid_tree_test.go @@ -0,0 +1,430 @@ +package hashsync + +import ( + "cmp" + "fmt" + "math/rand" + "slices" + "testing" + + "github.com/stretchr/testify/require" +) + +type sampleID string + +var _ Ordered = sampleID("") + +func (s sampleID) Compare(other Ordered) int { + return cmp.Compare(s, other.(sampleID)) +} + +type sampleMonoid struct{} + +var _ Monoid = sampleMonoid{} + +func (m sampleMonoid) Identity() any { return "" } +func (m sampleMonoid) Op(a, b any) any { return a.(string) + b.(string) } +func (m sampleMonoid) Fingerprint(a any) any { return string(a.(sampleID)) } + +func sampleCountMonoid() Monoid { + return CombineMonoids(sampleMonoid{}, CountingMonoid{}) +} + +func makeStringConcatTree(chars string) MonoidTree { + ids := make([]sampleID, len(chars)) + for n, c := range chars { + ids[n] = sampleID(c) + } + return MonoidTreeFromSlice[sampleID](sampleCountMonoid(), ids) +} + +// dumbAdd inserts the node into the tree without trying to maintain the +// red-black properties +func dumbAdd(mt MonoidTree, v Ordered) { + mtree := mt.(*monoidTree) + mtree.root = mtree.insert(mtree.root, v, false) +} + +// makeDumbTree constructs a binary tree by adding the chars one-by-one without +// trying to maintain the red-black properties +func makeDumbTree(chars string) MonoidTree { + if len(chars) == 0 { + panic("empty set") + } + mt := NewMonoidTree(sampleCountMonoid()) + for _, c := range chars { + dumbAdd(mt, sampleID(c)) + } + return mt +} + +func makeRBTree(chars string) MonoidTree { + if len(chars) == 0 { + panic("empty set") + } + mt := NewMonoidTree(sampleCountMonoid()) + for _, c := range chars { + mt.Add(sampleID(c)) + } + return mt +} + +func gtePos(all string, item string) int { + n := slices.IndexFunc([]byte(all), func(v byte) bool { + return v >= item[0] + }) + if n >= 0 { + return n + } + return len(all) +} + +func naiveRange(all, x, y string, stopCount int) (fingerprint, startStr, endStr string) { + if len(all) == 0 { + return "", "", "" + } + allBytes := []byte(all) + slices.Sort(allBytes) + all = string(allBytes) + start := gtePos(all, x) + end := gtePos(all, y) + if x < y { + if stopCount >= 0 && end-start > stopCount { + end = start + stopCount + } + if end < len(all) { + endStr = all[end : end+1] + } else { + endStr = all[0:1] + } + startStr = "" + if start < len(all) { + startStr = all[start : start+1] + } else { + startStr = all[0:1] + } + return all[start:end], startStr, endStr + } else { + r := all[start:] + all[:end] + // fmt.Fprintf(os.Stderr, "QQQQQ: x %q start %d y %q end %d\n", x, start, y, end) + if len(r) == 0 { + // fmt.Fprintf(os.Stderr, "QQQQQ: x %q start %d y %q end %d -- ret start\n", x, start, y, end) + return "", all[0:1], all[0:1] + } + if stopCount >= 0 && len(r) > stopCount { + return r[:stopCount], r[0:1], r[stopCount : stopCount+1] + } + if end < len(all) { + endStr = all[end : end+1] + } else { + endStr = all[0:1] + } + startStr = "" + if len(r) != 0 { + startStr = r[0:1] + } + return r, startStr, endStr + } +} + +func TestEmptyTree(t *testing.T) { + tree := NewMonoidTree(sampleCountMonoid()) + rfp1, startNode, endNode := tree.RangeFingerprint(nil, sampleID("a"), sampleID("a"), nil) + require.Nil(t, startNode) + require.Nil(t, endNode) + rfp2, startNode, endNode := tree.RangeFingerprint(nil, sampleID("a"), sampleID("c"), nil) + require.Nil(t, startNode) + require.Nil(t, endNode) + rfp3, startNode, endNode := tree.RangeFingerprint(nil, sampleID("c"), sampleID("a"), nil) + require.Nil(t, startNode) + require.Nil(t, endNode) + for _, fp := range []any{ + tree.Fingerprint(), + rfp1, + rfp2, + rfp3, + } { + require.Equal(t, "", CombinedFirst[string](fp)) + require.Equal(t, 0, CombinedSecond[int](fp)) + } +} + +func testMonoidTreeRanges(t *testing.T, tree MonoidTree) { + all := "abcdefghijklmnopqr" + for _, tc := range []struct { + all string + x, y sampleID + gte string + fp string + stop int + startAt sampleID + endAt sampleID + }{ + // normal ranges: [x, y) (x -> y) + {x: "0", y: "9", stop: -1, startAt: "a", endAt: "a", fp: ""}, + {x: "x", y: "y", stop: -1, startAt: "a", endAt: "a", fp: ""}, + {x: "a", y: "b", stop: -1, startAt: "a", endAt: "b", fp: "a"}, + {x: "a", y: "d", stop: -1, startAt: "a", endAt: "d", fp: "abc"}, + {x: "f", y: "o", stop: -1, startAt: "f", endAt: "o", fp: "fghijklmn"}, + {x: "0", y: "y", stop: -1, startAt: "a", endAt: "a", fp: "abcdefghijklmnopqr"}, + {x: "a", y: "r", stop: -1, startAt: "a", endAt: "r", fp: "abcdefghijklmnopq"}, + // full rollover range x -> end -> x, or [x, max) + [min, x) + {x: "a", y: "a", stop: -1, startAt: "a", endAt: "a", fp: "abcdefghijklmnopqr"}, + {x: "l", y: "l", stop: -1, startAt: "l", endAt: "l", fp: "lmnopqrabcdefghijk"}, + // rollover ranges: x -> end -> y, or [x, max), [min, y) + {x: "l", y: "f", stop: -1, startAt: "l", endAt: "f", fp: "lmnopqrabcde"}, + {x: "l", y: "0", stop: -1, startAt: "l", endAt: "a", fp: "lmnopqr"}, + {x: "y", y: "f", stop: -1, startAt: "a", endAt: "f", fp: "abcde"}, + {x: "y", y: "x", stop: -1, startAt: "a", endAt: "a", fp: "abcdefghijklmnopqr"}, + {x: "9", y: "0", stop: -1, startAt: "a", endAt: "a", fp: "abcdefghijklmnopqr"}, + {x: "s", y: "a", stop: -1, startAt: "a", endAt: "a", fp: ""}, + // normal ranges + stop + {x: "a", y: "q", stop: 0, startAt: "a", endAt: "a", fp: ""}, + {x: "a", y: "q", stop: 3, startAt: "a", endAt: "d", fp: "abc"}, + {x: "a", y: "q", stop: 5, startAt: "a", endAt: "f", fp: "abcde"}, + {x: "a", y: "q", stop: 7, startAt: "a", endAt: "h", fp: "abcdefg"}, + {x: "a", y: "q", stop: 16, startAt: "a", endAt: "q", fp: "abcdefghijklmnop"}, + // rollover ranges + stop + {x: "l", y: "f", stop: 3, startAt: "l", endAt: "o", fp: "lmn"}, + {x: "l", y: "f", stop: 8, startAt: "l", endAt: "b", fp: "lmnopqra"}, + {x: "y", y: "x", stop: 5, startAt: "a", endAt: "f", fp: "abcde"}, + // full rollover range + stop + {x: "a", y: "a", stop: 3, startAt: "a", endAt: "d", fp: "abc"}, + {x: "a", y: "a", stop: 10, startAt: "a", endAt: "k", fp: "abcdefghij"}, + {x: "l", y: "l", stop: 3, startAt: "l", endAt: "o", fp: "lmn"}, + } { + testName := fmt.Sprintf("%s-%s", tc.x, tc.y) + if tc.stop >= 0 { + testName += fmt.Sprintf("-%d", tc.stop) + } + t.Run(testName, func(t *testing.T) { + rootFP := tree.Fingerprint() + require.Equal(t, all, CombinedFirst[string](rootFP)) + require.Equal(t, len(all), CombinedSecond[int](rootFP)) + stopCounts := []int{tc.stop} + if tc.stop < 0 { + // Stop point at the end of the sequence or beyond it + // should produce the same results as no stop point at all + stopCounts = append(stopCounts, len(all), len(all)*2) + } + for _, stopCount := range stopCounts { + // make sure naiveRangeWithStopCount works as epxected, even + // though it is only used for tests + fpStr, startStr, endStr := naiveRange(all, string(tc.x), string(tc.y), stopCount) + require.Equal(t, tc.fp, fpStr, "naive fingerprint") + require.Equal(t, string(tc.startAt), startStr, "naive fingerprint: startAt") + require.Equal(t, string(tc.endAt), endStr, "naive fingerprint: endAt") + + var stop FingerprintPredicate + if stopCount >= 0 { + // stopCount is not used after this iteration + // so it's ok to have it captured in the closure + stop = func(fp any) bool { + count := CombinedSecond[int](fp) + return count > stopCount + } + } + fp, startNode, endNode := tree.RangeFingerprint(nil, tc.x, tc.y, stop) + require.Equal(t, tc.fp, CombinedFirst[string](fp), "fingerprint") + require.Equal(t, len(tc.fp), CombinedSecond[int](fp), "count") + require.NotNil(t, startNode, "start node") + require.NotNil(t, endNode, "end node") + require.Equal(t, tc.startAt, startNode.Key(), "start node key") + require.Equal(t, tc.endAt, endNode.Key(), "end node key") + } + }) + } +} + +func TestMonoidTreeRanges(t *testing.T) { + t.Run("pre-balanced tree", func(t *testing.T) { + testMonoidTreeRanges(t, makeStringConcatTree("abcdefghijklmnopqr")) + }) + t.Run("sequential add", func(t *testing.T) { + testMonoidTreeRanges(t, makeDumbTree("abcdefghijklmnopqr")) + }) + t.Run("shuffled add", func(t *testing.T) { + testMonoidTreeRanges(t, makeDumbTree("lodrnifeqacmbhkgjp")) + }) + t.Run("red-black add", func(t *testing.T) { + testMonoidTreeRanges(t, makeRBTree("lodrnifeqacmbhkgjp")) + }) +} + +func TestAscendingRanges(t *testing.T) { + all := "abcdefghijklmnopqr" + tree := makeRBTree(all) + for _, tc := range []struct { + name string + ranges []string + fingerprints []string + }{ + { + name: "normal ranges", + ranges: []string{"ac", "cj", "lq", "qr"}, + fingerprints: []string{"ab", "cdefghi", "lmnop", "q"}, + }, + { + name: "normal and inverted ranges", + ranges: []string{"xc", "cj", "p0"}, + fingerprints: []string{"ab", "cdefghi", "pqr"}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + var fps []string + var node MonoidTreeNode + for n, rng := range tc.ranges { + x := sampleID(rng[0]) + y := sampleID(rng[1]) + if n > 0 { + require.NotNil(t, node, "nil starting node for range %s-%s", x, y) + } + fpStr, _, _ := naiveRange(all, string(x), string(y), -1) + var fp any + fp, _, node = tree.RangeFingerprint(node, x, y, nil) + actualFP := CombinedFirst[string](fp) + require.Equal(t, len(actualFP), CombinedSecond[int](fp), "count") + require.Equal(t, fpStr, actualFP) + fps = append(fps, actualFP) + } + require.Equal(t, tc.fingerprints, fps, "fingerprints") + }) + } +} + +func verifyBinaryTree(t *testing.T, mn *monoidTreeNode) { + if mn.parent != nil && mn != mn.parent.left && mn != mn.parent.right { + require.Fail(t, "node is an 'unknown' child") + } + + if mn.left != nil { + require.Equal(t, mn, mn.left.parent, "bad parent node on the left branch") + require.Negative(t, mn.left.key.Compare(mn.key)) + leftMaxNode := mn.left.maxNode() + require.Negative(t, leftMaxNode.key.Compare(mn.key)) + verifyBinaryTree(t, mn.left) + } + + if mn.right != nil { + require.Equal(t, mn, mn.right.parent, "bad parent node on the right branch") + require.Positive(t, mn.right.key.Compare(mn.key)) + rightMinNode := mn.right.minNode() + require.Positive(t, rightMinNode.key.Compare(mn.key)) + verifyBinaryTree(t, mn.right) + } +} + +func verifyRedBlack(t *testing.T, mn *monoidTreeNode, blackDepth int) int { + if mn == nil { + return blackDepth + 1 + } + if mn.color == red { + require.NotNil(t, mn.parent, "root node must be black") + require.Equal(t, black, mn.parent.color, "parent of a red node is red") + if mn.left != nil { + require.Equal(t, black, mn.left.color, "left child of a red node is red") + } + if mn.right != nil { + require.Equal(t, black, mn.right.color, "right child of a red node is red") + } + } else { + blackDepth++ + } + bdLeft := verifyRedBlack(t, mn.left, blackDepth) + bdRight := verifyRedBlack(t, mn.right, blackDepth) + require.Equal(t, bdLeft, bdRight, "subtree black depth for node %s", mn.key) + return bdLeft +} + +func TestRedBlackTreeInsert(t *testing.T) { + for i := 0; i < 1000; i++ { + tree := NewMonoidTree(sampleCountMonoid()) + items := []byte("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz") + count := rand.Intn(len(items)) + 1 + items = items[:count] + shuffled := append([]byte(nil), items...) + rand.Shuffle(len(shuffled), func(i, j int) { + shuffled[i], shuffled[j] = shuffled[j], shuffled[i] + }) + + // items := []byte("0123456789ABCDEFG") + // shuffled := []byte("0678DF1CG5A9324BE") + + for i := 0; i < len(shuffled); i++ { + tree.Add(sampleID(shuffled[i])) // XXXX: Insert + } + var actualItems []byte + n := 0 + // t.Logf("items: %q", string(items)) + // t.Logf("shuffled: %q", string(shuffled)) + // t.Logf("QQQQQ: tree:\n%s", tree.Dump()) + root := tree.(*monoidTree).root + verifyBinaryTree(t, root) + verifyRedBlack(t, root, 0) + for node := tree.Min(); node != nil; node = node.Next() { + // avoid endless loop due to bugs in the tree impl + require.Less(t, n, len(items)*2, "got much more items than needed: %q -- %q", actualItems, shuffled) + n++ + actualItems = append(actualItems, node.Key().(sampleID)[0]) + } + require.Equal(t, items, actualItems) + + fp, startNode, endNode := tree.RangeFingerprint(nil, sampleID(items[0]), sampleID(items[0]), nil) + fpStr := CombinedFirst[string](fp) + require.Equal(t, string(items), fpStr, "fingerprint %q", shuffled) + require.Equal(t, len(fpStr), CombinedSecond[int](fp), "count %q") + require.Equal(t, sampleID(items[0]), startNode.Key(), "startNode") + require.Equal(t, sampleID(items[0]), endNode.Key(), "endNode") + } +} + +type makeTestTreeFunc func(chars string) MonoidTree + +func testRandomOrderAndRanges(t *testing.T, mktree makeTestTreeFunc) { + all := "abcdefghijklmnopqr" + for i := 0; i < 1000; i++ { + shuffled := []byte(all) + rand.Shuffle(len(shuffled), func(i, j int) { + shuffled[i], shuffled[j] = shuffled[j], shuffled[i] + }) + tree := makeDumbTree(string(shuffled)) + x := sampleID(shuffled[rand.Intn(len(shuffled))]) + y := sampleID(shuffled[rand.Intn(len(shuffled))]) + stopCount := rand.Intn(len(shuffled)+2) - 1 + var stop FingerprintPredicate + if stopCount >= 0 { + stop = func(fp any) bool { + return CombinedSecond[int](fp) > stopCount + } + } + + expFP, expStart, expEnd := naiveRange(all, string(x), string(y), stopCount) + fp, startNode, endNode := tree.RangeFingerprint(nil, x, y, stop) + + fpStr := CombinedFirst[string](fp) + curCase := fmt.Sprintf("items %q x %q y %q stopCount %d", shuffled, x, y, stopCount) + require.Equal(t, expFP, fpStr, "%s: fingerprint", curCase) + require.Equal(t, len(fpStr), CombinedSecond[int](fp), "%s: count", curCase) + + startStr := "" + if startNode != nil { + startStr = string(startNode.Key().(sampleID)) + } + require.Equal(t, expStart, startStr, "%s: next", curCase) + + endStr := "" + if endNode != nil { + endStr = string(endNode.Key().(sampleID)) + } + require.Equal(t, expEnd, endStr, "%s: next", curCase) + } +} + +func TestRandomOrderAndRanges(t *testing.T) { + t.Run("randomized dumb insert", func(t *testing.T) { + testRandomOrderAndRanges(t, makeDumbTree) + }) + t.Run("red-black tree", func(t *testing.T) { + testRandomOrderAndRanges(t, makeRBTree) + }) +} diff --git a/hashsync/rangesync.go b/hashsync/rangesync.go new file mode 100644 index 0000000000..9a65451a53 --- /dev/null +++ b/hashsync/rangesync.go @@ -0,0 +1,208 @@ +package hashsync + +import ( + "fmt" + "reflect" +) + +const ( + defaultMaxSendRange = 16 +) + +type RangeMessage struct { + X, Y Ordered + Fingerprint any + Count int + Items []Ordered +} + +func (m RangeMessage) String() string { + itemsStr := "" + if len(m.Items) != 0 { + itemsStr = fmt.Sprintf(" +%d items", len(m.Items)) + } + return fmt.Sprintf("", + m.X, m.Y, m.Count, m.Fingerprint, itemsStr) +} + +type Option func(r *RangeSetReconciler) + +func WithMaxSendRange(n int) Option { + return func(r *RangeSetReconciler) { + r.maxSendRange = n + } +} + +// Iterator points to in item in ItemStore +type Iterator interface { + // Equal returns true if this iterator is equal to another Iterator + Equal(other Iterator) bool + // Key returns the key corresponding to iterator + Key() Ordered + // Next returns an iterator pointing to the next key or nil + // if this key is the last one in the store + Next() Iterator +} + +type RangeInfo struct { + Fingerprint any + Count int + Start, End Iterator +} + +type ItemStore interface { + Add(v Ordered) + // GetRangeInfo returns RangeInfo for the item range in the tree. + // If count >= 0, at most count items are returned, and RangeInfo + // is returned for the corresponding subrange of the requested range + GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo + // Min returns the iterator pointing at the minimum element + // in the store. If the store is empty, it returns nil + Min() Iterator + // Max returns the iterator pointing at the maximum element + // in the store. If the store is empty, it returns nil + Max() Iterator +} + +type RangeSetReconciler struct { + is ItemStore + maxSendRange int +} + +func NewRangeSetReconciler(is ItemStore, opts ...Option) *RangeSetReconciler { + rsr := &RangeSetReconciler{ + is: is, + maxSendRange: defaultMaxSendRange, + } + for _, opt := range opts { + opt(rsr) + } + if rsr.maxSendRange == 0 { + panic("zero maxSendRange") + } + return rsr +} + +func (rsr *RangeSetReconciler) addItems(in []RangeMessage) { + for _, msg := range in { + for _, item := range msg.Items { + rsr.is.Add(item) + } + } +} + +func (rsr *RangeSetReconciler) getItems(start, end Iterator) []Ordered { + var r []Ordered + it := start + for { + r = append(r, it.Key()) + it = it.Next() + if it == end { + return r + } + } +} + +func (rsr *RangeSetReconciler) processSubrange(preceding Iterator, x, y Ordered) (RangeMessage, Iterator) { + if preceding != nil && preceding.Key().Compare(x) > 0 { + preceding = nil + } + info := rsr.is.GetRangeInfo(preceding, x, y, -1) + msg := RangeMessage{ + X: x, + Y: y, + Fingerprint: info.Fingerprint, + Count: info.Count, + } + // If the range is small enough, we send its contents + if info.Count != 0 && info.Count <= rsr.maxSendRange { + msg.Items = rsr.getItems(info.Start, info.End) + } + return msg, info.End +} + +func (rsr *RangeSetReconciler) processFingerprint(preceding Iterator, msg RangeMessage) ([]RangeMessage, Iterator) { + if msg.X == nil && msg.Y == nil { + it := rsr.is.Min() + if it == nil { + return nil, nil + } + msg.X = it.Key() + msg.Y = msg.X + } else if msg.X == nil || msg.Y == nil { + // TBD: don't pass just one nil when decoding!!! + panic("invalid range") + } + info := rsr.is.GetRangeInfo(preceding, msg.X, msg.Y, -1) + // fmt.Fprintf(os.Stderr, "msg %s fp %v start %#v end %#v count %d\n", msg, info.Fingerprint, info.Start, info.End, info.Count) + switch { + // FIXME: use Fingerprint interface for fingerprints + // with Equal() method + case reflect.DeepEqual(info.Fingerprint, msg.Fingerprint): + // fmt.Fprintf(os.Stderr, "range synced: %s\n", msg) + // the range is synced + return nil, info.End + case info.Count <= rsr.maxSendRange || msg.Count == 0: + // The other side is missing some items, and either + // range is small enough or empty on the other side + resp := RangeMessage{ + X: msg.X, + Y: msg.Y, + Fingerprint: info.Fingerprint, + Count: info.Count, + } + if info.Count != 0 { + resp.Items = rsr.getItems(info.Start, info.End) + } + // fmt.Fprintf(os.Stderr, "small/empty incoming range: %s -> %s\n", msg, resp) + return []RangeMessage{resp}, info.End + default: + // Need to split the range. + // Note that there's no special handling for rollover ranges with x >= y + // These need to be handled by ItemStore.GetRangeInfo() + count := info.Count / 2 + part := rsr.is.GetRangeInfo(preceding, msg.X, msg.Y, count) + middle := part.End.Key() + if middle == nil { + panic("BUG: can't split range with count > 1") + } + msg1, next := rsr.processSubrange(info.Start, msg.X, middle) + msg2, _ := rsr.processSubrange(next, middle, msg.Y) + // fmt.Fprintf(os.Stderr, "normal: split X %s - middle %s - Y %s:\n %s ->\n %s\n %s\n", + // msg.X, middle, msg.Y, msg, msg1, msg2) + return []RangeMessage{msg1, msg2}, info.End + } +} + +func (rsr *RangeSetReconciler) Initiate() RangeMessage { + it := rsr.is.Min() + if it == nil { + // Create a message with count 0 + return RangeMessage{} + } + min := it.Key() + info := rsr.is.GetRangeInfo(nil, min, min, -1) + return RangeMessage{ + X: min, + Y: min, + Fingerprint: info.Fingerprint, + Count: info.Count, + } +} + +func (rsr *RangeSetReconciler) Process(in []RangeMessage) []RangeMessage { + rsr.addItems(in) + var out []RangeMessage + for _, msg := range in { + // TODO: need to sort ranges, but also need to be careful + msgs, _ := rsr.processFingerprint(nil, msg) + out = append(out, msgs...) + } + return out +} + +// TBD: limit the number of rounds +// TBD: join adjacent ranges in the input +// TBD: process ascending ranges properly +// TBD: join adjacent ranges in the output +// TBD: bounded reconcile diff --git a/hashsync/rangesync_test.go b/hashsync/rangesync_test.go new file mode 100644 index 0000000000..92afe9b452 --- /dev/null +++ b/hashsync/rangesync_test.go @@ -0,0 +1,459 @@ +package hashsync + +import ( + "math/rand" + "slices" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" +) + +type dumbStoreIterator struct { + ds *dumbStore + n int +} + +var _ Iterator = dumbStoreIterator{} + +func (it dumbStoreIterator) Equal(other Iterator) bool { + o := other.(dumbStoreIterator) + if it.ds != o.ds { + panic("comparing iterators from different dumbStores") + } + return it.n == o.n +} + +func (it dumbStoreIterator) Key() Ordered { + return it.ds.items[it.n] +} + +func (it dumbStoreIterator) Next() Iterator { + return dumbStoreIterator{ + ds: it.ds, + n: (it.n + 1) % len(it.ds.items), + } +} + +type dumbStore struct { + items []sampleID +} + +var _ ItemStore = &dumbStore{} + +func (ds *dumbStore) Add(v Ordered) { + //slices.Insert[S ~[]E, E any](s S, i int, v ...E) + id := v.(sampleID) + if len(ds.items) == 0 { + ds.items = []sampleID{id} + return + } + p := slices.IndexFunc(ds.items, func(other sampleID) bool { + return other >= id + }) + switch { + case p < 0: + ds.items = append(ds.items, id) + case id == ds.items[p]: + // already present + default: + ds.items = slices.Insert(ds.items, p, id) + } +} + +func (ds *dumbStore) iter(n int) Iterator { + if n == -1 || n == len(ds.items) { + return nil + } + return dumbStoreIterator{ds: ds, n: n} +} + +func (ds *dumbStore) last() sampleID { + if len(ds.items) == 0 { + panic("can't get the last element: zero items") + } + return ds.items[len(ds.items)-1] +} + +func (ds *dumbStore) iterFor(s sampleID) Iterator { + n := slices.Index(ds.items, s) + if n == -1 { + panic("item not found: " + s) + } + return ds.iter(n) +} + +func (ds *dumbStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo { + all := storeItemStr(ds) + vx := x.(sampleID) + vy := y.(sampleID) + if preceding != nil && preceding.Key().Compare(x) > 0 { + panic("preceding info after x") + } + fp, startStr, endStr := naiveRange(all, string(vx), string(vy), count) + r := RangeInfo{ + Fingerprint: fp, + Count: len(fp), + } + if all != "" { + if startStr == "" || endStr == "" { + panic("empty startStr/endStr from naiveRange") + } + r.Start = ds.iterFor(sampleID(startStr)) + r.End = ds.iterFor(sampleID(endStr)) + } + return r +} + +func (ds *dumbStore) Min() Iterator { + if len(ds.items) == 0 { + return nil + } + return dumbStoreIterator{ + ds: ds, + n: 0, + } +} + +func (ds *dumbStore) Max() Iterator { + if len(ds.items) == 0 { + return nil + } + return dumbStoreIterator{ + ds: ds, + n: len(ds.items) - 1, + } +} + +type verifiedStoreIterator struct { + t *testing.T + knownGood Iterator + it Iterator +} + +var _ Iterator = &verifiedStoreIterator{} + +func (it verifiedStoreIterator) Equal(other Iterator) bool { + o := other.(verifiedStoreIterator) + eq1 := it.knownGood.Equal(o.knownGood) + eq2 := it.it.Equal(o.it) + require.Equal(it.t, eq1, eq2, "iterators equal") + require.Equal(it.t, it.knownGood.Key(), it.it.Key(), "keys of equal iterators") + return eq2 +} + +func (it verifiedStoreIterator) Key() Ordered { + k1 := it.knownGood.Key() + k2 := it.it.Key() + require.Equal(it.t, k1, k2, "keys") + return k2 +} + +func (it verifiedStoreIterator) Next() Iterator { + next1 := it.knownGood.Next() + next2 := it.it.Next() + require.Equal(it.t, next1.Key(), next2.Key(), "keys for Next()") + return verifiedStoreIterator{ + t: it.t, + knownGood: next1, + it: next2, + } +} + +type verifiedStore struct { + t *testing.T + knownGood ItemStore + store ItemStore +} + +var _ ItemStore = &verifiedStore{} + +func (vs *verifiedStore) Add(v Ordered) { + vs.knownGood.Add(v) + vs.store.Add(v) +} + +func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo { + var ri1, ri2 RangeInfo + if preceding != nil { + p := preceding.(verifiedStoreIterator) + ri1 = vs.knownGood.GetRangeInfo(p.knownGood, x, y, count) + ri2 = vs.store.GetRangeInfo(p.it, x, y, count) + } else { + ri1 = vs.knownGood.GetRangeInfo(nil, x, y, count) + ri2 = vs.store.GetRangeInfo(nil, x, y, count) + } + require.Equal(vs.t, ri1.Fingerprint, ri2.Fingerprint, "range info fingerprint") + require.Equal(vs.t, ri1.Count, ri2.Count, "range info count") + ri := RangeInfo{ + Fingerprint: ri2.Fingerprint, + Count: ri2.Count, + } + if ri1.Start == nil { + require.Nil(vs.t, ri2.Start, "range info start") + require.Nil(vs.t, ri1.End, "range info end (known good)") + require.Nil(vs.t, ri2.End, "range info end") + } else { + require.NotNil(vs.t, ri2.Start, "range info start") + require.Equal(vs.t, ri1.Start.Key(), ri2.Start.Key(), "range info start key") + require.NotNil(vs.t, ri1.End, "range info end (known good)") + require.NotNil(vs.t, ri2.End, "range info end") + ri.Start = verifiedStoreIterator{ + t: vs.t, + knownGood: ri1.Start, + it: ri2.Start, + } + } + if ri1.End == nil { + require.Nil(vs.t, ri2.End, "range info end") + } else { + require.NotNil(vs.t, ri2.End, "range info end") + require.Equal(vs.t, ri1.End.Key(), ri2.End.Key(), "range info end key") + ri.End = verifiedStoreIterator{ + t: vs.t, + knownGood: ri1.End, + it: ri2.End, + } + } + return ri +} + +func (vs *verifiedStore) Min() Iterator { + m1 := vs.knownGood.Min() + m2 := vs.knownGood.Min() + if m1 == nil { + require.Nil(vs.t, m2, "Min") + return nil + } else { + require.NotNil(vs.t, m2, "Min") + require.Equal(vs.t, m1.Key(), m2.Key(), "Min key") + } + return verifiedStoreIterator{ + t: vs.t, + knownGood: m1, + it: m2, + } +} + +func (vs *verifiedStore) Max() Iterator { + m1 := vs.knownGood.Max() + m2 := vs.knownGood.Max() + if m1 == nil { + require.Nil(vs.t, m2, "Max") + return nil + } else { + require.NotNil(vs.t, m2, "Max") + require.Equal(vs.t, m1.Key(), m2.Key(), "Max key") + } + return verifiedStoreIterator{ + t: vs.t, + knownGood: m1, + it: m2, + } +} + +type storeFactory func(t *testing.T) ItemStore + +func makeDumbStore(t *testing.T) ItemStore { + return &dumbStore{} +} + +func makeMonoidTreeStore(t *testing.T) ItemStore { + return NewMonoidTreeStore(sampleMonoid{}) +} + +func makeVerifiedMonoidTreeStore(t *testing.T) ItemStore { + return &verifiedStore{ + t: t, + knownGood: makeDumbStore(t), + store: makeMonoidTreeStore(t), + } +} + +func makeStore(t *testing.T, f storeFactory, items string) ItemStore { + s := f(t) + for _, c := range items { + s.Add(sampleID(c)) + } + return s +} + +func storeItemStr(is ItemStore) string { + it := is.Min() + if it == nil { + return "" + } + endAt := is.Max() + r := "" + for { + r += string(it.Key().(sampleID)) + if it == endAt { + return r + } + it = it.Next() + } +} + +var testStores = []struct { + name string + factory storeFactory +}{ + { + name: "dumb store", + factory: makeDumbStore, + }, + { + name: "monoid tree store", + factory: makeMonoidTreeStore, + }, + { + name: "verified monoid tree store", + factory: makeVerifiedMonoidTreeStore, + }, +} + +func forTestStores(t *testing.T, testFunc func(t *testing.T, factory storeFactory)) { + for _, s := range testStores { + t.Run(s.name, func(t *testing.T) { + testFunc(t, s.factory) + }) + } +} + +func dumpRangeMessages(t *testing.T, msgs []RangeMessage, fmt string, args ...any) { + t.Logf(fmt, args...) + for _, m := range msgs { + t.Logf(" %s", m) + } +} + +func runSync(t *testing.T, syncA, syncB *RangeSetReconciler, maxRounds int) (nRounds int) { + msgs := []RangeMessage{syncA.Initiate()} + var i int + for i = 0; len(msgs) != 0; i++ { + if i == maxRounds { + require.FailNow(t, "too many rounds", "didn't reconcile in %d rounds", i) + } + // dumpRangeMessages(t, msgs, "A %q -> B %q:", storeItemStr(syncA.is), storeItemStr(syncB.is)) + msgs = syncB.Process(msgs) + if msgs != nil { + // dumpRangeMessages(t, msgs, "B %q --> A %q:", storeItemStr(syncB.is), storeItemStr(syncA.is)) + msgs = syncA.Process(msgs) + } + } + // even with empty sets, there must be an exchange of messages + require.Greater(t, i, 0, "wrong reconc in zero rounds") + return i +} + +func testRangeSync(t *testing.T, storeFactory storeFactory) { + for _, tc := range []struct { + name string + a, b string + final string + maxRounds [4]int + }{ + { + name: "empty sets", + a: "", + b: "", + final: "", + maxRounds: [4]int{1, 1, 1, 1}, + }, + { + name: "empty to non-empty", + a: "", + b: "abcd", + final: "abcd", + maxRounds: [4]int{1, 1, 1, 1}, + }, + { + name: "non-empty to empty", + a: "abcd", + b: "", + final: "abcd", + maxRounds: [4]int{2, 2, 2, 2}, + }, + { + name: "non-intersecting sets", + a: "ab", + b: "cd", + final: "abcd", + maxRounds: [4]int{3, 2, 2, 2}, + }, + { + name: "intersecting sets", + a: "acdefghijklmn", + b: "bcdopqr", + final: "abcdefghijklmnopqr", + maxRounds: [4]int{4, 4, 4, 3}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + for n, maxSendRange := range []int{1, 2, 3, 4} { + t.Logf("maxSendRange: %d", maxSendRange) + storeA := makeStore(t, storeFactory, tc.a) + syncA := NewRangeSetReconciler(storeA, WithMaxSendRange(maxSendRange)) + storeB := makeStore(t, storeFactory, tc.b) + syncB := NewRangeSetReconciler(storeB, WithMaxSendRange(maxSendRange)) + + nRounds := runSync(t, syncA, syncB, tc.maxRounds[n]) + t.Logf("%s: maxSendRange %d: %d rounds", tc.name, maxSendRange, nRounds) + + require.Equal(t, storeItemStr(storeA), storeItemStr(storeB)) + require.Equal(t, tc.final, storeItemStr(storeA)) + } + }) + } +} + +func TestRangeSync(t *testing.T) { + forTestStores(t, testRangeSync) +} + +func testRandomSync(t *testing.T, storeFactory storeFactory) { + for i := 0; i < 1000; i++ { + var chars []byte + for c := byte(33); c < 127; c++ { + chars = append(chars, c) + } + + bytesA := append([]byte(nil), chars...) + rand.Shuffle(len(bytesA), func(i, j int) { + bytesA[i], bytesA[j] = bytesA[j], bytesA[i] + }) + bytesA = bytesA[:rand.Intn(len(bytesA))] + storeA := makeStore(t, storeFactory, string(bytesA)) + + bytesB := append([]byte(nil), chars...) + rand.Shuffle(len(bytesB), func(i, j int) { + bytesB[i], bytesB[j] = bytesB[j], bytesB[i] + }) + bytesB = bytesB[:rand.Intn(len(bytesB))] + storeB := makeStore(t, storeFactory, string(bytesB)) + + keySet := make(map[byte]struct{}) + for _, c := range append(bytesA, bytesB...) { + keySet[byte(c)] = struct{}{} + } + + expectedSet := maps.Keys(keySet) + slices.Sort(expectedSet) + + maxSendRange := rand.Intn(16) + 1 + syncA := NewRangeSetReconciler(storeA, WithMaxSendRange(maxSendRange)) + syncB := NewRangeSetReconciler(storeB, WithMaxSendRange(maxSendRange)) + + runSync(t, syncA, syncB, max(len(expectedSet), 2)) // FIXME: less rounds! + require.Equal(t, storeItemStr(storeA), storeItemStr(storeB)) + require.Equal(t, string(expectedSet), storeItemStr(storeA), + "expected set for %q<->%q", bytesA, bytesB) + } +} + +func TestRandomSync(t *testing.T) { + forTestStores(t, testRandomSync) +} + +// TBD: random test for MonoidTreeStore +// TBD: use logger for verbose logging (messages) From 5844d835820d22dbfd544faddc1d008306bf47ac Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Fri, 5 Jan 2024 01:11:05 +0400 Subject: [PATCH 02/76] hashsync: refactor rangesync to make it more stream-friendly --- hashsync/rangesync.go | 171 +++++++++++++++++++------------------ hashsync/rangesync_test.go | 126 ++++++++++++++++++++++++--- 2 files changed, 203 insertions(+), 94 deletions(-) diff --git a/hashsync/rangesync.go b/hashsync/rangesync.go index 9a65451a53..274cbee669 100644 --- a/hashsync/rangesync.go +++ b/hashsync/rangesync.go @@ -1,7 +1,6 @@ package hashsync import ( - "fmt" "reflect" ) @@ -9,20 +8,25 @@ const ( defaultMaxSendRange = 16 ) -type RangeMessage struct { - X, Y Ordered - Fingerprint any - Count int - Items []Ordered +type SyncMessage interface { + X() Ordered + Y() Ordered + Fingerprint() any + Count() int + HaveItems() bool } -func (m RangeMessage) String() string { - itemsStr := "" - if len(m.Items) != 0 { - itemsStr = fmt.Sprintf(" +%d items", len(m.Items)) - } - return fmt.Sprintf("", - m.X, m.Y, m.Count, m.Fingerprint, itemsStr) +type Conduit interface { + // NextMessage returns the next SyncMessage, or nil if there + // are no more SyncMessages. + NextMessage() (SyncMessage, error) + // NextItem returns the next item in the set or nil if there + // are no more items + NextItem() (Ordered, error) + // SendFingerprint sends range fingerprint to the peer + SendFingerprint(x, y Ordered, fingerprint any, count int) + // SendItems sends range fingerprint to the peer along with the items + SendItems(x, y Ordered, fingerprint any, count int, start, end Iterator) } type Option func(r *RangeSetReconciler) @@ -83,126 +87,129 @@ func NewRangeSetReconciler(is ItemStore, opts ...Option) *RangeSetReconciler { return rsr } -func (rsr *RangeSetReconciler) addItems(in []RangeMessage) { - for _, msg := range in { - for _, item := range msg.Items { - rsr.is.Add(item) - } - } -} - -func (rsr *RangeSetReconciler) getItems(start, end Iterator) []Ordered { - var r []Ordered - it := start +func (rsr *RangeSetReconciler) addItems(c Conduit) error { for { - r = append(r, it.Key()) - it = it.Next() - if it == end { - return r + item, err := c.NextItem() + if err != nil { + return err + } + if item == nil { + return nil } + rsr.is.Add(item) } } -func (rsr *RangeSetReconciler) processSubrange(preceding Iterator, x, y Ordered) (RangeMessage, Iterator) { +func (rsr *RangeSetReconciler) processSubrange(c Conduit, preceding, start, end Iterator, x, y Ordered) Iterator { if preceding != nil && preceding.Key().Compare(x) > 0 { preceding = nil } info := rsr.is.GetRangeInfo(preceding, x, y, -1) - msg := RangeMessage{ - X: x, - Y: y, - Fingerprint: info.Fingerprint, - Count: info.Count, - } // If the range is small enough, we send its contents if info.Count != 0 && info.Count <= rsr.maxSendRange { - msg.Items = rsr.getItems(info.Start, info.End) + c.SendItems(x, y, info.Fingerprint, info.Count, start, end) + } else { + c.SendFingerprint(x, y, info.Fingerprint, info.Count) } - return msg, info.End + return info.End } -func (rsr *RangeSetReconciler) processFingerprint(preceding Iterator, msg RangeMessage) ([]RangeMessage, Iterator) { - if msg.X == nil && msg.Y == nil { +func (rsr *RangeSetReconciler) processFingerprint(c Conduit, preceding Iterator, msg SyncMessage) Iterator { + x := msg.X() + y := msg.Y() + if x == nil && y == nil { + // The peer has no items at all so didn't + // even send X & Y it := rsr.is.Min() if it == nil { - return nil, nil + // We don't have any items at all, too + return nil } - msg.X = it.Key() - msg.Y = msg.X - } else if msg.X == nil || msg.Y == nil { - // TBD: don't pass just one nil when decoding!!! + x = it.Key() + y = x + } else if x == nil || y == nil { + // TBD: never pass just one nil when decoding!!! panic("invalid range") } - info := rsr.is.GetRangeInfo(preceding, msg.X, msg.Y, -1) + info := rsr.is.GetRangeInfo(preceding, x, y, -1) // fmt.Fprintf(os.Stderr, "msg %s fp %v start %#v end %#v count %d\n", msg, info.Fingerprint, info.Start, info.End, info.Count) switch { // FIXME: use Fingerprint interface for fingerprints // with Equal() method - case reflect.DeepEqual(info.Fingerprint, msg.Fingerprint): + case reflect.DeepEqual(info.Fingerprint, msg.Fingerprint()): // fmt.Fprintf(os.Stderr, "range synced: %s\n", msg) // the range is synced - return nil, info.End - case info.Count <= rsr.maxSendRange || msg.Count == 0: + return info.End + case info.Count <= rsr.maxSendRange || msg.Count() == 0: // The other side is missing some items, and either // range is small enough or empty on the other side - resp := RangeMessage{ - X: msg.X, - Y: msg.Y, - Fingerprint: info.Fingerprint, - Count: info.Count, - } if info.Count != 0 { - resp.Items = rsr.getItems(info.Start, info.End) + // fmt.Fprintf(os.Stderr, "small/empty incoming range: %s -> SendItems\n", msg) + c.SendItems(x, y, info.Fingerprint, info.Count, info.Start, info.End) + } else { + // fmt.Fprintf(os.Stderr, "small/empty incoming range: %s -> zero count msg\n", msg) + c.SendFingerprint(x, y, info.Fingerprint, info.Count) } - // fmt.Fprintf(os.Stderr, "small/empty incoming range: %s -> %s\n", msg, resp) - return []RangeMessage{resp}, info.End + return info.End default: // Need to split the range. // Note that there's no special handling for rollover ranges with x >= y // These need to be handled by ItemStore.GetRangeInfo() count := info.Count / 2 - part := rsr.is.GetRangeInfo(preceding, msg.X, msg.Y, count) - middle := part.End.Key() - if middle == nil { + part := rsr.is.GetRangeInfo(preceding, x, y, count) + if part.End == nil { panic("BUG: can't split range with count > 1") } - msg1, next := rsr.processSubrange(info.Start, msg.X, middle) - msg2, _ := rsr.processSubrange(next, middle, msg.Y) - // fmt.Fprintf(os.Stderr, "normal: split X %s - middle %s - Y %s:\n %s ->\n %s\n %s\n", - // msg.X, middle, msg.Y, msg, msg1, msg2) - return []RangeMessage{msg1, msg2}, info.End + middle := part.End.Key() + next := rsr.processSubrange(c, info.Start, part.Start, part.End, x, middle) + rsr.processSubrange(c, next, part.End, info.End, middle, y) + // fmt.Fprintf(os.Stderr, "normal: split X %s - middle %s - Y %s:\n %s", + // msg.X(), middle, msg.Y(), msg) + return info.End } } -func (rsr *RangeSetReconciler) Initiate() RangeMessage { +func (rsr *RangeSetReconciler) Initiate(c Conduit) { it := rsr.is.Min() if it == nil { // Create a message with count 0 - return RangeMessage{} + c.SendFingerprint(nil, nil, nil, 0) + return } min := it.Key() info := rsr.is.GetRangeInfo(nil, min, min, -1) - return RangeMessage{ - X: min, - Y: min, - Fingerprint: info.Fingerprint, - Count: info.Count, + if info.Count != 0 && info.Count < rsr.maxSendRange { + c.SendItems(min, min, info.Fingerprint, info.Count, info.Start, info.End) + } else { + c.SendFingerprint(min, min, info.Fingerprint, info.Count) } } -func (rsr *RangeSetReconciler) Process(in []RangeMessage) []RangeMessage { - rsr.addItems(in) - var out []RangeMessage - for _, msg := range in { +func (rsr *RangeSetReconciler) Process(c Conduit) error { + var msgs []SyncMessage + for { + msg, err := c.NextMessage() + if err != nil { + return err + } + if msg == nil { + break + } + msgs = append(msgs, msg) + } + + if err := rsr.addItems(c); err != nil { + return err + } + + for _, msg := range msgs { // TODO: need to sort ranges, but also need to be careful - msgs, _ := rsr.processFingerprint(nil, msg) - out = append(out, msgs...) + rsr.processFingerprint(c, nil, msg) } - return out + + return nil } -// TBD: limit the number of rounds -// TBD: join adjacent ranges in the input +// TBD: limit the number of rounds (outside RangeSetReconciler) // TBD: process ascending ranges properly -// TBD: join adjacent ranges in the output // TBD: bounded reconcile diff --git a/hashsync/rangesync_test.go b/hashsync/rangesync_test.go index 92afe9b452..da9ffd3b76 100644 --- a/hashsync/rangesync_test.go +++ b/hashsync/rangesync_test.go @@ -1,6 +1,7 @@ package hashsync import ( + "fmt" "math/rand" "slices" "testing" @@ -9,6 +10,104 @@ import ( "golang.org/x/exp/maps" ) +type rangeMessage struct { + x, y Ordered + fp any + count int + haveItems bool +} + +func (m rangeMessage) X() Ordered { return m.x } +func (m rangeMessage) Y() Ordered { return m.y } +func (m rangeMessage) Fingerprint() any { return m.fp } +func (m rangeMessage) Count() int { return m.count } +func (m rangeMessage) HaveItems() bool { return m.haveItems } + +var _ SyncMessage = rangeMessage{} + +func (m rangeMessage) String() string { + itemsStr := "" + if m.haveItems { + itemsStr = fmt.Sprintf(" +items") + } + return fmt.Sprintf("", + m.x, m.y, m.count, m.fp, itemsStr) +} + +type fakeConduit struct { + msgs []rangeMessage + items []Ordered + resp *fakeConduit +} + +var _ Conduit = &fakeConduit{} + +func (fc *fakeConduit) done() bool { + if fc.resp == nil { + return true + } + if len(fc.resp.msgs) == 0 { + panic("BUG: not done but no msgs") + } + return false +} + +func (fc *fakeConduit) NextMessage() (SyncMessage, error) { + if len(fc.msgs) != 0 { + m := fc.msgs[0] + fc.msgs = fc.msgs[1:] + return m, nil + } + + return nil, nil +} + +func (fc *fakeConduit) NextItem() (Ordered, error) { + if len(fc.items) != 0 { + item := fc.items[0] + fc.items = fc.items[1:] + return item, nil + } + + return nil, nil +} + +func (fc *fakeConduit) sendFingerprint(x, y Ordered, fingerprint any, count int, haveItems bool) { + if fc.resp == nil { + fc.resp = &fakeConduit{} + } + msg := rangeMessage{ + x: x, + y: y, + fp: fingerprint, + count: count, + haveItems: haveItems, + } + fc.resp.msgs = append(fc.resp.msgs, msg) +} + +func (fc *fakeConduit) SendFingerprint(x, y Ordered, fingerprint any, count int) { + fc.sendFingerprint(x, y, fingerprint, count, false) +} + +func (fc *fakeConduit) SendItems(x, y Ordered, fingerprint any, count int, start, end Iterator) { + fc.sendFingerprint(x, y, fingerprint, count, true) + if start == nil || end == nil { + panic("SendItems with null iterator(s)") + } + it := start + for { + fc.resp.items = append(fc.resp.items, it.Key()) + it = it.Next() + if it.Equal(end) { + break + } + } + if len(fc.resp.items) == 0 { + panic("SendItems with no items") + } +} + type dumbStoreIterator struct { ds *dumbStore n int @@ -287,7 +386,7 @@ func storeItemStr(is ItemStore) string { r := "" for { r += string(it.Key().(sampleID)) - if it == endAt { + if it.Equal(endAt) { return r } it = it.Next() @@ -320,7 +419,7 @@ func forTestStores(t *testing.T, testFunc func(t *testing.T, factory storeFactor } } -func dumpRangeMessages(t *testing.T, msgs []RangeMessage, fmt string, args ...any) { +func dumpRangeMessages(t *testing.T, msgs []rangeMessage, fmt string, args ...any) { t.Logf(fmt, args...) for _, m := range msgs { t.Logf(" %s", m) @@ -328,21 +427,24 @@ func dumpRangeMessages(t *testing.T, msgs []RangeMessage, fmt string, args ...an } func runSync(t *testing.T, syncA, syncB *RangeSetReconciler, maxRounds int) (nRounds int) { - msgs := []RangeMessage{syncA.Initiate()} + fc := &fakeConduit{} + syncA.Initiate(fc) + require.False(t, fc.done(), "no messages from Initiate") var i int - for i = 0; len(msgs) != 0; i++ { + for i := 0; !fc.done(); i++ { if i == maxRounds { require.FailNow(t, "too many rounds", "didn't reconcile in %d rounds", i) } - // dumpRangeMessages(t, msgs, "A %q -> B %q:", storeItemStr(syncA.is), storeItemStr(syncB.is)) - msgs = syncB.Process(msgs) - if msgs != nil { - // dumpRangeMessages(t, msgs, "B %q --> A %q:", storeItemStr(syncB.is), storeItemStr(syncA.is)) - msgs = syncA.Process(msgs) + // dumpRangeMessages(t, fc.msgs, "A %q -> B %q:", storeItemStr(syncA.is), storeItemStr(syncB.is)) + fc = fc.resp + syncB.Process(fc) + if fc.done() { + break } + fc = fc.resp + // dumpRangeMessages(t, fc.msgs, "B %q --> A %q:", storeItemStr(syncB.is), storeItemStr(syncA.is)) + syncA.Process(fc) } - // even with empty sets, there must be an exchange of messages - require.Greater(t, i, 0, "wrong reconc in zero rounds") return i } @@ -455,5 +557,5 @@ func TestRandomSync(t *testing.T) { forTestStores(t, testRandomSync) } -// TBD: random test for MonoidTreeStore +// TBD: include initiate round!!! // TBD: use logger for verbose logging (messages) From dfa569d11bd2beaefb76257946d24c9fc749ee7d Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Fri, 5 Jan 2024 21:57:23 +0400 Subject: [PATCH 03/76] hashsync: get rid of parent links in the tree Parent links make it impossible to implement (semi-)persistent mode --- hashsync/monoid_tree.go | 723 +++++++++++++++++----------------- hashsync/monoid_tree_store.go | 60 ++- hashsync/monoid_tree_test.go | 42 +- hashsync/rangesync.go | 3 +- hashsync/rangesync_test.go | 64 +-- 5 files changed, 448 insertions(+), 444 deletions(-) diff --git a/hashsync/monoid_tree.go b/hashsync/monoid_tree.go index f379fc8f40..0c3af52368 100644 --- a/hashsync/monoid_tree.go +++ b/hashsync/monoid_tree.go @@ -33,9 +33,9 @@ func (fpred FingerprintPredicate) Match(y any) bool { type MonoidTree interface { Fingerprint() any Add(v Ordered) - Min() MonoidTreeNode - Max() MonoidTreeNode - RangeFingerprint(node MonoidTreeNode, start, end Ordered, stop FingerprintPredicate) (fp any, startNode, endNode MonoidTreeNode) + Min() MonoidTreePointer + Max() MonoidTreePointer + RangeFingerprint(ptr MonoidTreePointer, start, end Ordered, stop FingerprintPredicate) (fp any, startNode, endNode MonoidTreePointer) Dump() string } @@ -58,10 +58,11 @@ func MonoidTreeFromSlice[T Ordered](m Monoid, items []T) MonoidTree { return MonoidTreeFromSortedSlice(m, items) } -type MonoidTreeNode interface { +type MonoidTreePointer interface { + Equal(other MonoidTreePointer) bool Key() Ordered - Prev() MonoidTreeNode - Next() MonoidTreeNode + Prev() + Next() } type color uint8 @@ -104,8 +105,123 @@ func (d dir) String() string { } } +const initialParentStackSize = 32 + +type monoidTreePointer struct { + parentStack []*monoidTreeNode + node *monoidTreeNode +} + +var _ MonoidTreePointer = &monoidTreePointer{} + +func (p *monoidTreePointer) clone() *monoidTreePointer { + // TODO: copy node stack + r := &monoidTreePointer{ + parentStack: make([]*monoidTreeNode, len(p.parentStack), cap(p.parentStack)), + node: p.node, + } + copy(r.parentStack, p.parentStack) + return r +} + +func (p *monoidTreePointer) parent() { + n := len(p.parentStack) + if n == 0 { + p.node = nil + } else { + n-- + p.node = p.parentStack[n] + p.parentStack = p.parentStack[:n] + } +} + +func (p *monoidTreePointer) left() { + if p.node != nil { + p.parentStack = append(p.parentStack, p.node) + p.node = p.node.left + } +} + +func (p *monoidTreePointer) right() { + if p.node != nil { + p.parentStack = append(p.parentStack, p.node) + p.node = p.node.right + } +} + +func (p *monoidTreePointer) min() { + for { + switch { + case p.node == nil || p.node.left == nil: + return + default: + p.left() + } + } +} + +func (p *monoidTreePointer) max() { + for { + switch { + case p.node == nil || p.node.right == nil: + return + default: + p.right() + } + } +} + +func (p *monoidTreePointer) Equal(other MonoidTreePointer) bool { + if other == nil { + return p.node == nil + } + return p.node == other.(*monoidTreePointer).node +} + +func (p *monoidTreePointer) Prev() { + switch { + case p.node == nil: + case p.node.left != nil: + p.left() + p.max() + default: + oldNode := p.node + for { + p.parent() + if p.node == nil || oldNode != p.node.left { + return + } + oldNode = p.node + } + } +} + +func (p *monoidTreePointer) Next() { + switch { + case p.node == nil: + case p.node.right != nil: + p.right() + p.min() + default: + oldNode := p.node + for { + p.parent() + if p.node == nil || oldNode != p.node.right { + return + } + oldNode = p.node + } + } +} + +func (p *monoidTreePointer) Key() Ordered { + if p.node == nil { + return nil + } + return p.node.key +} + type monoidTreeNode struct { - parent *monoidTreeNode left *monoidTreeNode right *monoidTreeNode key Ordered @@ -141,9 +257,6 @@ func (mn *monoidTreeNode) setChild(dir dir, child *monoidTreeNode) { } else { mn.right = child } - if child != nil { - child.parent = mn - } } func (mn *monoidTreeNode) flip() { @@ -157,73 +270,6 @@ func (mn *monoidTreeNode) flip() { func (mn *monoidTreeNode) Key() Ordered { return mn.key } -func (mn *monoidTreeNode) minNode() *monoidTreeNode { - if mn.left == nil { - return mn - } - return mn.left.minNode() -} - -func (mn *monoidTreeNode) maxNode() *monoidTreeNode { - if mn.right == nil { - return mn - } - return mn.right.maxNode() -} - -func (mn *monoidTreeNode) prev() *monoidTreeNode { - switch { - case mn == nil: - return nil - case mn.left != nil: - return mn.left.maxNode() - default: - p := mn.parent - for p != nil && mn == p.left { - mn = p - p = p.parent - } - return p - } -} - -func (mn *monoidTreeNode) next() *monoidTreeNode { - switch { - case mn == nil: - return nil - case mn.right != nil: - return mn.right.minNode() - default: - p := mn.parent - for p != nil && mn == p.right { - mn = p - p = p.parent - } - return p - } -} - -func (mn *monoidTreeNode) rmmeStr() string { - if mn == nil { - return "" - } - return fmt.Sprintf("%s", mn.key) -} - -func (mn *monoidTreeNode) Prev() MonoidTreeNode { - if prev := mn.prev(); prev != nil { - return prev - } - return nil -} - -func (mn *monoidTreeNode) Next() MonoidTreeNode { - if next := mn.next(); next != nil { - return next - } - return nil -} - func (mn *monoidTreeNode) dump(w io.Writer, indent int) { indentStr := strings.Repeat(" ", indent) fmt.Fprintf(w, "%skey: %v\n", indentStr, mn.key) @@ -232,9 +278,6 @@ func (mn *monoidTreeNode) dump(w io.Writer, indent int) { if mn.left != nil { fmt.Fprintf(w, "%sleft:\n", indentStr) mn.left.dump(w, indent+1) - if mn.left.parent != mn { - fmt.Fprintf(w, "%sERROR: bad parent on the left\n", indentStr) - } if mn.left.key.Compare(mn.key) >= 0 { fmt.Fprintf(w, "%sERROR: left key >= parent key\n", indentStr) } @@ -242,9 +285,6 @@ func (mn *monoidTreeNode) dump(w io.Writer, indent int) { if mn.right != nil { fmt.Fprintf(w, "%sright:\n", indentStr) mn.right.dump(w, indent+1) - if mn.right.parent != mn { - fmt.Fprintf(w, "%sERROR: bad parent on the right\n", indentStr) - } if mn.right.key.Compare(mn.key) <= 0 { fmt.Fprintf(w, "%sERROR: right key <= parent key\n", indentStr) } @@ -258,40 +298,49 @@ func (mn *monoidTreeNode) dumpSubtree() string { } type monoidTree struct { - m Monoid - root *monoidTreeNode - cachedMinNode *monoidTreeNode - cachedMaxNode *monoidTreeNode + m Monoid + root *monoidTreeNode + cachedMinPtr *monoidTreePointer + cachedMaxPtr *monoidTreePointer } func NewMonoidTree(m Monoid) MonoidTree { return &monoidTree{m: m} } -func (mt *monoidTree) Min() MonoidTreeNode { +func (mt *monoidTree) rootPtr() *monoidTreePointer { + return &monoidTreePointer{ + parentStack: make([]*monoidTreeNode, 0, initialParentStackSize), + node: mt.root, + } +} + +func (mt *monoidTree) Min() MonoidTreePointer { if mt.root == nil { return nil } - if mt.cachedMinNode == nil { - mt.cachedMinNode = mt.root.minNode() + if mt.cachedMinPtr == nil { + mt.cachedMinPtr = mt.rootPtr() + mt.cachedMinPtr.min() } - if mt.cachedMinNode == nil { + if mt.cachedMinPtr.node == nil { panic("BUG: no minNode in a non-empty tree") } - return mt.cachedMinNode + return mt.cachedMinPtr.clone() } -func (mt *monoidTree) Max() MonoidTreeNode { +func (mt *monoidTree) Max() MonoidTreePointer { if mt.root == nil { return nil } - if mt.cachedMaxNode == nil { - mt.cachedMaxNode = mt.root.maxNode() + if mt.cachedMaxPtr == nil { + mt.cachedMaxPtr = mt.rootPtr() + mt.cachedMaxPtr.max() } - if mt.cachedMaxNode == nil { + if mt.cachedMaxPtr.node == nil { panic("BUG: no maxNode in a non-empty tree") } - return mt.cachedMaxNode + return mt.cachedMaxPtr.clone() } func (mt *monoidTree) Fingerprint() any { @@ -303,7 +352,6 @@ func (mt *monoidTree) Fingerprint() any { func (mt *monoidTree) newNode(parent *monoidTreeNode, v Ordered) *monoidTreeNode { return &monoidTreeNode{ - parent: parent, key: v, max: v, fingerprint: mt.m.Fingerprint(v), @@ -322,11 +370,9 @@ func (mt *monoidTree) buildFromSortedSlice(parent *monoidTreeNode, s []Ordered) node.left = mt.buildFromSortedSlice(node, s[:middle]) node.right = mt.buildFromSortedSlice(node, s[middle+1:]) if node.left != nil { - node.left.parent = node node.fingerprint = mt.m.Op(node.left.fingerprint, node.fingerprint) } if node.right != nil { - node.right.parent = node node.fingerprint = mt.m.Op(node.fingerprint, node.right.fingerprint) node.max = node.right.max } @@ -358,7 +404,6 @@ func (mt *monoidTree) rotate(mn *monoidTreeNode, d dir) *monoidTreeNode { // fmt.Fprintf(os.Stderr, "QQQQQ: rotate %s (child at %s is %s): subtree:\n%s\n", // d, rd, tmp.key, mn.dumpSubtree()) mn.setChild(rd, tmp.child(d)) - tmp.parent = mn.parent tmp.setChild(d, mn) tmp.color = mn.color @@ -388,12 +433,9 @@ func (mt *monoidTree) insert(mn *monoidTreeNode, v Ordered, rb bool) *monoidTree // https://zarif98sjs.github.io/blog/blog/redblacktree/ if mn == nil { mn = mt.newNode(nil, v) - if mt.cachedMinNode != nil && v.Compare(mt.cachedMinNode.key) < 0 { - mt.cachedMinNode = mn - } - if mt.cachedMaxNode != nil && v.Compare(mt.cachedMaxNode.key) > 0 { - mt.cachedMaxNode = mn - } + // if the tree is being modified, cached min/max ptrs are no longer valid + mt.cachedMinPtr = nil + mt.cachedMaxPtr = nil return mn } c := v.Compare(mn.key) @@ -449,135 +491,116 @@ func (mt *monoidTree) insertFixup(mn *monoidTreeNode, d dir, updateFP bool) (*mo return mn, updateFP } -func (mt *monoidTree) findGTENode(mn *monoidTreeNode, x Ordered) *monoidTreeNode { - switch { - case mn == nil: - return nil - case x.Compare(mn.key) == 0: - // Exact match - return mn - case x.Compare(mn.max) > 0: - // All of this subtree is below v, maybe we can have - // some luck with the parent node - return mt.findGTENode(mn.parent, x) - case x.Compare(mn.key) >= 0: - // We're still below x (or at x, but allowEqual is - // false), but given that we checked Max and saw that - // this subtree has some keys that are greater than - // or equal to x, we can find them on the right - if mn.right == nil { - // mn.Max lied to us - panic("BUG: MonoidTreeNode: x > mn.Max but no right branch") - } - // Avoid endless recursion in case of a bug - if x.Compare(mn.right.max) > 0 { - panic("BUG: MonoidTreeNode: inconsistent Max on the right branch") - } - return mt.findGTENode(mn.right, x) - case mn.left == nil || x.Compare(mn.left.max) > 0: - // The current node's key is greater than x and the - // left branch is either empty or fully below x, so - // the current node is what we were looking for - return mn - default: - // Some keys on the left branch are greater or equal - // than x accordingto mn.Left.Max - r := mt.findGTENode(mn.left, x) - if r == nil { - panic("BUG: MonoidTreeNode: inconsistent Max on the left branch") - } - return r - } -} - -func (mt *monoidTree) invRangeFingerprint(mn *monoidTreeNode, x, y Ordered, stop FingerprintPredicate) (any, *monoidTreeNode) { - // QQQQQ: rename: rollover range - next := mn - minNode := mn.minNode() - - var acc any - var stopped bool - rightStartNode := mt.findGTENode(mn, x) - if rightStartNode != nil { - acc, next, stopped = mt.aggregateUntil(rightStartNode, acc, x, UpperBound{}, stop) - if stopped { - return acc, next +func (mt *monoidTree) findGTENode(ptr *monoidTreePointer, x Ordered) bool { + for { + switch { + case ptr.node == nil: + return false + case x.Compare(ptr.node.key) == 0: + // Exact match + return true + case x.Compare(ptr.node.max) > 0: + // All of this subtree is below v, maybe we can have + // some luck with the parent node + ptr.parent() + mt.findGTENode(ptr, x) + case x.Compare(ptr.node.key) >= 0: + // We're still below x (or at x, but allowEqual is + // false), but given that we checked Max and saw that + // this subtree has some keys that are greater than + // or equal to x, we can find them on the right + if ptr.node.right == nil { + // mn.Max lied to us + panic("BUG: MonoidTreeNode: x > mn.Max but no right branch") + } + // Avoid endless recursion in case of a bug + if x.Compare(ptr.node.right.max) > 0 { + panic("BUG: MonoidTreeNode: inconsistent Max on the right branch") + } + ptr.right() + case ptr.node.left == nil || x.Compare(ptr.node.left.max) > 0: + // The current node's key is greater than x and the + // left branch is either empty or fully below x, so + // the current node is what we were looking for + return true + default: + // Some keys on the left branch are greater or equal + // than x accordingto mn.Left.Max + ptr.left() } - } else { - acc = mt.m.Identity() - } - - if y.Compare(minNode.key) > 0 { - acc, next, _ = mt.aggregateUntil(minNode, acc, LowerBound{}, y, stop) } - - return acc, next } -func (mt *monoidTree) rangeFingerprint(node MonoidTreeNode, start, end Ordered, stop FingerprintPredicate) (fp any, startNode, endNode *monoidTreeNode) { +func (mt *monoidTree) rangeFingerprint(preceding MonoidTreePointer, start, end Ordered, stop FingerprintPredicate) (fp any, startPtr, endPtr *monoidTreePointer) { if mt.root == nil { return mt.m.Identity(), nil, nil } - if node == nil { - node = mt.root + var ptr *monoidTreePointer + if preceding == nil { + ptr = mt.rootPtr() + } else { + ptr = preceding.(*monoidTreePointer) } - mn := node.(*monoidTreeNode) - minNode := mt.root.minNode() + minPtr := mt.Min().(*monoidTreePointer) acc := mt.m.Identity() - startNode = mt.findGTENode(mn, start) + haveGTE := mt.findGTENode(ptr, start) + startPtr = ptr.clone() switch { case start.Compare(end) >= 0: // rollover range, which includes the case start == end // this includes 2 subranges: // [start, max_element] and [min_element, end) var stopped bool - if node != nil { - acc, endNode, stopped = mt.aggregateUntil(startNode, acc, start, UpperBound{}, stop) + if haveGTE { + acc, stopped = mt.aggregateUntil(ptr, acc, start, UpperBound{}, stop) } - if !stopped && end.Compare(minNode.key) > 0 { - acc, endNode, _ = mt.aggregateUntil(minNode, acc, LowerBound{}, end, stop) + if !stopped && end.Compare(minPtr.Key()) > 0 { + ptr = minPtr.clone() + acc, _ = mt.aggregateUntil(ptr, acc, LowerBound{}, end, stop) } - case node != nil: + case haveGTE: // normal range, that is, start < end - acc, endNode, _ = mt.aggregateUntil(startNode, mt.m.Identity(), start, end, stop) + acc, _ = mt.aggregateUntil(ptr, mt.m.Identity(), start, end, stop) } - if startNode == nil { - startNode = minNode + if startPtr.node == nil { + startPtr = minPtr.clone() } - if endNode == nil { - endNode = minNode + if ptr.node == nil { + ptr = minPtr.clone() } - return acc, startNode, endNode + return acc, startPtr, ptr } -func (mt *monoidTree) RangeFingerprint(node MonoidTreeNode, start, end Ordered, stop FingerprintPredicate) (fp any, startNode, endNode MonoidTreeNode) { - fp, stn, endn := mt.rangeFingerprint(node, start, end, stop) +func (mt *monoidTree) RangeFingerprint(ptr MonoidTreePointer, start, end Ordered, stop FingerprintPredicate) (fp any, startNode, endNode MonoidTreePointer) { + fp, startPtr, endPtr := mt.rangeFingerprint(ptr, start, end, stop) switch { - case stn == nil && endn == nil: - // avoid wrapping nil in MonoidTreeNode interface + case startPtr == nil && endPtr == nil: + // avoid wrapping nil in MonoidTreePointer interface return fp, nil, nil - case stn == nil || endn == nil: + case startPtr == nil || endPtr == nil: panic("BUG: can't have nil node just on one end") default: - return fp, stn, endn + return fp, startPtr, endPtr } } -func (mt *monoidTree) aggregateUntil(mn *monoidTreeNode, acc any, start, end Ordered, stop FingerprintPredicate) (fp any, node *monoidTreeNode, stopped bool) { - acc, node, stopped = mt.aggregateUp(mn, acc, start, end, stop) - if node == nil || end.Compare(node.key) <= 0 || stopped { - return acc, node, stopped +func (mt *monoidTree) aggregateUntil(ptr *monoidTreePointer, acc any, start, end Ordered, stop FingerprintPredicate) (fp any, stopped bool) { + acc, stopped = mt.aggregateUp(ptr, acc, start, end, stop) + if ptr.node == nil || end.Compare(ptr.node.key) <= 0 || stopped { + return acc, stopped } - f := mt.m.Op(acc, mt.m.Fingerprint(node.key)) + // fmt.Fprintf(os.Stderr, "QQQQQ: from aggregateUp: acc %q; ptr.node %q\n", acc, ptr.node.key) + f := mt.m.Op(acc, mt.m.Fingerprint(ptr.node.key)) if stop.Match(f) { - return acc, node, true + return acc, true } - return mt.aggregateDown(node.right, f, end, stop) + ptr.right() + return mt.aggregateDown(ptr, f, end, stop) } // aggregateUp ascends from the left (lower) end of the range towards the LCA @@ -594,50 +617,49 @@ func (mt *monoidTree) aggregateUntil(mn *monoidTreeNode, acc any, start, end Ord // If stop function is passed, we find the node on which it returns true // for the fingerprint accumulated between start and that node, if the target // node is somewhere to the left from the LCA. -func (mt *monoidTree) aggregateUp(mn *monoidTreeNode, acc any, start, end Ordered, stop FingerprintPredicate) (fp any, node *monoidTreeNode, stopped bool) { - switch { - case mn == nil: - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: null node\n") - return acc, nil, false - case stop.Match(acc): - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: stop: node %v acc %v\n", mn.key, acc) - return acc, mn.prev(), true - case end.Compare(mn.max) <= 0: - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: LCA: node %v acc %v\n", mn.key, acc) - // This node is a the LCA, the starting point for AggregateDown - return acc, mn, false - case start.Compare(mn.key) <= 0: - // This node is within the target range - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: in-range node %v acc %v\n", mn.key, acc) - f := mt.m.Op(acc, mt.m.Fingerprint(mn.key)) - if stop.Match(f) { - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: stop at the own node %v acc %v\n", mn.key, acc) - return acc, mn, true - } - f1 := mt.m.Op(f, mt.safeFingerprint(mn.right)) - if stop.Match(f1) { - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: right subtree matches node %v acc %v f1 %v\n", mn.key, acc, f1) - // The target node is somewhere in the right subtree - if mn.right == nil { - panic("BUG: nil right child with non-identity fingerprint") +func (mt *monoidTree) aggregateUp(ptr *monoidTreePointer, acc any, start, end Ordered, stop FingerprintPredicate) (fp any, stopped bool) { + for { + switch { + case ptr.node == nil: + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: null node\n") + return acc, false + case stop.Match(acc): + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: stop: node %v acc %v\n", mn.key, acc) + ptr.Prev() + return acc, true + case end.Compare(ptr.node.max) <= 0: + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: LCA: node %v acc %v\n", mn.key, acc) + // This node is a the LCA, the starting point for AggregateDown + return acc, false + case start.Compare(ptr.node.key) <= 0: + // This node is within the target range + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: in-range node %v acc %v\n", mn.key, acc) + f := mt.m.Op(acc, mt.m.Fingerprint(ptr.node.key)) + if stop.Match(f) { + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: stop at the own node %v acc %v\n", mn.key, acc) + return acc, true } - acc, node := mt.boundedAggregate(mn.right, f, stop) - if node == nil { - panic("BUG: aggregateUp: bad subtree fingerprint on the right branch") + f1 := mt.m.Op(f, mt.safeFingerprint(ptr.node.right)) + if stop.Match(f1) { + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: right subtree matches node %v acc %v f1 %v\n", mn.key, acc, f1) + // The target node is somewhere in the right subtree + if ptr.node.right == nil { + panic("BUG: nil right child with non-identity fingerprint") + } + ptr.right() + acc := mt.boundedAggregate(ptr, f, stop) + if ptr.node == nil { + panic("BUG: aggregateUp: bad subtree fingerprint on the right branch") + } + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: right subtree: node %v acc %v\n", node.key, acc) + return acc, true + } else { + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: no right subtree match: node %v acc %v f1 %v\n", mn.key, acc, f1) + acc = f1 } - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: right subtree: node %v acc %v\n", node.key, acc) - return acc, node, true - } else { - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: no right subtree match: node %v acc %v f1 %v\n", mn.key, acc, f1) - acc = f1 } + ptr.parent() } - if mn.parent == nil { - // No need for AggregateDown as we've covered the entire - // [start, end) range - return acc, nil, false - } - return mt.aggregateUp(mn.parent, acc, start, end, stop) } // aggregateDown descends from the LCA (lowest common ancestor) of nodes within @@ -646,136 +668,119 @@ func (mt *monoidTree) aggregateUp(mn *monoidTreeNode, acc any, start, end Ordere // aggregation using their saved fingerprint. // If stop function is passed, we find the node on which it returns true // for the fingerprint accumulated between start and that node -func (mt *monoidTree) aggregateDown(mn *monoidTreeNode, acc any, end Ordered, stop FingerprintPredicate) (fp any, node *monoidTreeNode, stopped bool) { - switch { - case mn == nil: - // fmt.Fprintf(os.Stderr, "QQQQQ: mn == nil\n") - return acc, nil, false - case stop.Match(acc): - // fmt.Fprintf(os.Stderr, "QQQQQ: stop on node\n") - return acc, mn.prev(), true - case end.Compare(mn.key) > 0: - // fmt.Fprintf(os.Stderr, "QQQQQ: within the range\n") - // We're within the range but there also may be nodes - // within the range to the right. The left branch is - // fully within the range - f := mt.m.Op(acc, mt.safeFingerprint(mn.left)) - if stop.Match(f) { - // fmt.Fprintf(os.Stderr, "QQQQQ: left subtree covers it\n") - // The target node is somewhere in the left subtree - if mn.left == nil { - panic("BUG: aggregateDown: nil left child with non-identity fingerprint") +func (mt *monoidTree) aggregateDown(ptr *monoidTreePointer, acc any, end Ordered, stop FingerprintPredicate) (fp any, stopped bool) { + for { + switch { + case ptr.node == nil: + // fmt.Fprintf(os.Stderr, "QQQQQ: mn == nil\n") + return acc, false + case stop.Match(acc): + // fmt.Fprintf(os.Stderr, "QQQQQ: stop on node\n") + ptr.Prev() + return acc, true + case end.Compare(ptr.node.key) > 0: + // fmt.Fprintf(os.Stderr, "QQQQQ: within the range\n") + // We're within the range but there also may be nodes + // within the range to the right. The left branch is + // fully within the range + f := mt.m.Op(acc, mt.safeFingerprint(ptr.node.left)) + if stop.Match(f) { + // fmt.Fprintf(os.Stderr, "QQQQQ: left subtree covers it\n") + // The target node is somewhere in the left subtree + if ptr.node.left == nil { + panic("BUG: aggregateDown: nil left child with non-identity fingerprint") + } + ptr.left() + return mt.boundedAggregate(ptr, acc, stop), true } - acc, node := mt.boundedAggregate(mn.left, acc, stop) - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: returned acc %v node %#v\n", acc, node) - if node == nil { - panic("BUG: aggregateDown: bad subtree fingerprint on the left branch") - } - return acc, node, true - } - f1 := mt.m.Op(f, mt.m.Fingerprint(mn.key)) - if stop.Match(f1) { - // fmt.Fprintf(os.Stderr, "QQQQQ: stop at the node, prev %#v\n", node.prev()) - return f, mn, true - } else { - acc = f1 - } - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateDown on the right\n") - return mt.aggregateDown(mn.right, acc, end, stop) - case mn.left == nil || end.Compare(mn.left.max) > 0: - // fmt.Fprintf(os.Stderr, "QQQQQ: node covers the range\n") - // Found the rightmost bounding node - f := mt.m.Op(acc, mt.safeFingerprint(mn.left)) - if stop.Match(f) { - // The target node is somewhere in the left subtree - if mn.left == nil { - panic("BUG: aggregateDown: nil left child with non-identity fingerprint") + f1 := mt.m.Op(f, mt.m.Fingerprint(ptr.node.key)) + if stop.Match(f1) { + // fmt.Fprintf(os.Stderr, "QQQQQ: stop at the node, prev %#v\n", node.prev()) + return f, true + } else { + acc = f1 } - acc, node := mt.boundedAggregate(mn.left, acc, stop) - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate(2): returned acc %v node %#v\n", acc, node) - if node == nil { - panic("BUG: aggregateDown: bad subtree fingerprint on the left branch") + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateDown on the right\n") + ptr.right() + case ptr.node.left == nil || end.Compare(ptr.node.left.max) > 0: + // fmt.Fprintf(os.Stderr, "QQQQQ: node covers the range\n") + // Found the rightmost bounding node + f := mt.m.Op(acc, mt.safeFingerprint(ptr.node.left)) + if stop.Match(f) { + // The target node is somewhere in the left subtree + if ptr.node.left == nil { + panic("BUG: aggregateDown: nil left child with non-identity fingerprint") + } + // XXXXX fixme + ptr.left() + return mt.boundedAggregate(ptr, acc, stop), true } - return acc, node, true + return f, false + default: + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateDown: going further down\n") + // We're too far to the right, outside the range + ptr.left() } - return f, mn, false - default: - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateDown: going further down\n") - // We're too far to the right, outside the range - return mt.aggregateDown(mn.left, acc, end, stop) } } -func (mt *monoidTree) boundedAggregate(mn *monoidTreeNode, acc any, stop FingerprintPredicate) (any, *monoidTreeNode) { - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: node %v, acc %v\n", mn.key, acc) - if mn == nil { - return acc, nil - } +func (mt *monoidTree) boundedAggregate(ptr *monoidTreePointer, acc any, stop FingerprintPredicate) any { + for { + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: node %v, acc %v\n", mn.key, acc) + if ptr.node == nil { + return acc + } - // If we don't need to stop, or if the stop point is somewhere after - // this subtree, we can just use the pre-calculated subtree fingerprint - if f := mt.m.Op(acc, mn.fingerprint); !stop.Match(f) { - return f, nil - } + // If we don't need to stop, or if the stop point is somewhere after + // this subtree, we can just use the pre-calculated subtree fingerprint + if f := mt.m.Op(acc, ptr.node.fingerprint); !stop.Match(f) { + return f + } - // This function is not supposed to be called with acc already matching - // the stop condition - if stop(acc) { - panic("BUG: boundedAggregate: initial fingerprint is matched before the first node") - } + // This function is not supposed to be called with acc already matching + // the stop condition + if stop.Match(acc) { + panic("BUG: boundedAggregate: initial fingerprint is matched before the first node") + } - if mn.left != nil { - // See if we can skip recursion on the left branch - f := mt.m.Op(acc, mn.left.fingerprint) - if !stop(f) { - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: mn Left non-nil and no-stop %v, f %v, left fingerprint %v\n", mn.key, f, mn.Left.Fingerprint) - acc = f - } else { - // The target node must be contained in the left subtree - var node *monoidTreeNode - acc, node = mt.boundedAggregate(mn.left, acc, stop) - if node == nil { - panic("BUG: boundedAggregate: bad subtree fingerprint on the left branch") + if ptr.node.left != nil { + // See if we can skip recursion on the left branch + f := mt.m.Op(acc, ptr.node.left.fingerprint) + if !stop.Match(f) { + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: mn Left non-nil and no-stop %v, f %v, left fingerprint %v\n", mn.key, f, mn.Left.Fingerprint) + acc = f + } else { + // The target node must be contained in the left subtree + ptr.left() + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: mn Left non-nil and stop %v, new node %v, acc %v\n", mn.key, node.key, acc) + continue } - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: mn Left non-nil and stop %v, new node %v, acc %v\n", mn.key, node.key, acc) - return acc, node } - } - f := mt.m.Op(acc, mt.m.Fingerprint(mn.key)) - - switch { - case stop(f): - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: stop at this node %v, fp %v\n", mn.key, acc) - return acc, mn - case mn.right != nil: - f1 := mt.m.Op(f, mn.right.fingerprint) - if !stop(f1) { - // The right branch is still below the target fingerprint - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: mn Right non-nil and no-stop %v, acc %v\n", mn.key, acc) - acc = f1 - } else { - // The target node must be contained in the right subtree - var node *monoidTreeNode - acc, node := mt.boundedAggregate(mn.right, f, stop) - if node == nil { - panic("BUG: boundedAggregate: bad subtree fingerprint on the right branch") + f := mt.m.Op(acc, mt.m.Fingerprint(ptr.node.key)) + if stop.Match(f) { + return acc + } + acc = f + + if ptr.node.right != nil { + f1 := mt.m.Op(f, ptr.node.right.fingerprint) + if !stop.Match(f1) { + // The right branch is still below the target fingerprint + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: mn Right non-nil and no-stop %v, acc %v\n", mn.key, acc) + acc = f1 + } else { + // The target node must be contained in the right subtree + acc = f + ptr.right() + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: mn Right non-nil and stop %v, new node %v, acc %v\n", mn.key, node.key, acc) + continue } - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: mn Right non-nil and stop %v, new node %v, acc %v\n", mn.key, node.key, acc) - return acc, node } + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: %v -- return acc %v\n", mn.key, acc) + // QQQQQ: ZXXXXXX: return acc, nil !!!! + return acc } - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: %v -- return acc %v\n", mn.key, acc) - // QQQQQ: ZXXXXXX: return acc, nil !!!! - return f, nil -} - -func (mt *monoidTree) Next(node MonoidTreeNode) MonoidTreeNode { - next := node.(*monoidTreeNode).next() - if next == nil { - return nil - } - return next } func (mt *monoidTree) Dump() string { @@ -792,12 +797,8 @@ func (mt *monoidTree) Dump() string { // too, in this case (Insert returns a new tree => no problem with thread safety // or cached min/max) -// Persistent tree: -// In 'derived' mode, along with the color, use 2 more bits: -// * whether this node is new -// * whether this node has new descendants -// Note that any combination possible (a new node MAY NOT have any new descendants) -// Derived trees are created with Derive(). The persistent mode may only be used -// for derived trees (need to see if this is worth it). -// With these bits, it is fairly easy to grab the newly added items w/o -// using the local sync algorithm +// TODO: rename MonoidTreeNode to just Node, MonoidTree to SyncTree +// TODO: use sync.Pool for node alloc +// see also: +// https://www.akshaydeo.com/blog/2017/12/23/How-did-I-improve-latency-by-700-percent-using-syncPool/ +// so may need refcounting diff --git a/hashsync/monoid_tree_store.go b/hashsync/monoid_tree_store.go index 7ad948ccf3..3d0f7af21b 100644 --- a/hashsync/monoid_tree_store.go +++ b/hashsync/monoid_tree_store.go @@ -1,38 +1,28 @@ package hashsync type monoidTreeIterator struct { - mt MonoidTree - node MonoidTreeNode + mt MonoidTree + ptr MonoidTreePointer } -var _ Iterator = monoidTreeIterator{} +var _ Iterator = &monoidTreeIterator{} -func (it monoidTreeIterator) Equal(other Iterator) bool { - o := other.(monoidTreeIterator) +func (it *monoidTreeIterator) Equal(other Iterator) bool { + o := other.(*monoidTreeIterator) if it.mt != o.mt { panic("comparing iterators from different MonoidTreeStore") } - return it.node == o.node + return it.ptr.Equal(o.ptr) } -func (it monoidTreeIterator) Key() Ordered { - return it.node.Key() +func (it *monoidTreeIterator) Key() Ordered { + return it.ptr.Key() } -func (it monoidTreeIterator) Next() Iterator { - next := it.node.Next() - if next == nil { - next = it.mt.Min() - } - if next == nil { - return nil - } - if next.(*monoidTreeNode) == nil { - panic("QQQQQ: wrapped nil in Next") - } - return monoidTreeIterator{ - mt: it.mt, - node: next, +func (it *monoidTreeIterator) Next() { + it.ptr.Next() + if it.ptr.Key() == nil { + it.ptr = it.mt.Min() } } @@ -54,43 +44,39 @@ func (mts *MonoidTreeStore) Add(v Ordered) { mts.mt.Add(v) } -func (mts *MonoidTreeStore) iter(node MonoidTreeNode) Iterator { - if node == nil { +func (mts *MonoidTreeStore) iter(ptr MonoidTreePointer) Iterator { + if ptr == nil { return nil } - if node.(*monoidTreeNode) == nil { - panic("QQQQQ: wrapped nil") - } - return monoidTreeIterator{ - mt: mts.mt, - node: node, + return &monoidTreeIterator{ + mt: mts.mt, + ptr: ptr, } } // GetRangeInfo implements ItemStore. func (mts *MonoidTreeStore) GetRangeInfo(preceding Iterator, x Ordered, y Ordered, count int) RangeInfo { var stop FingerprintPredicate - var node MonoidTreeNode + var node MonoidTreePointer if preceding != nil { - p := preceding.(monoidTreeIterator) + p := preceding.(*monoidTreeIterator) if p.mt != mts.mt { panic("GetRangeInfo: preceding iterator from a wrong MonoidTreeStore") } - node = p.node + node = p.ptr } if count >= 0 { stop = func(fp any) bool { return CombinedSecond[int](fp) > count } } - fp, startNode, endNode := mts.mt.RangeFingerprint(node, x, y, stop) - // fmt.Fprintf(os.Stderr, "QQQQQ: fp %v, startNode %#v, endNode %#v\n", fp, startNode, endNode) + fp, startPtr, endPtr := mts.mt.RangeFingerprint(node, x, y, stop) cfp := fp.(CombinedFingerprint) return RangeInfo{ Fingerprint: cfp.First, Count: cfp.Second.(int), - Start: mts.iter(startNode), - End: mts.iter(endNode), + Start: mts.iter(startPtr), + End: mts.iter(endPtr), } } diff --git a/hashsync/monoid_tree_test.go b/hashsync/monoid_tree_test.go index f4bf2370bf..98e7a3cb03 100644 --- a/hashsync/monoid_tree_test.go +++ b/hashsync/monoid_tree_test.go @@ -272,7 +272,7 @@ func TestAscendingRanges(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { var fps []string - var node MonoidTreeNode + var node MonoidTreePointer for n, rng := range tc.ranges { x := sampleID(rng[0]) y := sampleID(rng[1]) @@ -293,34 +293,30 @@ func TestAscendingRanges(t *testing.T) { } func verifyBinaryTree(t *testing.T, mn *monoidTreeNode) { - if mn.parent != nil && mn != mn.parent.left && mn != mn.parent.right { - require.Fail(t, "node is an 'unknown' child") - } - if mn.left != nil { - require.Equal(t, mn, mn.left.parent, "bad parent node on the left branch") require.Negative(t, mn.left.key.Compare(mn.key)) - leftMaxNode := mn.left.maxNode() - require.Negative(t, leftMaxNode.key.Compare(mn.key)) + // not a "real" pointer (no parent stack), just to get max + leftMax := &monoidTreePointer{node: mn.left} + leftMax.max() + require.Negative(t, leftMax.Key().Compare(mn.key)) verifyBinaryTree(t, mn.left) } if mn.right != nil { - require.Equal(t, mn, mn.right.parent, "bad parent node on the right branch") require.Positive(t, mn.right.key.Compare(mn.key)) - rightMinNode := mn.right.minNode() - require.Positive(t, rightMinNode.key.Compare(mn.key)) + // not a "real" pointer (no parent stack), just to get min + rightMin := &monoidTreePointer{node: mn.right} + rightMin.min() + require.Positive(t, rightMin.Key().Compare(mn.key)) verifyBinaryTree(t, mn.right) } } -func verifyRedBlack(t *testing.T, mn *monoidTreeNode, blackDepth int) int { +func verifyRedBlackNode(t *testing.T, mn *monoidTreeNode, blackDepth int) int { if mn == nil { return blackDepth + 1 } if mn.color == red { - require.NotNil(t, mn.parent, "root node must be black") - require.Equal(t, black, mn.parent.color, "parent of a red node is red") if mn.left != nil { require.Equal(t, black, mn.left.color, "left child of a red node is red") } @@ -330,12 +326,20 @@ func verifyRedBlack(t *testing.T, mn *monoidTreeNode, blackDepth int) int { } else { blackDepth++ } - bdLeft := verifyRedBlack(t, mn.left, blackDepth) - bdRight := verifyRedBlack(t, mn.right, blackDepth) + bdLeft := verifyRedBlackNode(t, mn.left, blackDepth) + bdRight := verifyRedBlackNode(t, mn.right, blackDepth) require.Equal(t, bdLeft, bdRight, "subtree black depth for node %s", mn.key) return bdLeft } +func verifyRedBlack(t *testing.T, mt *monoidTree) { + if mt.root == nil { + return + } + require.Equal(t, black, mt.root.color, "root node must be black") + verifyRedBlackNode(t, mt.root, 0) +} + func TestRedBlackTreeInsert(t *testing.T) { for i := 0; i < 1000; i++ { tree := NewMonoidTree(sampleCountMonoid()) @@ -360,12 +364,12 @@ func TestRedBlackTreeInsert(t *testing.T) { // t.Logf("QQQQQ: tree:\n%s", tree.Dump()) root := tree.(*monoidTree).root verifyBinaryTree(t, root) - verifyRedBlack(t, root, 0) - for node := tree.Min(); node != nil; node = node.Next() { + verifyRedBlack(t, tree.(*monoidTree)) + for ptr := tree.Min(); ptr.Key() != nil; ptr.Next() { // avoid endless loop due to bugs in the tree impl require.Less(t, n, len(items)*2, "got much more items than needed: %q -- %q", actualItems, shuffled) n++ - actualItems = append(actualItems, node.Key().(sampleID)[0]) + actualItems = append(actualItems, ptr.Key().(sampleID)[0]) } require.Equal(t, items, actualItems) diff --git a/hashsync/rangesync.go b/hashsync/rangesync.go index 274cbee669..c6c724aeac 100644 --- a/hashsync/rangesync.go +++ b/hashsync/rangesync.go @@ -45,7 +45,7 @@ type Iterator interface { Key() Ordered // Next returns an iterator pointing to the next key or nil // if this key is the last one in the store - Next() Iterator + Next() } type RangeInfo struct { @@ -213,3 +213,4 @@ func (rsr *RangeSetReconciler) Process(c Conduit) error { // TBD: limit the number of rounds (outside RangeSetReconciler) // TBD: process ascending ranges properly // TBD: bounded reconcile +// TBD: limit max N of received unconfirmed items diff --git a/hashsync/rangesync_test.go b/hashsync/rangesync_test.go index da9ffd3b76..24064880ce 100644 --- a/hashsync/rangesync_test.go +++ b/hashsync/rangesync_test.go @@ -97,8 +97,11 @@ func (fc *fakeConduit) SendItems(x, y Ordered, fingerprint any, count int, start } it := start for { + if it.Key() == nil { + panic("fakeConduit.SendItems: went got to the end of the tree") + } fc.resp.items = append(fc.resp.items, it.Key()) - it = it.Next() + it.Next() if it.Equal(end) { break } @@ -113,24 +116,23 @@ type dumbStoreIterator struct { n int } -var _ Iterator = dumbStoreIterator{} +var _ Iterator = &dumbStoreIterator{} -func (it dumbStoreIterator) Equal(other Iterator) bool { - o := other.(dumbStoreIterator) +func (it *dumbStoreIterator) Equal(other Iterator) bool { + o := other.(*dumbStoreIterator) if it.ds != o.ds { panic("comparing iterators from different dumbStores") } return it.n == o.n } -func (it dumbStoreIterator) Key() Ordered { +func (it *dumbStoreIterator) Key() Ordered { return it.ds.items[it.n] } -func (it dumbStoreIterator) Next() Iterator { - return dumbStoreIterator{ - ds: it.ds, - n: (it.n + 1) % len(it.ds.items), +func (it *dumbStoreIterator) Next() { + if len(it.ds.items) != 0 { + it.n = (it.n + 1) % len(it.ds.items) } } @@ -164,7 +166,7 @@ func (ds *dumbStore) iter(n int) Iterator { if n == -1 || n == len(ds.items) { return nil } - return dumbStoreIterator{ds: ds, n: n} + return &dumbStoreIterator{ds: ds, n: n} } func (ds *dumbStore) last() sampleID { @@ -208,7 +210,7 @@ func (ds *dumbStore) Min() Iterator { if len(ds.items) == 0 { return nil } - return dumbStoreIterator{ + return &dumbStoreIterator{ ds: ds, n: 0, } @@ -218,7 +220,7 @@ func (ds *dumbStore) Max() Iterator { if len(ds.items) == 0 { return nil } - return dumbStoreIterator{ + return &dumbStoreIterator{ ds: ds, n: len(ds.items) - 1, } @@ -236,7 +238,9 @@ func (it verifiedStoreIterator) Equal(other Iterator) bool { o := other.(verifiedStoreIterator) eq1 := it.knownGood.Equal(o.knownGood) eq2 := it.it.Equal(o.it) - require.Equal(it.t, eq1, eq2, "iterators equal") + require.Equal(it.t, eq1, eq2, "iterators equal -- keys <%v> <%v> / <%v> <%v>", + it.knownGood.Key(), it.it.Key(), + o.knownGood.Key(), o.it.Key()) require.Equal(it.t, it.knownGood.Key(), it.it.Key(), "keys of equal iterators") return eq2 } @@ -248,15 +252,10 @@ func (it verifiedStoreIterator) Key() Ordered { return k2 } -func (it verifiedStoreIterator) Next() Iterator { - next1 := it.knownGood.Next() - next2 := it.it.Next() - require.Equal(it.t, next1.Key(), next2.Key(), "keys for Next()") - return verifiedStoreIterator{ - t: it.t, - knownGood: next1, - it: next2, - } +func (it verifiedStoreIterator) Next() { + it.knownGood.Next() + it.it.Next() + require.Equal(it.t, it.knownGood.Key(), it.it.Key(), "keys for Next()") } type verifiedStore struct { @@ -389,7 +388,7 @@ func storeItemStr(is ItemStore) string { if it.Equal(endAt) { return r } - it = it.Next() + it.Next() } } @@ -445,7 +444,7 @@ func runSync(t *testing.T, syncA, syncB *RangeSetReconciler, maxRounds int) (nRo // dumpRangeMessages(t, fc.msgs, "B %q --> A %q:", storeItemStr(syncB.is), storeItemStr(syncA.is)) syncA.Process(fc) } - return i + return i + 1 } func testRangeSync(t *testing.T, storeFactory storeFactory) { @@ -490,6 +489,13 @@ func testRangeSync(t *testing.T, storeFactory storeFactory) { final: "abcdefghijklmnopqr", maxRounds: [4]int{4, 4, 4, 3}, }, + { + name: "sync against 1-element set", + a: "bcd", + b: "a", + final: "abcd", + maxRounds: [4]int{3, 2, 2, 1}, + }, } { t.Run(tc.name, func(t *testing.T) { for n, maxSendRange := range []int{1, 2, 3, 4} { @@ -514,20 +520,26 @@ func TestRangeSync(t *testing.T) { } func testRandomSync(t *testing.T, storeFactory storeFactory) { + var bytesA, bytesB []byte + defer func() { + if t.Failed() { + t.Logf("Random sync failed: %q <-> %q", bytesA, bytesB) + } + }() for i := 0; i < 1000; i++ { var chars []byte for c := byte(33); c < 127; c++ { chars = append(chars, c) } - bytesA := append([]byte(nil), chars...) + bytesA = append([]byte(nil), chars...) rand.Shuffle(len(bytesA), func(i, j int) { bytesA[i], bytesA[j] = bytesA[j], bytesA[i] }) bytesA = bytesA[:rand.Intn(len(bytesA))] storeA := makeStore(t, storeFactory, string(bytesA)) - bytesB := append([]byte(nil), chars...) + bytesB = append([]byte(nil), chars...) rand.Shuffle(len(bytesB), func(i, j int) { bytesB[i], bytesB[j] = bytesB[j], bytesB[i] }) From 34f3b739d14a336a6a439fc8eb4cd3452142f02e Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 6 Jan 2024 01:10:56 +0400 Subject: [PATCH 04/76] hashsync: implement persistent tree --- hashsync/monoid_tree.go | 185 +++++++++++++++++++++++------------ hashsync/monoid_tree_test.go | 77 ++++++++++----- 2 files changed, 178 insertions(+), 84 deletions(-) diff --git a/hashsync/monoid_tree.go b/hashsync/monoid_tree.go index 0c3af52368..1f8f920a85 100644 --- a/hashsync/monoid_tree.go +++ b/hashsync/monoid_tree.go @@ -4,6 +4,7 @@ package hashsync import ( "fmt" "io" + "reflect" "slices" "strings" ) @@ -31,6 +32,7 @@ func (fpred FingerprintPredicate) Match(y any) bool { } type MonoidTree interface { + Copy() MonoidTree Fingerprint() any Add(v Ordered) Min() MonoidTreePointer @@ -65,26 +67,19 @@ type MonoidTreePointer interface { Next() } -type color uint8 +type flags uint8 const ( - red color = 0 - black color = 1 + // flagBlack indicates a black node. If it is not set, the + // node is red, which is the default for newly created nodes + flagBlack flags = 1 + // flagCloned indicates a node that is only present in this + // tree and not in any of its copies, thus permitting + // modification of this node without cloning it. When the tree + // is copied, flagCloned is cleared on all of its nodes. + flagCloned flags = 2 ) -func (c color) flip() color { return c ^ 1 } - -func (c color) String() string { - switch c { - case red: - return "red" - case black: - return "black" - default: - return fmt.Sprintf("", c) - } -} - type dir uint8 const ( @@ -227,15 +222,15 @@ type monoidTreeNode struct { key Ordered max Ordered fingerprint any - color color + flags flags } func (mn *monoidTreeNode) red() bool { - return mn != nil && mn.color == red + return mn != nil && (mn.flags&flagBlack) == 0 } func (mn *monoidTreeNode) black() bool { - return mn == nil || mn.color == black + return mn == nil || (mn.flags&flagBlack) != 0 } func (mn *monoidTreeNode) child(dir dir) *monoidTreeNode { @@ -248,26 +243,6 @@ func (mn *monoidTreeNode) child(dir dir) *monoidTreeNode { return mn.right } -func (mn *monoidTreeNode) setChild(dir dir, child *monoidTreeNode) { - if mn == nil { - panic("setChild for a nil node") - } - if dir == left { - mn.left = child - } else { - mn.right = child - } -} - -func (mn *monoidTreeNode) flip() { - if mn.left == nil || mn.right == nil { - panic("can't flip color with one or more nil children") - } - mn.color = mn.color.flip() - mn.left.color = mn.left.color.flip() - mn.right.color = mn.right.color.flip() -} - func (mn *monoidTreeNode) Key() Ordered { return mn.key } func (mn *monoidTreeNode) dump(w io.Writer, indent int) { @@ -275,6 +250,11 @@ func (mn *monoidTreeNode) dump(w io.Writer, indent int) { fmt.Fprintf(w, "%skey: %v\n", indentStr, mn.key) fmt.Fprintf(w, "%smax: %v\n", indentStr, mn.max) fmt.Fprintf(w, "%sfp: %v\n", indentStr, mn.fingerprint) + color := "red" + if mn.black() { + color = "black" + } + fmt.Fprintf(w, "%scolor: %v\n", indentStr, color) if mn.left != nil { fmt.Fprintf(w, "%sleft:\n", indentStr) mn.left.dump(w, indent+1) @@ -297,6 +277,19 @@ func (mn *monoidTreeNode) dumpSubtree() string { return sb.String() } +// cleanNodes removed flagCloned from all of the nodes in the subtree, +// so that it can be used in further cloned trees. +// A non-cloned node cannot have any cloned children, so the function +// stops the recursion at any non-cloned node. +func (mn *monoidTreeNode) cleanCloned() { + if mn == nil || mn.flags&flagCloned == 0 { + return + } + mn.flags &^= flagCloned + mn.left.cleanCloned() + mn.right.cleanCloned() +} + type monoidTree struct { m Monoid root *monoidTreeNode @@ -308,6 +301,19 @@ func NewMonoidTree(m Monoid) MonoidTree { return &monoidTree{m: m} } +func (mt *monoidTree) Copy() MonoidTree { + // Clean flagCloned from any nodes created specifically + // for this subtree. This will mean they will have to be + // re-cloned if they need to be changed again. + mt.root.cleanCloned() + // Don't reuse cachedMinPtr / cachedMaxPtr for the cloned + // tree to be on the safe side + return &monoidTree{ + m: mt.m, + root: mt.root, + } +} + func (mt *monoidTree) rootPtr() *monoidTreePointer { return &monoidTreePointer{ parentStack: make([]*monoidTreeNode, 0, initialParentStackSize), @@ -315,6 +321,48 @@ func (mt *monoidTree) rootPtr() *monoidTreePointer { } } +func (mt *monoidTree) ensureCloned(mn *monoidTreeNode) *monoidTreeNode { + if mn.flags&flagCloned != 0 { + return mn + } + cloned := *mn + cloned.flags |= flagCloned + return &cloned +} + +func (mt *monoidTree) setChild(mn *monoidTreeNode, dir dir, child *monoidTreeNode) *monoidTreeNode { + if mn == nil { + panic("setChild for a nil node") + } + if mn.child(dir) == child { + return mn + } + mn = mt.ensureCloned(mn) + if dir == left { + mn.left = child + } else { + mn.right = child + } + return mn +} + +func (mt *monoidTree) flip(mn *monoidTreeNode) *monoidTreeNode { + if mn.left == nil || mn.right == nil { + panic("can't flip color with one or more nil children") + } + + left := mt.ensureCloned(mn.left) + right := mt.ensureCloned(mn.right) + mn = mt.ensureCloned(mn) + mn.left = left + mn.right = right + + mn.flags ^= flagBlack + left.flags ^= flagBlack + right.flags ^= flagBlack + return mn +} + func (mt *monoidTree) Min() MonoidTreePointer { if mt.root == nil { return nil @@ -388,12 +436,17 @@ func (mt *monoidTree) safeFingerprint(mn *monoidTreeNode) any { func (mt *monoidTree) updateFingerprintAndMax(mn *monoidTreeNode) { fp := mt.m.Op(mt.safeFingerprint(mn.left), mt.m.Fingerprint(mn.key)) - mn.fingerprint = mt.m.Op(fp, mt.safeFingerprint(mn.right)) + fp = mt.m.Op(fp, mt.safeFingerprint(mn.right)) + newMax := mn.key if mn.right != nil { - mn.max = mn.right.max - } else { - mn.max = mn.key + newMax = mn.right.max + } + if mn.flags&flagCloned == 0 && + (!reflect.DeepEqual(mn.fingerprint, fp) || mn.max.Compare(newMax) != 0) { + panic("BUG: updating fingerprint/max for a non-cloned node") } + mn.fingerprint = fp + mn.max = newMax } func (mt *monoidTree) rotate(mn *monoidTreeNode, d dir) *monoidTreeNode { @@ -401,13 +454,17 @@ func (mt *monoidTree) rotate(mn *monoidTreeNode, d dir) *monoidTreeNode { rd := d.flip() tmp := mn.child(rd) + if tmp == nil { + panic("BUG: nil parent after rotate") + } // fmt.Fprintf(os.Stderr, "QQQQQ: rotate %s (child at %s is %s): subtree:\n%s\n", // d, rd, tmp.key, mn.dumpSubtree()) - mn.setChild(rd, tmp.child(d)) - tmp.setChild(d, mn) + mn = mt.setChild(mn, rd, tmp.child(d)) + tmp = mt.setChild(tmp, d, mn) - tmp.color = mn.color - mn.color = red + // copy node color to the tmp + tmp.flags = (tmp.flags &^ flagBlack) | (mn.flags & flagBlack) + mn.flags &^= flagBlack // set to red // it's important to update mn first as it may be the new right child of // tmp, and we need to update tmp.max too @@ -419,13 +476,16 @@ func (mt *monoidTree) rotate(mn *monoidTreeNode, d dir) *monoidTreeNode { func (mt *monoidTree) doubleRotate(mn *monoidTreeNode, d dir) *monoidTreeNode { rd := d.flip() - mn.setChild(rd, mt.rotate(mn.child(rd), rd)) + mn = mt.setChild(mn, rd, mt.rotate(mn.child(rd), rd)) return mt.rotate(mn, d) } func (mt *monoidTree) Add(v Ordered) { mt.root = mt.insert(mt.root, v, true) - mt.root.color = black + if mt.root.flags&flagBlack == 0 { + mt.root = mt.ensureCloned(mt.root) + mt.root.flags |= flagBlack + } } func (mt *monoidTree) insert(mn *monoidTreeNode, v Ordered, rb bool) *monoidTreeNode { @@ -433,7 +493,11 @@ func (mt *monoidTree) insert(mn *monoidTreeNode, v Ordered, rb bool) *monoidTree // https://zarif98sjs.github.io/blog/blog/redblacktree/ if mn == nil { mn = mt.newNode(nil, v) - // if the tree is being modified, cached min/max ptrs are no longer valid + // the new node is not really "cloned", but at this point it's + // only present in this tree so we can safely modify it + // without allocating new nodes + mn.flags |= flagCloned + // when the tree is being modified, cached min/max ptrs are no longer valid mt.cachedMinPtr = nil mt.cachedMaxPtr = nil return mn @@ -448,7 +512,7 @@ func (mt *monoidTree) insert(mn *monoidTreeNode, v Ordered, rb bool) *monoidTree } oldChild := mn.child(d) newChild := mt.insert(oldChild, v, rb) - mn.setChild(d, newChild) + mn = mt.setChild(mn, d, newChild) updateFP := true if rb { // non-red-black insert is used for testing @@ -460,6 +524,9 @@ func (mt *monoidTree) insert(mn *monoidTreeNode, v Ordered, rb bool) *monoidTree return mn } +// insertFixup fixes a subtree after insert according to Red-Black tree rules. +// It returns the updated node and a boolean indicating whether the fingerprint/max +// update is needed. The latter is NOT the case func (mt *monoidTree) insertFixup(mn *monoidTreeNode, d dir, updateFP bool) (*monoidTreeNode, bool) { child := mn.child(d) rd := d.flip() @@ -467,28 +534,27 @@ func (mt *monoidTree) insertFixup(mn *monoidTreeNode, d dir, updateFP bool) (*mo case child.black(): return mn, true case mn.child(rd).red(): - updateFP = true // both children of mn are red => any child has 2 reds in a row // (LL LR RR RL) => flip colors if child.child(d).red() || child.child(rd).red() { - mn.flip() + return mt.flip(mn), true } + return mn, true case child.child(d).red(): // another child of mn is black // any child has 2 reds in a row (LL RR) => rotate // rotate will update fingerprint of mn and the node // that replaces it - mn = mt.rotate(mn, rd) + return mt.rotate(mn, rd), updateFP case child.child(rd).red(): // another child of mn is black // any child has 2 reds in a row (LR RL) => align first, then rotate // doubleRotate will update fingerprint of mn and the node // that replaces it - mn = mt.doubleRotate(mn, rd) + return mt.doubleRotate(mn, rd), updateFP default: - updateFP = true + return mn, true } - return mn, updateFP } func (mt *monoidTree) findGTENode(ptr *monoidTreePointer, x Ordered) bool { @@ -778,7 +844,6 @@ func (mt *monoidTree) boundedAggregate(ptr *monoidTreePointer, acc any, stop Fin } } // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: %v -- return acc %v\n", mn.key, acc) - // QQQQQ: ZXXXXXX: return acc, nil !!!! return acc } } @@ -793,10 +858,6 @@ func (mt *monoidTree) Dump() string { } // TBD: !!! values and Lookup (via findGTENode) !!! -// TBD: maybe: persistent rbtree -- note that MonoidTree will be immutable, -// too, in this case (Insert returns a new tree => no problem with thread safety -// or cached min/max) - // TODO: rename MonoidTreeNode to just Node, MonoidTree to SyncTree // TODO: use sync.Pool for node alloc // see also: diff --git a/hashsync/monoid_tree_test.go b/hashsync/monoid_tree_test.go index 98e7a3cb03..34db8932cf 100644 --- a/hashsync/monoid_tree_test.go +++ b/hashsync/monoid_tree_test.go @@ -293,7 +293,11 @@ func TestAscendingRanges(t *testing.T) { } func verifyBinaryTree(t *testing.T, mn *monoidTreeNode) { + cloned := mn.flags&flagCloned != 0 if mn.left != nil { + if !cloned { + require.Zero(t, mn.left.flags&flagCloned, "cloned left child of a non-cloned node") + } require.Negative(t, mn.left.key.Compare(mn.key)) // not a "real" pointer (no parent stack), just to get max leftMax := &monoidTreePointer{node: mn.left} @@ -303,6 +307,9 @@ func verifyBinaryTree(t *testing.T, mn *monoidTreeNode) { } if mn.right != nil { + if !cloned { + require.Zero(t, mn.right.flags&flagCloned, "cloned right child of a non-cloned node") + } require.Positive(t, mn.right.key.Compare(mn.key)) // not a "real" pointer (no parent stack), just to get min rightMin := &monoidTreePointer{node: mn.right} @@ -316,12 +323,12 @@ func verifyRedBlackNode(t *testing.T, mn *monoidTreeNode, blackDepth int) int { if mn == nil { return blackDepth + 1 } - if mn.color == red { + if mn.flags&flagBlack == 0 { if mn.left != nil { - require.Equal(t, black, mn.left.color, "left child of a red node is red") + require.Equal(t, flagBlack, mn.left.flags&flagBlack, "left child of a red node is red") } if mn.right != nil { - require.Equal(t, black, mn.right.color, "right child of a red node is red") + require.Equal(t, flagBlack, mn.right.flags&flagBlack, "right child of a red node is red") } } else { blackDepth++ @@ -336,7 +343,7 @@ func verifyRedBlack(t *testing.T, mt *monoidTree) { if mt.root == nil { return } - require.Equal(t, black, mt.root.color, "root node must be black") + require.Equal(t, flagBlack, mt.root.flags&flagBlack, "root node must be black") verifyRedBlackNode(t, mt.root, 0) } @@ -354,16 +361,31 @@ func TestRedBlackTreeInsert(t *testing.T) { // items := []byte("0123456789ABCDEFG") // shuffled := []byte("0678DF1CG5A9324BE") + trees := make([]MonoidTree, len(shuffled)) + treeDumps := make([]string, len(shuffled)) + for i := 0; i < len(shuffled); i++ { + trees[i] = tree.Copy() + treeDumps[i] = tree.Dump() + require.Equal(t, treeDumps[i], trees[i].Dump(), "initial tree dump %d", i) + tree.Add(sampleID(shuffled[i])) + if i >= 3 && i%3 == 0 { + // this shouldn't change anything + trees[i-1].Add(sampleID(shuffled[rand.Intn(i-1)])) + // cloning should not happen b/c no new nodes are inserted + require.Zero(t, trees[i-1].(*monoidTree).root.flags&flagCloned) + } + } + for i := 0; i < len(shuffled); i++ { - tree.Add(sampleID(shuffled[i])) // XXXX: Insert + require.Equal(t, treeDumps[i], trees[i].Dump(), "tree dump %d after copy", i) } + var actualItems []byte n := 0 // t.Logf("items: %q", string(items)) // t.Logf("shuffled: %q", string(shuffled)) // t.Logf("QQQQQ: tree:\n%s", tree.Dump()) - root := tree.(*monoidTree).root - verifyBinaryTree(t, root) + verifyBinaryTree(t, tree.(*monoidTree).root) verifyRedBlack(t, tree.(*monoidTree)) for ptr := tree.Min(); ptr.Key() != nil; ptr.Next() { // avoid endless loop due to bugs in the tree impl @@ -402,25 +424,36 @@ func testRandomOrderAndRanges(t *testing.T, mktree makeTestTreeFunc) { } } - expFP, expStart, expEnd := naiveRange(all, string(x), string(y), stopCount) - fp, startNode, endNode := tree.RangeFingerprint(nil, x, y, stop) + verify := func() { + expFP, expStart, expEnd := naiveRange(all, string(x), string(y), stopCount) + fp, startNode, endNode := tree.RangeFingerprint(nil, x, y, stop) - fpStr := CombinedFirst[string](fp) - curCase := fmt.Sprintf("items %q x %q y %q stopCount %d", shuffled, x, y, stopCount) - require.Equal(t, expFP, fpStr, "%s: fingerprint", curCase) - require.Equal(t, len(fpStr), CombinedSecond[int](fp), "%s: count", curCase) + fpStr := CombinedFirst[string](fp) + curCase := fmt.Sprintf("items %q x %q y %q stopCount %d", shuffled, x, y, stopCount) + require.Equal(t, expFP, fpStr, "%s: fingerprint", curCase) + require.Equal(t, len(fpStr), CombinedSecond[int](fp), "%s: count", curCase) - startStr := "" - if startNode != nil { - startStr = string(startNode.Key().(sampleID)) - } - require.Equal(t, expStart, startStr, "%s: next", curCase) + startStr := "" + if startNode != nil { + startStr = string(startNode.Key().(sampleID)) + } + require.Equal(t, expStart, startStr, "%s: next", curCase) - endStr := "" - if endNode != nil { - endStr = string(endNode.Key().(sampleID)) + endStr := "" + if endNode != nil { + endStr = string(endNode.Key().(sampleID)) + } + require.Equal(t, expEnd, endStr, "%s: next", curCase) } - require.Equal(t, expEnd, endStr, "%s: next", curCase) + verify() + tree1 := tree.Copy() + tree1.Add(sampleID("s")) + tree1.Add(sampleID("t")) + tree1.Add(sampleID("u")) + verify() // the original tree should be unchanged + fp, _, _ := tree1.RangeFingerprint(nil, sampleID("a"), sampleID("a"), nil) + require.Equal(t, "abcdefghijklmnopqrstu", CombinedFirst[string](fp)) + require.Equal(t, len(all)+3, CombinedSecond[int](fp)) } } From 6bac0de2566fffc3be7e4cee9be149a6f53f8e0d Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 6 Jan 2024 02:36:44 +0400 Subject: [PATCH 05/76] hashsync: add tree values and lookup --- hashsync/monoid_tree.go | 59 +++++++++++++++++++++++++++-------- hashsync/monoid_tree_store.go | 5 ++- hashsync/monoid_tree_test.go | 54 +++++++++++++++++++++++++++++--- hashsync/rangesync.go | 3 +- hashsync/rangesync_test.go | 11 +++---- 5 files changed, 104 insertions(+), 28 deletions(-) diff --git a/hashsync/monoid_tree.go b/hashsync/monoid_tree.go index 1f8f920a85..4c989061df 100644 --- a/hashsync/monoid_tree.go +++ b/hashsync/monoid_tree.go @@ -34,7 +34,9 @@ func (fpred FingerprintPredicate) Match(y any) bool { type MonoidTree interface { Copy() MonoidTree Fingerprint() any - Add(v Ordered) + Add(k Ordered) + Set(k Ordered, v any) + Lookup(k Ordered) (any, bool) Min() MonoidTreePointer Max() MonoidTreePointer RangeFingerprint(ptr MonoidTreePointer, start, end Ordered, stop FingerprintPredicate) (fp any, startNode, endNode MonoidTreePointer) @@ -63,6 +65,7 @@ func MonoidTreeFromSlice[T Ordered](m Monoid, items []T) MonoidTree { type MonoidTreePointer interface { Equal(other MonoidTreePointer) bool Key() Ordered + Value() any Prev() Next() } @@ -216,10 +219,18 @@ func (p *monoidTreePointer) Key() Ordered { return p.node.key } +func (p *monoidTreePointer) Value() any { + if p.node == nil { + return nil + } + return p.node.value +} + type monoidTreeNode struct { left *monoidTreeNode right *monoidTreeNode key Ordered + value any max Ordered fingerprint any flags flags @@ -398,11 +409,12 @@ func (mt *monoidTree) Fingerprint() any { return mt.root.fingerprint } -func (mt *monoidTree) newNode(parent *monoidTreeNode, v Ordered) *monoidTreeNode { +func (mt *monoidTree) newNode(parent *monoidTreeNode, k Ordered, v any) *monoidTreeNode { return &monoidTreeNode{ - key: v, - max: v, - fingerprint: mt.m.Fingerprint(v), + key: k, + value: v, + max: k, + fingerprint: mt.m.Fingerprint(k), } } @@ -411,10 +423,10 @@ func (mt *monoidTree) buildFromSortedSlice(parent *monoidTreeNode, s []Ordered) case 0: return nil case 1: - return mt.newNode(nil, s[0]) + return mt.newNode(nil, s[0], nil) } middle := len(s) / 2 - node := mt.newNode(parent, s[middle]) + node := mt.newNode(parent, s[middle], nil) node.left = mt.buildFromSortedSlice(node, s[:middle]) node.right = mt.buildFromSortedSlice(node, s[middle+1:]) if node.left != nil { @@ -480,19 +492,27 @@ func (mt *monoidTree) doubleRotate(mn *monoidTreeNode, d dir) *monoidTreeNode { return mt.rotate(mn, d) } -func (mt *monoidTree) Add(v Ordered) { - mt.root = mt.insert(mt.root, v, true) +func (mt *monoidTree) Add(k Ordered) { + mt.add(k, nil, false) +} + +func (mt *monoidTree) Set(k Ordered, v any) { + mt.add(k, v, true) +} + +func (mt *monoidTree) add(k Ordered, v any, set bool) { + mt.root = mt.insert(mt.root, k, v, true, set) if mt.root.flags&flagBlack == 0 { mt.root = mt.ensureCloned(mt.root) mt.root.flags |= flagBlack } } -func (mt *monoidTree) insert(mn *monoidTreeNode, v Ordered, rb bool) *monoidTreeNode { +func (mt *monoidTree) insert(mn *monoidTreeNode, k Ordered, v any, rb, set bool) *monoidTreeNode { // simplified insert implementation idea from // https://zarif98sjs.github.io/blog/blog/redblacktree/ if mn == nil { - mn = mt.newNode(nil, v) + mn = mt.newNode(nil, k, v) // the new node is not really "cloned", but at this point it's // only present in this tree so we can safely modify it // without allocating new nodes @@ -502,8 +522,12 @@ func (mt *monoidTree) insert(mn *monoidTreeNode, v Ordered, rb bool) *monoidTree mt.cachedMaxPtr = nil return mn } - c := v.Compare(mn.key) + c := k.Compare(mn.key) if c == 0 { + if v != mn.value { + mn = mt.ensureCloned(mn) + mn.value = v + } return mn } d := left @@ -511,7 +535,7 @@ func (mt *monoidTree) insert(mn *monoidTreeNode, v Ordered, rb bool) *monoidTree d = right } oldChild := mn.child(d) - newChild := mt.insert(oldChild, v, rb) + newChild := mt.insert(oldChild, k, v, rb, set) mn = mt.setChild(mn, d, newChild) updateFP := true if rb { @@ -557,6 +581,15 @@ func (mt *monoidTree) insertFixup(mn *monoidTreeNode, d dir, updateFP bool) (*mo } } +func (mt *monoidTree) Lookup(k Ordered) (any, bool) { + // TODO: lookups shouldn't cause any allocation! + ptr := mt.rootPtr() + if !mt.findGTENode(ptr, k) || ptr.node == nil || ptr.Key().Compare(k) != 0 { + return nil, false + } + return ptr.Value(), true +} + func (mt *monoidTree) findGTENode(ptr *monoidTreePointer, x Ordered) bool { for { switch { diff --git a/hashsync/monoid_tree_store.go b/hashsync/monoid_tree_store.go index 3d0f7af21b..4ea94b0558 100644 --- a/hashsync/monoid_tree_store.go +++ b/hashsync/monoid_tree_store.go @@ -26,7 +26,6 @@ func (it *monoidTreeIterator) Next() { } } -// TBD: Lookup method type MonoidTreeStore struct { mt MonoidTree } @@ -40,8 +39,8 @@ func NewMonoidTreeStore(m Monoid) ItemStore { } // Add implements ItemStore. -func (mts *MonoidTreeStore) Add(v Ordered) { - mts.mt.Add(v) +func (mts *MonoidTreeStore) Add(k Ordered) { + mts.mt.Add(k) } func (mts *MonoidTreeStore) iter(ptr MonoidTreePointer) Iterator { diff --git a/hashsync/monoid_tree_test.go b/hashsync/monoid_tree_test.go index 34db8932cf..f6e02c2cb9 100644 --- a/hashsync/monoid_tree_test.go +++ b/hashsync/monoid_tree_test.go @@ -40,9 +40,9 @@ func makeStringConcatTree(chars string) MonoidTree { // dumbAdd inserts the node into the tree without trying to maintain the // red-black properties -func dumbAdd(mt MonoidTree, v Ordered) { +func dumbAdd(mt MonoidTree, k Ordered) { mtree := mt.(*monoidTree) - mtree.root = mtree.insert(mtree.root, v, false) + mtree.root = mtree.insert(mtree.root, k, nil, false, false) } // makeDumbTree constructs a binary tree by adding the chars one-by-one without @@ -59,9 +59,6 @@ func makeDumbTree(chars string) MonoidTree { } func makeRBTree(chars string) MonoidTree { - if len(chars) == 0 { - panic("empty set") - } mt := NewMonoidTree(sampleCountMonoid()) for _, c := range chars { mt.Add(sampleID(c)) @@ -465,3 +462,50 @@ func TestRandomOrderAndRanges(t *testing.T) { testRandomOrderAndRanges(t, makeRBTree) }) } + +func TestTreeValues(t *testing.T) { + tree := makeRBTree("") + tree.Add(sampleID("a")) + tree.Set(sampleID("b"), 123) + tree.Set(sampleID("d"), 456) + verifyOrig := func() { + v, found := tree.Lookup(sampleID("a")) + require.True(t, found) + require.Nil(t, v) + v, found = tree.Lookup(sampleID("b")) + require.True(t, found) + require.Equal(t, 123, v) + v, found = tree.Lookup(sampleID("c")) + require.False(t, found) + require.Nil(t, v) + v, found = tree.Lookup(sampleID("d")) + require.True(t, found) + require.Equal(t, 456, v) + } + verifyOrig() + + treeDump := tree.Dump() + tree1 := tree.Copy() + + // flagCloned on the root should be cleared after copy + // and not set again by Set b/c the value is the same + tree.Set(sampleID("d"), 456) // nothing changed + require.Zero(t, tree.(*monoidTree).root.flags&flagCloned) + + tree1.Set(sampleID("b"), 1234) + tree1.Set(sampleID("c"), 222) + verifyOrig() + require.Equal(t, treeDump, tree.Dump()) + v, found := tree1.Lookup(sampleID("a")) + require.True(t, found) + require.Nil(t, v) + v, found = tree1.Lookup(sampleID("b")) + require.True(t, found) + require.Equal(t, 1234, v) + v, found = tree1.Lookup(sampleID("c")) + require.True(t, found) + require.Equal(t, 222, v) + v, found = tree1.Lookup(sampleID("d")) + require.True(t, found) + require.Equal(t, 456, v) +} diff --git a/hashsync/rangesync.go b/hashsync/rangesync.go index c6c724aeac..b341bf84a7 100644 --- a/hashsync/rangesync.go +++ b/hashsync/rangesync.go @@ -55,7 +55,8 @@ type RangeInfo struct { } type ItemStore interface { - Add(v Ordered) + // Add adds a key to the store + Add(k Ordered) // GetRangeInfo returns RangeInfo for the item range in the tree. // If count >= 0, at most count items are returned, and RangeInfo // is returned for the corresponding subrange of the requested range diff --git a/hashsync/rangesync_test.go b/hashsync/rangesync_test.go index 24064880ce..672acf7f18 100644 --- a/hashsync/rangesync_test.go +++ b/hashsync/rangesync_test.go @@ -142,9 +142,8 @@ type dumbStore struct { var _ ItemStore = &dumbStore{} -func (ds *dumbStore) Add(v Ordered) { - //slices.Insert[S ~[]E, E any](s S, i int, v ...E) - id := v.(sampleID) +func (ds *dumbStore) Add(k Ordered) { + id := k.(sampleID) if len(ds.items) == 0 { ds.items = []sampleID{id} return @@ -266,9 +265,9 @@ type verifiedStore struct { var _ ItemStore = &verifiedStore{} -func (vs *verifiedStore) Add(v Ordered) { - vs.knownGood.Add(v) - vs.store.Add(v) +func (vs *verifiedStore) Add(k Ordered) { + vs.knownGood.Add(k) + vs.store.Add(k) } func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo { From 513b136e57700bf8569447a0148d1d1df28e6176 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 6 Jan 2024 23:35:16 +0400 Subject: [PATCH 06/76] hashsync: fix sending same items multiple times --- hashsync/rangesync.go | 117 ++++++++++++++++++++++++---------- hashsync/rangesync_test.go | 126 ++++++++++++++++++++++++++----------- 2 files changed, 173 insertions(+), 70 deletions(-) diff --git a/hashsync/rangesync.go b/hashsync/rangesync.go index b341bf84a7..050c43afc8 100644 --- a/hashsync/rangesync.go +++ b/hashsync/rangesync.go @@ -16,6 +16,7 @@ type SyncMessage interface { HaveItems() bool } +// Conduit handles receiving and sending peer messages type Conduit interface { // NextMessage returns the next SyncMessage, or nil if there // are no more SyncMessages. @@ -23,10 +24,21 @@ type Conduit interface { // NextItem returns the next item in the set or nil if there // are no more items NextItem() (Ordered, error) - // SendFingerprint sends range fingerprint to the peer + // SendFingerprint sends range fingerprint to the peer. + // Count must be > 0 SendFingerprint(x, y Ordered, fingerprint any, count int) - // SendItems sends range fingerprint to the peer along with the items - SendItems(x, y Ordered, fingerprint any, count int, start, end Iterator) + // SendEmptySet notifies the peer that it we don't have any items. + // The corresponding SyncMessage has Count() == 0, X() == nil and Y() == nil + SendEmptySet() + // SendEmptyRange notifies the peer that the specified range + // is empty on our side. The corresponding SyncMessage has Count() == 0 + SendEmptyRange(x, y Ordered) + // SendItems sends the local items to the peer, requesting back + // the items peer has in that range. The corresponding + // SyncMessage has HaveItems() == true + SendItems(x, y Ordered, count int, it Iterator) + // SendItemsOnly sends just items without any message + SendItemsOnly(count int, it Iterator) } type Option func(r *RangeSetReconciler) @@ -82,8 +94,8 @@ func NewRangeSetReconciler(is ItemStore, opts ...Option) *RangeSetReconciler { for _, opt := range opts { opt(rsr) } - if rsr.maxSendRange == 0 { - panic("zero maxSendRange") + if rsr.maxSendRange <= 0 { + panic("bad maxSendRange") } return rsr } @@ -101,26 +113,52 @@ func (rsr *RangeSetReconciler) addItems(c Conduit) error { } } +// func qqqqRmmeK(it Iterator) any { +// if it == nil { +// return "" +// } +// if it.Key() == nil { +// return "" +// } +// return fmt.Sprintf("%s", it.Key()) +// } + func (rsr *RangeSetReconciler) processSubrange(c Conduit, preceding, start, end Iterator, x, y Ordered) Iterator { if preceding != nil && preceding.Key().Compare(x) > 0 { preceding = nil } + // fmt.Fprintf(os.Stderr, "QQQQQ: preceding=%q\n", + // qqqqRmmeK(preceding)) info := rsr.is.GetRangeInfo(preceding, x, y, -1) - // If the range is small enough, we send its contents - if info.Count != 0 && info.Count <= rsr.maxSendRange { - c.SendItems(x, y, info.Fingerprint, info.Count, start, end) - } else { + // fmt.Fprintf(os.Stderr, "QQQQQ: start=%q end=%q info.Start=%q info.End=%q info.FP=%q x=%q y=%q\n", + // qqqqRmmeK(start), qqqqRmmeK(end), qqqqRmmeK(info.Start), qqqqRmmeK(info.End), info.Fingerprint, x, y) + switch { + // case info.Count != 0 && info.Count <= rsr.maxSendRange: + // // If the range is small enough, we send its contents. + // // The peer may have more items of its own in that range, + // // so we can't use SendItemsOnly(), instead we use SendItems, + // // which includes our items and asks the peer to send any + // // items it has in the range. + // c.SendItems(x, y, info.Count, info.Start) + case info.Count == 0: + // We have no more items in this subrange. + // Ask peer to send any items it has in the range + c.SendEmptyRange(x, y) + default: + // The range is non-empty and large enough. + // Send fingerprint so that the peer can further subdivide it. c.SendFingerprint(x, y, info.Fingerprint, info.Count) } + // fmt.Fprintf(os.Stderr, "QQQQQ: info.End=%q\n", qqqqRmmeK(info.End)) return info.End } -func (rsr *RangeSetReconciler) processFingerprint(c Conduit, preceding Iterator, msg SyncMessage) Iterator { +func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg SyncMessage) Iterator { x := msg.X() y := msg.Y() if x == nil && y == nil { // The peer has no items at all so didn't - // even send X & Y + // even send X & Y (SendEmptySet) it := rsr.is.Min() if it == nil { // We don't have any items at all, too @@ -135,52 +173,58 @@ func (rsr *RangeSetReconciler) processFingerprint(c Conduit, preceding Iterator, info := rsr.is.GetRangeInfo(preceding, x, y, -1) // fmt.Fprintf(os.Stderr, "msg %s fp %v start %#v end %#v count %d\n", msg, info.Fingerprint, info.Start, info.End, info.Count) switch { - // FIXME: use Fingerprint interface for fingerprints - // with Equal() method - case reflect.DeepEqual(info.Fingerprint, msg.Fingerprint()): - // fmt.Fprintf(os.Stderr, "range synced: %s\n", msg) - // the range is synced - return info.End - case info.Count <= rsr.maxSendRange || msg.Count() == 0: - // The other side is missing some items, and either - // range is small enough or empty on the other side + case msg.HaveItems() || msg.Count() == 0: + // The peer has no more items to send in this range after this + // message, as it is either empty or it has sent all of its + // items in the range to us, but there may be some items on our + // side. In the latter case, send only the items themselves b/c + // the range doesn't need any further handling by the peer. if info.Count != 0 { - // fmt.Fprintf(os.Stderr, "small/empty incoming range: %s -> SendItems\n", msg) - c.SendItems(x, y, info.Fingerprint, info.Count, info.Start, info.End) + c.SendItemsOnly(info.Count, info.Start) + } + case fingerprintEqual(info.Fingerprint, msg.Fingerprint()): + // The range is synced + // case (info.Count+1)/2 <= rsr.maxSendRange: + case info.Count <= rsr.maxSendRange: + // The range differs from the peer's version of it, but the it + // is small enough (or would be small enough after split) or + // empty on our side + if info.Count != 0 { + // fmt.Fprintf(os.Stderr, "small incoming range: %s -> SendItems\n", msg) + c.SendItems(x, y, info.Count, info.Start) } else { - // fmt.Fprintf(os.Stderr, "small/empty incoming range: %s -> zero count msg\n", msg) - c.SendFingerprint(x, y, info.Fingerprint, info.Count) + // fmt.Fprintf(os.Stderr, "small incoming range: %s -> empty range msg\n", msg) + c.SendEmptyRange(x, y) } - return info.End default: // Need to split the range. // Note that there's no special handling for rollover ranges with x >= y // These need to be handled by ItemStore.GetRangeInfo() - count := info.Count / 2 + count := (info.Count + 1) / 2 part := rsr.is.GetRangeInfo(preceding, x, y, count) if part.End == nil { panic("BUG: can't split range with count > 1") } middle := part.End.Key() next := rsr.processSubrange(c, info.Start, part.Start, part.End, x, middle) + // fmt.Fprintf(os.Stderr, "QQQQQ: next=%q\n", qqqqRmmeK(next)) rsr.processSubrange(c, next, part.End, info.End, middle, y) // fmt.Fprintf(os.Stderr, "normal: split X %s - middle %s - Y %s:\n %s", // msg.X(), middle, msg.Y(), msg) - return info.End } + return info.End } func (rsr *RangeSetReconciler) Initiate(c Conduit) { it := rsr.is.Min() if it == nil { - // Create a message with count 0 - c.SendFingerprint(nil, nil, nil, 0) + c.SendEmptySet() return } min := it.Key() info := rsr.is.GetRangeInfo(nil, min, min, -1) if info.Count != 0 && info.Count < rsr.maxSendRange { - c.SendItems(min, min, info.Fingerprint, info.Count, info.Start, info.End) + c.SendItems(min, min, info.Count, info.Start) } else { c.SendFingerprint(min, min, info.Fingerprint, info.Count) } @@ -204,14 +248,23 @@ func (rsr *RangeSetReconciler) Process(c Conduit) error { } for _, msg := range msgs { - // TODO: need to sort ranges, but also need to be careful - rsr.processFingerprint(c, nil, msg) + // TODO: need to sort the ranges, but also need to be careful + rsr.handleMessage(c, nil, msg) } return nil } +func fingerprintEqual(a, b any) bool { + // FIXME: use Fingerprint interface with Equal() method for fingerprints + // but still allow nil fingerprints + return reflect.DeepEqual(a, b) +} + +// TBD: successive messages with payloads can be combined! // TBD: limit the number of rounds (outside RangeSetReconciler) // TBD: process ascending ranges properly // TBD: bounded reconcile // TBD: limit max N of received unconfirmed items +// TBD: streaming sync with sequence numbers or timestamps +// TBD: never pass just one of X and Y as nil when decoding the messages!!! diff --git a/hashsync/rangesync_test.go b/hashsync/rangesync_test.go index 672acf7f18..0b678a3d42 100644 --- a/hashsync/rangesync_test.go +++ b/hashsync/rangesync_test.go @@ -6,6 +6,7 @@ import ( "slices" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/exp/maps" ) @@ -35,6 +36,7 @@ func (m rangeMessage) String() string { } type fakeConduit struct { + t *testing.T msgs []rangeMessage items []Ordered resp *fakeConduit @@ -46,10 +48,7 @@ func (fc *fakeConduit) done() bool { if fc.resp == nil { return true } - if len(fc.resp.msgs) == 0 { - panic("BUG: not done but no msgs") - } - return false + return len(fc.resp.msgs) == 0 && len(fc.resp.items) == 0 } func (fc *fakeConduit) NextMessage() (SyncMessage, error) { @@ -72,10 +71,14 @@ func (fc *fakeConduit) NextItem() (Ordered, error) { return nil, nil } -func (fc *fakeConduit) sendFingerprint(x, y Ordered, fingerprint any, count int, haveItems bool) { +func (fc *fakeConduit) ensureResp() { if fc.resp == nil { - fc.resp = &fakeConduit{} + fc.resp = &fakeConduit{t: fc.t} } +} + +func (fc *fakeConduit) sendMsg(x, y Ordered, fingerprint any, count int, haveItems bool) { + fc.ensureResp() msg := rangeMessage{ x: x, y: y, @@ -86,31 +89,49 @@ func (fc *fakeConduit) sendFingerprint(x, y Ordered, fingerprint any, count int, fc.resp.msgs = append(fc.resp.msgs, msg) } -func (fc *fakeConduit) SendFingerprint(x, y Ordered, fingerprint any, count int) { - fc.sendFingerprint(x, y, fingerprint, count, false) -} - -func (fc *fakeConduit) SendItems(x, y Ordered, fingerprint any, count int, start, end Iterator) { - fc.sendFingerprint(x, y, fingerprint, count, true) - if start == nil || end == nil { - panic("SendItems with null iterator(s)") - } - it := start - for { +func (fc *fakeConduit) sendItems(count int, it Iterator) { + require.NotZero(fc.t, count) + require.NotNil(fc.t, it) + fc.ensureResp() + for i := 0; i < count; i++ { if it.Key() == nil { panic("fakeConduit.SendItems: went got to the end of the tree") } fc.resp.items = append(fc.resp.items, it.Key()) it.Next() - if it.Equal(end) { - break - } - } - if len(fc.resp.items) == 0 { - panic("SendItems with no items") } } +func (fc *fakeConduit) SendFingerprint(x, y Ordered, fingerprint any, count int) { + require.NotNil(fc.t, x) + require.NotNil(fc.t, y) + require.NotZero(fc.t, count) + require.NotNil(fc.t, fingerprint) + fc.sendMsg(x, y, fingerprint, count, false) +} + +func (fc *fakeConduit) SendEmptySet() { + fc.sendMsg(nil, nil, nil, 0, false) +} + +func (fc *fakeConduit) SendEmptyRange(x, y Ordered) { + require.NotNil(fc.t, x) + require.NotNil(fc.t, y) + fc.sendMsg(x, y, nil, 0, false) +} + +func (fc *fakeConduit) SendItems(x, y Ordered, count int, it Iterator) { + require.Positive(fc.t, count) + require.NotNil(fc.t, x) + require.NotNil(fc.t, y) + fc.sendMsg(x, y, nil, count, true) + fc.sendItems(count, it) +} + +func (fc *fakeConduit) SendItemsOnly(count int, it Iterator) { + fc.sendItems(count, it) +} + type dumbStoreIterator struct { ds *dumbStore n int @@ -237,35 +258,51 @@ func (it verifiedStoreIterator) Equal(other Iterator) bool { o := other.(verifiedStoreIterator) eq1 := it.knownGood.Equal(o.knownGood) eq2 := it.it.Equal(o.it) - require.Equal(it.t, eq1, eq2, "iterators equal -- keys <%v> <%v> / <%v> <%v>", + assert.Equal(it.t, eq1, eq2, "iterators equal -- keys <%v> <%v> / <%v> <%v>", it.knownGood.Key(), it.it.Key(), o.knownGood.Key(), o.it.Key()) - require.Equal(it.t, it.knownGood.Key(), it.it.Key(), "keys of equal iterators") + assert.Equal(it.t, it.knownGood.Key(), it.it.Key(), "keys of equal iterators") return eq2 } func (it verifiedStoreIterator) Key() Ordered { k1 := it.knownGood.Key() k2 := it.it.Key() - require.Equal(it.t, k1, k2, "keys") + assert.Equal(it.t, k1, k2, "keys") return k2 } func (it verifiedStoreIterator) Next() { it.knownGood.Next() it.it.Next() - require.Equal(it.t, it.knownGood.Key(), it.it.Key(), "keys for Next()") + assert.Equal(it.t, it.knownGood.Key(), it.it.Key(), "keys for Next()") } type verifiedStore struct { - t *testing.T - knownGood ItemStore - store ItemStore + t *testing.T + knownGood ItemStore + store ItemStore + disableReAdd bool + added map[sampleID]struct{} } var _ ItemStore = &verifiedStore{} +func disableReAdd(s ItemStore) { + if vs, ok := s.(*verifiedStore); ok { + vs.disableReAdd = true + } +} + func (vs *verifiedStore) Add(k Ordered) { + if vs.disableReAdd { + _, found := vs.added[k.(sampleID)] + require.False(vs.t, found, "hash sent twice: %v", k) + if vs.added == nil { + vs.added = make(map[sampleID]struct{}) + } + vs.added[k.(sampleID)] = struct{}{} + } vs.knownGood.Add(k) vs.store.Add(k) } @@ -312,6 +349,8 @@ func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count in it: ri2.End, } } + // QQQQQ: TODO: if count >= 0 and start+end != nil, do more calls to GetRangeInfo using resulting + // end iterator key to make sure the range is correct return ri } @@ -380,14 +419,14 @@ func storeItemStr(is ItemStore) string { if it == nil { return "" } - endAt := is.Max() + endAt := is.Min() r := "" for { r += string(it.Key().(sampleID)) + it.Next() if it.Equal(endAt) { return r } - it.Next() } } @@ -417,6 +456,7 @@ func forTestStores(t *testing.T, testFunc func(t *testing.T, factory storeFactor } } +// QQQQQ: rm func dumpRangeMessages(t *testing.T, msgs []rangeMessage, fmt string, args ...any) { t.Logf(fmt, args...) for _, m := range msgs { @@ -424,26 +464,30 @@ func dumpRangeMessages(t *testing.T, msgs []rangeMessage, fmt string, args ...an } } -func runSync(t *testing.T, syncA, syncB *RangeSetReconciler, maxRounds int) (nRounds int) { - fc := &fakeConduit{} +func runSync(t *testing.T, syncA, syncB *RangeSetReconciler, maxRounds int) (nRounds, nMsg, nItems int) { + fc := &fakeConduit{t: t} syncA.Initiate(fc) require.False(t, fc.done(), "no messages from Initiate") var i int - for i := 0; !fc.done(); i++ { + for i = 0; !fc.done(); i++ { if i == maxRounds { require.FailNow(t, "too many rounds", "didn't reconcile in %d rounds", i) } - // dumpRangeMessages(t, fc.msgs, "A %q -> B %q:", storeItemStr(syncA.is), storeItemStr(syncB.is)) fc = fc.resp + // dumpRangeMessages(t, fc.msgs, "A %q -> B %q:", storeItemStr(syncA.is), storeItemStr(syncB.is)) + nMsg += len(fc.msgs) + nItems += len(fc.items) syncB.Process(fc) if fc.done() { break } fc = fc.resp + nMsg += len(fc.msgs) + nItems += len(fc.items) // dumpRangeMessages(t, fc.msgs, "B %q --> A %q:", storeItemStr(syncB.is), storeItemStr(syncA.is)) syncA.Process(fc) } - return i + 1 + return i + 1, nMsg, nItems } func testRangeSync(t *testing.T, storeFactory storeFactory) { @@ -500,11 +544,13 @@ func testRangeSync(t *testing.T, storeFactory storeFactory) { for n, maxSendRange := range []int{1, 2, 3, 4} { t.Logf("maxSendRange: %d", maxSendRange) storeA := makeStore(t, storeFactory, tc.a) + disableReAdd(storeA) syncA := NewRangeSetReconciler(storeA, WithMaxSendRange(maxSendRange)) storeB := makeStore(t, storeFactory, tc.b) + disableReAdd(storeB) syncB := NewRangeSetReconciler(storeB, WithMaxSendRange(maxSendRange)) - nRounds := runSync(t, syncA, syncB, tc.maxRounds[n]) + nRounds, _, _ := runSync(t, syncA, syncB, tc.maxRounds[n]) t.Logf("%s: maxSendRange %d: %d rounds", tc.name, maxSendRange, nRounds) require.Equal(t, storeItemStr(storeA), storeItemStr(storeB)) @@ -558,6 +604,7 @@ func testRandomSync(t *testing.T, storeFactory storeFactory) { syncB := NewRangeSetReconciler(storeB, WithMaxSendRange(maxSendRange)) runSync(t, syncA, syncB, max(len(expectedSet), 2)) // FIXME: less rounds! + // t.Logf("maxSendRange %d a %d b %d n %d", maxSendRange, len(bytesA), len(bytesB), n) require.Equal(t, storeItemStr(storeA), storeItemStr(storeB)) require.Equal(t, string(expectedSet), storeItemStr(storeA), "expected set for %q<->%q", bytesA, bytesB) @@ -568,5 +615,8 @@ func TestRandomSync(t *testing.T) { forTestStores(t, testRandomSync) } +// TBD: test XOR + big sync // TBD: include initiate round!!! // TBD: use logger for verbose logging (messages) +// TBD: in fakeConduit -- check item count against the iterator in SendItems / SendItemsOnly!! +// TBD: record interaction using golden master in testRangeSync, together with N of rounds / msgs / items and don't check max rounds From 4cc5358a750404b1ad2483e0be5f43ff0a9a20da Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 6 Jan 2024 23:36:37 +0400 Subject: [PATCH 07/76] hashsync: test XOR fingerprint based sync for hashes --- common/types/hashes.go | 11 ++++ hashsync/monoid_tree.go | 6 +- hashsync/monoid_tree_test.go | 2 +- hashsync/xorsync.go | 61 +++++++++++++++++++ hashsync/xorsync_test.go | 113 +++++++++++++++++++++++++++++++++++ 5 files changed, 189 insertions(+), 4 deletions(-) create mode 100644 hashsync/xorsync.go create mode 100644 hashsync/xorsync_test.go diff --git a/common/types/hashes.go b/common/types/hashes.go index dc3b786a79..aaa8525979 100644 --- a/common/types/hashes.go +++ b/common/types/hashes.go @@ -1,6 +1,7 @@ package types import ( + "bytes" "encoding/hex" "fmt" "math/big" @@ -203,6 +204,16 @@ func (h Hash32) ShortString() string { return Shorten(h.Hex()[min(2, l):], 10) } +// Compare compares a Hash32 to another hash and returns +// +// -1 if this hash is less than the other +// 0 if the hashes are equal +// 1 if this hash is greater than the other +func (h Hash32) Compare(other any) int { + oh := other.(Hash32) + return bytes.Compare(h[:], oh[:]) +} + // Shorten shortens a string to a specified length. func Shorten(s string, maxlen int) string { l := len(s) diff --git a/hashsync/monoid_tree.go b/hashsync/monoid_tree.go index 4c989061df..3b0133455c 100644 --- a/hashsync/monoid_tree.go +++ b/hashsync/monoid_tree.go @@ -10,20 +10,20 @@ import ( ) type Ordered interface { - Compare(other Ordered) int + Compare(other any) int } type LowerBound struct{} var _ Ordered = LowerBound{} -func (vb LowerBound) Compare(x Ordered) int { return -1 } +func (vb LowerBound) Compare(x any) int { return -1 } type UpperBound struct{} var _ Ordered = UpperBound{} -func (vb UpperBound) Compare(x Ordered) int { return 1 } +func (vb UpperBound) Compare(x any) int { return 1 } type FingerprintPredicate func(fp any) bool diff --git a/hashsync/monoid_tree_test.go b/hashsync/monoid_tree_test.go index f6e02c2cb9..071e207381 100644 --- a/hashsync/monoid_tree_test.go +++ b/hashsync/monoid_tree_test.go @@ -14,7 +14,7 @@ type sampleID string var _ Ordered = sampleID("") -func (s sampleID) Compare(other Ordered) int { +func (s sampleID) Compare(other any) int { return cmp.Compare(s, other.(sampleID)) } diff --git a/hashsync/xorsync.go b/hashsync/xorsync.go new file mode 100644 index 0000000000..88063a0054 --- /dev/null +++ b/hashsync/xorsync.go @@ -0,0 +1,61 @@ +package hashsync + +import ( + "sync" + + "github.com/zeebo/blake3" + + "github.com/spacemeshos/go-spacemesh/common/types" +) + +// Note: we don't care too much about artificially induced collisions. +// Given that none of the synced hashes are used internally or +// propagated further down the P2P network before the actual contents +// of the objects is received and validated, most an attacker can get +// is partial sync of this node with the attacker node, which doesn't +// pose any serious threat. We could even skip additional hashing +// altogether, but let's make playing the algorithm not too easy. + +type Hash32To12Xor struct{} + +var _ Monoid = Hash32To12Xor{} + +func (m Hash32To12Xor) Identity() any { + return types.Hash12{} +} + +func (m Hash32To12Xor) Op(b, a any) any { + var r types.Hash12 + h1 := a.(types.Hash12) + h2 := b.(types.Hash12) + for n, b := range h1 { + r[n] = b ^ h2[n] + } + return r +} + +var hashPool = &sync.Pool{ + New: func() any { + return blake3.New() + }, +} + +func (m Hash32To12Xor) Fingerprint(v any) any { + // Blake3's New allocates too much memory, + // so we can't just call types.CalcHash12(h[:]) here + // TODO: fix types.CalcHash12() + h := v.(types.Hash32) + var r types.Hash12 + // copy(r[:], h[20:]) + // return r + hasher := hashPool.Get().(*blake3.Hasher) + defer func() { + hasher.Reset() + hashPool.Put(hasher) + }() + var hashRes [32]byte + hasher.Write(h[:]) + hasher.Sum(hashRes[:0]) + copy(r[:], hashRes[:]) + return r +} diff --git a/hashsync/xorsync_test.go b/hashsync/xorsync_test.go new file mode 100644 index 0000000000..80e6cce385 --- /dev/null +++ b/hashsync/xorsync_test.go @@ -0,0 +1,113 @@ +package hashsync + +import ( + "math/rand" + "slices" + "testing" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/stretchr/testify/require" +) + +func TestHash32To12Xor(t *testing.T) { + var m Hash32To12Xor + require.Equal(t, m.Identity(), m.Op(m.Identity(), m.Identity())) + hash1 := types.CalcHash32([]byte("foo")) + fp1 := m.Fingerprint(hash1) + hash2 := types.CalcHash32([]byte("bar")) + fp2 := m.Fingerprint(hash2) + hash3 := types.CalcHash32([]byte("baz")) + fp3 := m.Fingerprint(hash3) + require.Equal(t, fp1, m.Op(m.Identity(), fp1)) + require.Equal(t, fp2, m.Op(fp2, m.Identity())) + require.NotEqual(t, fp1, fp2) + require.NotEqual(t, fp1, fp3) + require.NotEqual(t, fp1, m.Op(fp1, fp2)) + require.NotEqual(t, fp2, m.Op(fp1, fp2)) + require.NotEqual(t, m.Identity(), m.Op(fp1, fp2)) + require.Equal(t, m.Op(m.Op(fp1, fp2), fp3), m.Op(fp1, m.Op(fp2, fp3))) +} + +func collectStoreItems[T Ordered](is ItemStore) (r []T) { + it := is.Min() + if it == nil { + return nil + } + endAt := is.Min() + for { + r = append(r, it.Key().(T)) + it.Next() + if it.Equal(endAt) { + return r + } + } +} + +const numTestHashes = 100000 + +// const numTestHashes = 100 + +type catchTransferTwice struct { + ItemStore + t *testing.T + added map[types.Hash32]bool +} + +func (s *catchTransferTwice) Add(k Ordered) { + h := k.(types.Hash32) + _, found := s.added[h] + require.False(s.t, found, "hash sent twice") + s.ItemStore.Add(k) + if s.added == nil { + s.added = make(map[types.Hash32]bool) + } + s.added[h] = true +} + +const xorTestMaxSendRange = 1 + +func TestBigSyncHash32(t *testing.T) { + numSpecificA := rand.Intn(96) + 4 + numSpecificB := rand.Intn(96) + 4 + // numSpecificA := rand.Intn(6) + 4 + // numSpecificB := rand.Intn(6) + 4 + src := make([]types.Hash32, numTestHashes) + for n := range src { + src[n] = types.RandomHash() + } + + sliceA := src[:numTestHashes-numSpecificB] + storeA := NewMonoidTreeStore(Hash32To12Xor{}) + for _, h := range sliceA { + storeA.Add(h) + } + storeA = &catchTransferTwice{t: t, ItemStore: storeA} + syncA := NewRangeSetReconciler(storeA, WithMaxSendRange(xorTestMaxSendRange)) + + sliceB := append([]types.Hash32(nil), src[:numTestHashes-numSpecificB-numSpecificA]...) + sliceB = append(sliceB, src[numTestHashes-numSpecificB:]...) + storeB := NewMonoidTreeStore(Hash32To12Xor{}) + for _, h := range sliceB { + storeB.Add(h) + } + storeB = &catchTransferTwice{t: t, ItemStore: storeB} + syncB := NewRangeSetReconciler(storeB, WithMaxSendRange(xorTestMaxSendRange)) + + nRounds, nMsg, nItems := runSync(t, syncA, syncB, 100) + excess := float64(nItems-numSpecificA-numSpecificB) / float64(numSpecificA+numSpecificB) + t.Logf("numSpecificA: %d, numSpecificB: %d, nRounds: %d, nMsg: %d, nItems: %d, excess: %.2f", + numSpecificA, numSpecificB, nRounds, nMsg, nItems, excess) + + slices.SortFunc(src, func(a, b types.Hash32) int { + return a.Compare(b) + }) + itemsA := collectStoreItems[types.Hash32](storeA) + itemsB := collectStoreItems[types.Hash32](storeB) + require.Equal(t, itemsA, itemsB) + require.Equal(t, src, itemsA) +} + +// TODO: try catching items sent twice in a simpler test +// TODO: check why insertion takes so long (1000000 items => too long wait) +// TODO: number of items transferred is unreasonable for 100k total / 1 range size: +// xorsync_test.go:56: numSpecificA: 141, numSpecificB: 784, nRounds: 11, nMsg: 13987, nItems: 3553 From 077c1ca0de7a9fc74d892f906c294b1fe41641a2 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Mon, 8 Jan 2024 14:32:07 +0400 Subject: [PATCH 08/76] p2p: streamed / interactive protocol support --- p2p/server/server.go | 265 ++++++++++++++++++++++++++++++-------- p2p/server/server_test.go | 162 ++++++++++++++++++++++- 2 files changed, 364 insertions(+), 63 deletions(-) diff --git a/p2p/server/server.go b/p2p/server/server.go index 7b30f691be..419771c417 100644 --- a/p2p/server/server.go +++ b/p2p/server/server.go @@ -77,9 +77,42 @@ func WithRequestsPerInterval(n int, interval time.Duration) Opt { } } +type Interactor interface { + Send(data []byte) error + SendError(err error) error + Receive() ([]byte, error) +} + +type ServerHandler interface { + Handle(ctx context.Context, i Interactor) (time.Duration, error) +} + +type InteractiveHandler func(ctx context.Context, i Interactor) (time.Duration, error) + +func (ih InteractiveHandler) Handle(ctx context.Context, i Interactor) (time.Duration, error) { + return ih(ctx, i) +} + // Handler is the handler to be defined by the application. type Handler func(context.Context, []byte) ([]byte, error) +func (h Handler) Handle(ctx context.Context, i Interactor) (time.Duration, error) { + in, err := i.Receive() + if err != nil { + return 0, err + } + start := time.Now() + out, err := h(ctx, in) + duration := time.Since(start) + if err != nil { + if err := i.SendError(err); err != nil { + return duration, err + } + return duration, err + } + return duration, i.Send(out) +} + //go:generate scalegen -types Response // Response is a server response. @@ -101,7 +134,7 @@ type Host interface { type Server struct { logger log.Log protocol string - handler Handler + handler ServerHandler timeout time.Duration requestLimit int queueSize int @@ -114,7 +147,7 @@ type Server struct { } // New server for the handler. -func New(h Host, proto string, handler Handler, opts ...Opt) *Server { +func New(h Host, proto string, handler ServerHandler, opts ...Opt) *Server { srv := &Server{ logger: log.NewNop(), protocol: proto, @@ -184,52 +217,24 @@ func (s *Server) Run(ctx context.Context) error { } func (s *Server) queueHandler(ctx context.Context, stream network.Stream) { - defer stream.Close() - _ = stream.SetDeadline(time.Now().Add(s.timeout)) - defer stream.SetDeadline(time.Time{}) - rd := bufio.NewReader(stream) - size, err := varint.ReadUvarint(rd) - if err != nil { - return - } - if size > uint64(s.requestLimit) { - s.logger.With().Warning("request limit overflow", - log.Int("limit", s.requestLimit), - log.Uint64("request", size), - ) - stream.Conn().Close() - return - } - buf := make([]byte, size) - _, err = io.ReadFull(rd, buf) - if err != nil { + ps := s.peerStream(stream) + defer ps.Close() + defer ps.clearDeadline() + if err := ps.readInitialRequest(); err != nil { + s.logger.With().Debug("error receiving request", log.Err(err)) + ps.Conn().Close() return } - start := time.Now() - buf, err = s.handler(log.WithNewRequestID(ctx), buf) - s.logger.With().Debug("protocol handler execution time", - log.String("protocol", s.protocol), - log.Duration("duration", time.Since(start)), - ) - var resp Response + d, err := s.handler.Handle(ctx, ps) if err != nil { - resp.Error = err.Error() + s.logger.With().Debug("protocol handler execution time", + log.String("protocol", s.protocol), + log.Duration("duration", d)) } else { - resp.Data = buf - } - - wr := bufio.NewWriter(stream) - if _, err := codec.EncodeTo(wr, &resp); err != nil { - s.logger.With().Warning( - "failed to write response", - log.Int("resp.Data len", len(resp.Data)), - log.Int("resp.Error len", len(resp.Error)), - log.Err(err), - ) - return - } - if err := wr.Flush(); err != nil { - s.logger.With().Warning("failed to flush stream", log.Err(err)) + s.logger.With().Debug("protocol handler execution failed", + log.String("protocol", s.protocol), + log.Duration("duration", d), + log.Err(err)) } } @@ -275,7 +280,54 @@ func (s *Server) Request( return nil } +func (s *Server) InteractiveRequest( + ctx context.Context, + pid peer.ID, + initialRequest []byte, + handler InteractiveHandler, + failure func(error), +) error { + // start := time.Now() + if len(initialRequest) > s.requestLimit { + return fmt.Errorf("request length (%d) is longer than limit %d", len(initialRequest), s.requestLimit) + } + if s.h.Network().Connectedness(pid) != network.Connected { + return fmt.Errorf("%w: %s", ErrNotConnected, pid) + } + go func() { + ps, err := s.beginRequest(ctx, pid) + if err != nil { + failure(err) + return + } + defer ps.Close() + defer ps.clearDeadline() + + ps.updateDeadline() + ps.sendInitialRequest(initialRequest) + + if _, err = handler.Handle(ctx, ps); err != nil { + failure(err) + } + // TODO: client latency metrics + }() + return nil +} + func (s *Server) request(ctx context.Context, pid peer.ID, req []byte) (*Response, error) { + ps, err := s.beginRequest(ctx, pid) + if err != nil { + return nil, err + } + defer ps.Close() + defer ps.clearDeadline() + + ps.updateDeadline() + ps.sendInitialRequest(req) + return ps.readResponse() +} + +func (s *Server) beginRequest(ctx context.Context, pid peer.ID) (*peerStream, error) { ctx, cancel := context.WithTimeout(ctx, s.timeout) defer cancel() @@ -288,27 +340,126 @@ func (s *Server) request(ctx context.Context, pid peer.ID, req []byte) (*Respons if err != nil { return nil, err } - defer stream.Close() - defer stream.SetDeadline(time.Time{}) - _ = stream.SetDeadline(time.Now().Add(s.timeout)) - wr := bufio.NewWriter(stream) + return s.peerStream(stream), err +} + +func (s *Server) peerStream(stream network.Stream) *peerStream { + return &peerStream{ + Stream: stream, + rd: bufio.NewReader(stream), + wr: bufio.NewWriter(stream), + s: s, + } +} + +type peerStream struct { + network.Stream + rd *bufio.Reader + wr *bufio.Writer + s *Server + initReq []byte +} + +func (ps *peerStream) updateDeadline() { + ps.SetDeadline(time.Now().Add(ps.s.timeout)) +} + +func (ps *peerStream) clearDeadline() { + ps.SetDeadline(time.Time{}) +} + +func (ps *peerStream) sendInitialRequest(data []byte) error { sz := make([]byte, binary.MaxVarintLen64) - n := binary.PutUvarint(sz, uint64(len(req))) - if _, err := wr.Write(sz[:n]); err != nil { - return nil, err + n := binary.PutUvarint(sz, uint64(len(data))) + if _, err := ps.wr.Write(sz[:n]); err != nil { + return err } - if _, err := wr.Write(req); err != nil { - return nil, err + if _, err := ps.wr.Write(data); err != nil { + return err } - if err := wr.Flush(); err != nil { - return nil, err + if err := ps.wr.Flush(); err != nil { + return err + } + return nil +} + +var errOversizedRequest = errors.New("request size limit exceeded") + +func (ps *peerStream) readInitialRequest() error { + size, err := varint.ReadUvarint(ps.rd) + if err != nil { + return err + } + if size > uint64(ps.s.requestLimit) { + ps.s.logger.With().Warning("request limit overflow", + log.Int("limit", ps.s.requestLimit), + log.Uint64("request", size), + ) + ps.Conn().Close() + return errOversizedRequest + } + ps.initReq = make([]byte, size) + _, err = io.ReadFull(ps.rd, ps.initReq) + if err != nil { + return err + } + return nil +} + +func (ps *peerStream) sendResponse(resp *Response) error { + if _, err := codec.EncodeTo(ps.wr, resp); err != nil { + ps.s.logger.With().Warning( + "failed to write response", + log.Int("resp.Data len", len(resp.Data)), + log.Int("resp.Error len", len(resp.Error)), + log.Err(err), + ) + return err + } + + if err := ps.wr.Flush(); err != nil { + ps.s.logger.With().Warning("failed to flush stream", log.Err(err)) + return err } - rd := bufio.NewReader(stream) + return nil +} + +func (ps *peerStream) readResponse() (*Response, error) { var r Response - if _, err = codec.DecodeFrom(rd, &r); err != nil { + if _, err := codec.DecodeFrom(ps.rd, &r); err != nil { return nil, err } return &r, nil } + +func (ps *peerStream) Receive() (data []byte, err error) { + if ps.initReq != nil { + data, ps.initReq = ps.initReq, nil + return data, nil + } + resp, err := ps.readResponse() + switch { + case err != nil: + return nil, err + case resp.Error != "": + return nil, fmt.Errorf("%w: %s", RemoteError, resp.Error) + default: + ps.updateDeadline() + return resp.Data, nil + } +} + +func (ps *peerStream) Send(data []byte) error { + ps.updateDeadline() + return ps.sendResponse(&Response{Data: data}) +} + +func (ps *peerStream) SendError(err error) error { + return ps.sendResponse(&Response{Error: err.Error()}) +} + +var RemoteError = errors.New("peer reported an error") + +// TODO: InteractiveRequest should be same as Request diff --git a/p2p/server/server_test.go b/p2p/server/server_test.go index 22183342be..6fd27231a3 100644 --- a/p2p/server/server_test.go +++ b/p2p/server/server_test.go @@ -1,8 +1,10 @@ package server import ( + "bytes" "context" "errors" + "fmt" "sync/atomic" "testing" "time" @@ -26,12 +28,12 @@ func TestServer(t *testing.T) { errch := make(chan error, 1) respch := make(chan []byte, 1) - handler := func(_ context.Context, msg []byte) ([]byte, error) { + handler := Handler(func(_ context.Context, msg []byte) ([]byte, error) { return msg, nil - } - errhandler := func(_ context.Context, _ []byte) ([]byte, error) { + }) + errhandler := Handler(func(_ context.Context, _ []byte) ([]byte, error) { return nil, testErr - } + }) opts := []Opt{ WithTimeout(100 * time.Millisecond), WithLog(logtest.New(t)), @@ -135,6 +137,154 @@ func TestServer(t *testing.T) { }) } +func TestInteractive(t *testing.T) { + const limit = 1024 + + mesh, err := mocknet.FullMeshConnected(2) + require.NoError(t, err) + proto := "itest" + errch := make(chan error, 1) + respch := make(chan []string, 1) + + handler := InteractiveHandler(func(ctx context.Context, i Interactor) (time.Duration, error) { + hi, err := i.Receive() + if err != nil { + return 0, err + } + if string(hi) != "" { + return 0, fmt.Errorf("server: bad greeting message: %q", hi) + } + if err := i.Send([]byte("")); err != nil { + return 0, err + } + + var b bytes.Buffer + for { + msg, err := i.Receive() + if err != nil { + return 0, err + } + m := string(msg) + if m == "" { + break + } else if m == "" { + retErr := errors.New("duplicate ") + if err := i.SendError(retErr); err != nil { + return 0, err + } + return 0, retErr + } else { + b.WriteString("+" + m) + } + } + + for _, s := range []string{b.String(), "foo", "bar", "baz", ""} { + if err := i.Send([]byte(s)); err != nil { + return 0, err + } + } + + return 0, nil + }) + opts := []Opt{ + WithTimeout(100 * time.Millisecond), + WithLog(logtest.New(t)), + } + client := New(mesh.Hosts()[0], proto, handler, append(opts, WithRequestSizeLimit(2*limit))...) + srv1 := New(mesh.Hosts()[1], proto, handler, append(opts, WithRequestSizeLimit(limit))...) + ctx, cancel := context.WithCancel(context.Background()) + var eg errgroup.Group + eg.Go(func() error { + return srv1.Run(ctx) + }) + require.Eventually(t, func() bool { + for _, h := range mesh.Hosts()[1:] { + if len(h.Mux().Protocols()) == 0 { + return false + } + } + return true + }, time.Second, 10*time.Millisecond) + t.Cleanup(func() { + cancel() + eg.Wait() + }) + makeRespHandler := func(strs ...string) InteractiveHandler { + return func(ctx context.Context, i Interactor) (time.Duration, error) { + hi, err := i.Receive() + if err != nil { + return 0, err + } + if string(hi) != "" { + return 0, fmt.Errorf("client: bad greeting message: %q", hi) + } + + for _, s := range append(strs, "") { + if err := i.Send([]byte(s)); err != nil { + return 0, err + } + } + + var strs []string + for { + msg, err := i.Receive() + if err != nil { + return 0, err + } + m := string(msg) + switch m { + case "": + respch <- strs + return 0, nil + case "": + return 0, errors.New("duplicate ") + default: + strs = append(strs, m) + } + } + } + } + respErrHandler := func(err error) { + select { + case <-ctx.Done(): + case errch <- err: + } + } + initReq := []byte("") + + t.Run("ReceiveMessage", func(t *testing.T) { + respHandler := makeRespHandler("abc", "def", "ghi") + require.NoError( + t, + client.InteractiveRequest(ctx, mesh.Hosts()[1].ID(), initReq, respHandler, respErrHandler), + ) + select { + case <-time.After(time.Second): + require.FailNow(t, "timed out while waiting for interaction to finish") + case strs := <-respch: + require.Equal(t, []string{"+abc+def+ghi", "foo", "bar", "baz"}, strs) + case err := <-errch: + require.Fail(t, "unexpected error", "%v", err) + } + }) + + t.Run("ReceiveError", func(t *testing.T) { + respHandler := makeRespHandler("abc", "def", "") + require.NoError( + t, + client.InteractiveRequest(ctx, mesh.Hosts()[1].ID(), initReq, respHandler, respErrHandler), + ) + select { + case <-time.After(time.Second): + require.FailNow(t, "timed out while waiting for error response") + case <-respch: + require.FailNow(t, "got unexpected response") + case err := <-errch: + require.ErrorContains(t, err, "peer reported an error: duplicate ") + } + }) +} + func TestQueued(t *testing.T) { mesh, err := mocknet.FullMeshConnected(2) require.NoError(t, err) @@ -150,9 +300,9 @@ func TestQueued(t *testing.T) { srv := New( mesh.Hosts()[1], proto, - func(_ context.Context, msg []byte) ([]byte, error) { + Handler(func(_ context.Context, msg []byte) ([]byte, error) { return msg, nil - }, + }), WithQueueSize(total/4), WithRequestsPerInterval(50, time.Second), WithMetrics(), From 5be5a1efaf2fbc4aff4d098c0289baaed08ea54a Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 10 Jan 2024 06:44:02 +0400 Subject: [PATCH 09/76] hashsync: implement wire protocol --- common/types/hashes.go | 6 + hashsync/handler.go | 269 ++++++++++++++++++++ hashsync/handler_test.go | 477 +++++++++++++++++++++++++++++++++++ hashsync/interface.go | 13 + hashsync/monoid_tree_test.go | 1 + hashsync/rangesync.go | 305 ++++++++++++++++------ hashsync/rangesync_test.go | 174 +++++++------ hashsync/wire_types.go | 109 ++++++++ hashsync/wire_types_scale.go | 258 +++++++++++++++++++ hashsync/xorsync.go | 2 - hashsync/xorsync_test.go | 64 +++-- 11 files changed, 1498 insertions(+), 180 deletions(-) create mode 100644 hashsync/handler.go create mode 100644 hashsync/handler_test.go create mode 100644 hashsync/interface.go create mode 100644 hashsync/wire_types.go create mode 100644 hashsync/wire_types_scale.go diff --git a/common/types/hashes.go b/common/types/hashes.go index aaa8525979..a8981fc051 100644 --- a/common/types/hashes.go +++ b/common/types/hashes.go @@ -34,6 +34,12 @@ type Hash20 [hash20Length]byte // Field returns a log field. Implements the LoggableField interface. func (h Hash12) Field() log.Field { return log.String("hash", hex.EncodeToString(h[:])) } +// String implements the stringer interface and is used also by the logger when +// doing full logging into a file. +func (h Hash12) String() string { + return util.Encode(h[:5]) +} + // Bytes gets the byte representation of the underlying hash. func (h Hash20) Bytes() []byte { return h[:] } diff --git a/hashsync/handler.go b/hashsync/handler.go new file mode 100644 index 0000000000..9624f57a47 --- /dev/null +++ b/hashsync/handler.go @@ -0,0 +1,269 @@ +package hashsync + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "time" + + "github.com/spacemeshos/go-spacemesh/codec" + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/p2p/server" +) + +type outboundMessage struct { + code MessageType // TODO: "mt" + msg codec.Encodable +} + +type conduitState int + +type wireConduit struct { + i server.Interactor + pendingMsgs []SyncMessage + initReqBuf *bytes.Buffer + // rmmePrint bool +} + +var _ Conduit = &wireConduit{} + +func (c *wireConduit) reset() { + c.pendingMsgs = nil +} + +// receive receives a single frame from the Interactor and decodes one +// or more SyncMessages from it. The frames contain just one message +// except for the initial frame which may contain multiple messages +// b/c of the way Server handles the initial request +func (c *wireConduit) receive() (msgs []SyncMessage, err error) { + data, err := c.i.Receive() + if err != nil { + return nil, err + } + if len(data) == 0 { + return nil, errors.New("zero length sync message") + } + b := bytes.NewBuffer(data) + for { + code, err := b.ReadByte() + if err != nil { + if !errors.Is(err, io.EOF) { + // this shouldn't really happen + return nil, err + } + // fmt.Fprintf(os.Stderr, "QQQQQ: wireConduit: decoded msgs: %#v\n", msgs) + return msgs, nil + } + mtype := MessageType(code) + // fmt.Fprintf(os.Stderr, "QQQQQ: wireConduit: receive message type %s\n", mtype) + switch mtype { + case MessageTypeDone: + msgs = append(msgs, &DoneMessage{}) + case MessageTypeEndRound: + msgs = append(msgs, &EndRoundMessage{}) + case MessageTypeItemBatch: + var m ItemBatchMessage + if _, err := codec.DecodeFrom(b, &m); err != nil { + return nil, err + } + msgs = append(msgs, &m) + case MessageTypeEmptySet: + msgs = append(msgs, &EmptySetMessage{}) + case MessageTypeEmptyRange: + var m EmptyRangeMessage + if _, err := codec.DecodeFrom(b, &m); err != nil { + return nil, err + } + msgs = append(msgs, &m) + case MessageTypeFingerprint: + var m FingerprintMessage + if _, err := codec.DecodeFrom(b, &m); err != nil { + return nil, err + } + msgs = append(msgs, &m) + case MessageTypeRangeContents: + var m RangeContentsMessage + if _, err := codec.DecodeFrom(b, &m); err != nil { + return nil, err + } + msgs = append(msgs, &m) + default: + return nil, fmt.Errorf("invalid message code %02x", code) + } + } +} + +func (c *wireConduit) send(m SyncMessage) error { + // fmt.Fprintf(os.Stderr, "QQQQQ: wireConduit: sending %s m %#v\n", m.Type(), m) + msg := []byte{byte(m.Type())} + // if c.rmmePrint { + // fmt.Fprintf(os.Stderr, "QQQQQ: send: %s\n", SyncMessageToString(m)) + // } + encoded, err := codec.Encode(m.(codec.Encodable)) + if err != nil { + return fmt.Errorf("error encoding %T: %w", m, err) + } + msg = append(msg, encoded...) + if c.initReqBuf != nil { + c.initReqBuf.Write(msg) + } else { + if err := c.i.Send(msg); err != nil { + return err + } + } + return nil +} + +// NextMessage implements Conduit. +func (c *wireConduit) NextMessage() (SyncMessage, error) { + if len(c.pendingMsgs) != 0 { + m := c.pendingMsgs[0] + c.pendingMsgs = c.pendingMsgs[1:] + // if c.rmmePrint { + // fmt.Fprintf(os.Stderr, "QQQQQ: recv: %s\n", SyncMessageToString(m)) + // } + return m, nil + } + + msgs, err := c.receive() + if err != nil { + return nil, err + } + if len(msgs) == 0 { + return nil, nil + } + + c.pendingMsgs = msgs[1:] + // if c.rmmePrint { + // fmt.Fprintf(os.Stderr, "QQQQQ: recv: %s\n", SyncMessageToString(msgs[0])) + // } + return msgs[0], nil +} + +func (c *wireConduit) SendFingerprint(x Ordered, y Ordered, fingerprint any, count int) error { + return c.send(&FingerprintMessage{ + RangeX: x.(types.Hash32), + RangeY: y.(types.Hash32), + RangeFingerprint: fingerprint.(types.Hash12), + NumItems: uint32(count), + }) +} + +func (c *wireConduit) SendEmptySet() error { + return c.send(&EmptySetMessage{}) +} + +func (c *wireConduit) SendEmptyRange(x Ordered, y Ordered) error { + return c.send(&EmptyRangeMessage{RangeX: x.(types.Hash32), RangeY: y.(types.Hash32)}) +} + +func (c *wireConduit) SendRangeContents(x Ordered, y Ordered, count int) error { + return c.send(&RangeContentsMessage{ + RangeX: x.(types.Hash32), + RangeY: y.(types.Hash32), + NumItems: uint32(count), + }) +} + +func (c *wireConduit) SendItems(count, itemChunkSize int, it Iterator) error { + for i := 0; i < count; i += itemChunkSize { + var msg ItemBatchMessage + n := min(itemChunkSize, count-i) + for n > 0 { + if it.Key() == nil { + panic("fakeConduit.SendItems: went got to the end of the tree") + } + msg.Contents = append(msg.Contents, it.Key().(types.Hash32)) + it.Next() + n-- + } + if err := c.send(&msg); err != nil { + return err + } + } + return nil +} + +func (c *wireConduit) SendEndRound() error { + return c.send(&EndRoundMessage{}) +} + +func (c *wireConduit) SendDone() error { + return c.send(&DoneMessage{}) +} + +func (c *wireConduit) withInitialRequest(toCall func(Conduit) error) ([]byte, error) { + c.initReqBuf = new(bytes.Buffer) + defer func() { c.initReqBuf = nil }() + if err := toCall(c); err != nil { + return nil, err + } + return c.initReqBuf.Bytes(), nil +} + +func makeHandler(rsr *RangeSetReconciler, c *wireConduit, done chan struct{}) server.InteractiveHandler { + return func(ctx context.Context, i server.Interactor) (time.Duration, error) { + defer func() { + if done != nil { + close(done) + } + }() + c.i = i + for { + c.reset() + // Process() will receive all items and messages from the peer + syncDone, err := rsr.Process(c) + if err != nil { + // do not close done if we're returning an + // error, as the channel will be closed in the + // error handler func + done = nil + return 0, err + } else if syncDone { + return 0, nil + } + } + } +} + +func MakeServerHandler(rsr *RangeSetReconciler) server.InteractiveHandler { + return func(ctx context.Context, i server.Interactor) (time.Duration, error) { + var c wireConduit + h := makeHandler(rsr, &c, nil) + return h(ctx, i) + } +} + +func SyncStore(ctx context.Context, r requester, peer p2p.Peer, rsr *RangeSetReconciler) error { + var c wireConduit + // c.rmmePrint = true + initReq, err := c.withInitialRequest(rsr.Initiate) + if err != nil { + return err + } + done := make(chan struct{}, 1) + h := makeHandler(rsr, &c, done) + var reqErr error + if err = r.InteractiveRequest(ctx, peer, initReq, h, func(err error) { + reqErr = err + close(done) + }); err != nil { + return err + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-done: + return reqErr + } +} + +// TODO: HashSyncer object (SyncStore, also server handler, implementing ServerHandler) +// TODO: HashSyncer options instead of itemChunkSize (WithItemChunkSize, WithMaxSendRange) +// TODO: duration +// TODO: validate counts +// TODO: don't forget about Initiate!!! +// TBD: use MessageType instead of byte diff --git a/hashsync/handler_test.go b/hashsync/handler_test.go new file mode 100644 index 0000000000..7585d90f16 --- /dev/null +++ b/hashsync/handler_test.go @@ -0,0 +1,477 @@ +package hashsync + +import ( + "context" + "fmt" + "slices" + "sync/atomic" + "testing" + "time" + + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/log/logtest" + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/p2p/server" +) + +type fakeMessage struct { + data []byte + error string +} + +type fakeInteractor struct { + fr *fakeRequester + ctx context.Context + sendCh chan fakeMessage + recvCh chan fakeMessage +} + +func (i *fakeInteractor) Send(data []byte) error { + // fmt.Fprintf(os.Stderr, "%p: send %q\n", i, data) + select { + case i.sendCh <- fakeMessage{data: data}: + atomic.AddUint32(&i.fr.bytesSent, uint32(len(data))) + return nil + case <-i.ctx.Done(): + return i.ctx.Err() + } +} + +func (i *fakeInteractor) SendError(err error) error { + // fmt.Fprintf(os.Stderr, "%p: send error %q\n", i, err) + select { + case i.sendCh <- fakeMessage{error: err.Error()}: + atomic.AddUint32(&i.fr.bytesSent, uint32(len(err.Error()))) + return nil + case <-i.ctx.Done(): + return i.ctx.Err() + } +} + +func (i *fakeInteractor) Receive() ([]byte, error) { + // fmt.Fprintf(os.Stderr, "%p: receive\n", i) + var m fakeMessage + select { + case m = <-i.recvCh: + case <-i.ctx.Done(): + return nil, i.ctx.Err() + } + // fmt.Fprintf(os.Stderr, "%p: received %#v\n", i, m) + if m.error != "" { + atomic.AddUint32(&i.fr.bytesReceived, uint32(len(m.error))) + return nil, fmt.Errorf("%w: %s", server.RemoteError, m.error) + } + atomic.AddUint32(&i.fr.bytesReceived, uint32(len(m.data))) + return m.data, nil +} + +type incomingRequest struct { + sendCh chan fakeMessage + recvCh chan fakeMessage +} + +var _ server.Interactor = &fakeInteractor{} + +type fakeRequester struct { + id p2p.Peer + handler server.ServerHandler + peers map[p2p.Peer]*fakeRequester + reqCh chan incomingRequest + bytesSent uint32 + bytesReceived uint32 +} + +var _ requester = &fakeRequester{} + +func newFakeRequester(id p2p.Peer, handler server.ServerHandler, peers ...requester) *fakeRequester { + fr := &fakeRequester{ + id: id, + handler: handler, + reqCh: make(chan incomingRequest), + peers: make(map[p2p.Peer]*fakeRequester), + } + for _, p := range peers { + pfr := p.(*fakeRequester) + fr.peers[pfr.id] = pfr + } + return fr +} + +func (fr *fakeRequester) Run(ctx context.Context) error { + if fr.handler == nil { + panic("no handler") + } + for { + var req incomingRequest + select { + case <-ctx.Done(): + return nil + case req = <-fr.reqCh: + } + i := &fakeInteractor{ + fr: fr, + ctx: ctx, + sendCh: req.sendCh, + recvCh: req.recvCh, + } + fr.handler.Handle(ctx, i) + } +} + +func (fr *fakeRequester) request( + ctx context.Context, + pid p2p.Peer, + initialRequest []byte, + handler server.InteractiveHandler, +) error { + p, found := fr.peers[pid] + if !found { + return fmt.Errorf("bad peer %q", pid) + } + i := &fakeInteractor{ + fr: fr, + ctx: ctx, + sendCh: make(chan fakeMessage, 1), + recvCh: make(chan fakeMessage), + } + i.sendCh <- fakeMessage{data: initialRequest} + select { + case p.reqCh <- incomingRequest{ + sendCh: i.recvCh, + recvCh: i.sendCh, + }: + case <-ctx.Done(): + return ctx.Err() + } + _, err := handler(ctx, i) + return err +} + +func (fr *fakeRequester) InteractiveRequest( + ctx context.Context, + pid p2p.Peer, + initialRequest []byte, + handler server.InteractiveHandler, + failure func(error), +) error { + go func() { + err := fr.request(ctx, pid, initialRequest, handler) + if err != nil { + failure(err) + } + }() + return nil +} + +type sliceIterator struct { + s []Ordered +} + +var _ Iterator = &sliceIterator{} + +func (it *sliceIterator) Equal(other Iterator) bool { + // not used by wireConduit + return false +} + +func (it *sliceIterator) Key() Ordered { + if len(it.s) != 0 { + return it.s[0] + } + return nil +} + +func (it *sliceIterator) Next() { + if len(it.s) != 0 { + it.s = it.s[1:] + } +} + +type fakeSend struct { + x, y Ordered + count int + fp any + items []Ordered + endRound bool + done bool +} + +func (fs *fakeSend) send(c Conduit) error { + switch { + case fs.endRound: + return c.SendEndRound() + case fs.done: + return c.SendDone() + case len(fs.items) != 0: + items := slices.Clone(fs.items) + return c.SendItems(len(items), 2, &sliceIterator{s: items}) + case fs.x == nil || fs.y == nil: + return c.SendEmptySet() + case fs.count == 0: + return c.SendEmptyRange(fs.x, fs.y) + case fs.fp != nil: + return c.SendFingerprint(fs.x, fs.y, fs.fp, fs.count) + default: + return c.SendRangeContents(fs.x, fs.y, fs.count) + } +} + +type fakeRound struct { + name string + expectMsgs []SyncMessage + toSend []*fakeSend +} + +func (r *fakeRound) handleMessages(t *testing.T, c Conduit) error { + // fmt.Fprintf(os.Stderr, "fakeRound %q: handleMessages\n", r.name) + var msgs []SyncMessage + for { + msg, err := c.NextMessage() + if err != nil { + // fmt.Fprintf(os.Stderr, "fakeRound %q: error getting message: %v\n", r.name, err) + return fmt.Errorf("NextMessage(): %w", err) + } else if msg == nil { + // fmt.Fprintf(os.Stderr, "fakeRound %q: consumed all messages\n", r.name) + break + } + // fmt.Fprintf(os.Stderr, "fakeRound %q: got message %#v\n", r.name, msg) + msgs = append(msgs, msg) + if msg.Type() == MessageTypeDone || msg.Type() == MessageTypeEndRound { + break + } + } + require.Equal(t, r.expectMsgs, msgs, "messages for round %q", r.name) + return nil +} + +func (r *fakeRound) handleConversation(t *testing.T, c *wireConduit) error { + if err := r.handleMessages(t, c); err != nil { + return err + } + for _, s := range r.toSend { + if err := s.send(c); err != nil { + return err + } + } + return nil +} + +func makeTestHandler(t *testing.T, c *wireConduit, done chan struct{}, rounds []fakeRound) server.InteractiveHandler { + return func(ctx context.Context, i server.Interactor) (time.Duration, error) { + defer func() { + if done != nil { + close(done) + } + }() + if c == nil { + c = &wireConduit{i: i} + } else { + c.i = i + } + for _, round := range rounds { + if err := round.handleConversation(t, c); err != nil { + done = nil + return 0, err + } + } + return 0, nil + } +} + +func TestWireConduit(t *testing.T) { + hs := make([]types.Hash32, 16) + for n := range hs { + hs[n] = types.RandomHash() + } + fp := types.Hash12(hs[2][:12]) + srvHandler := makeTestHandler(t, nil, nil, []fakeRound{ + { + name: "server got 1st request", + expectMsgs: []SyncMessage{ + &FingerprintMessage{ + RangeX: hs[0], + RangeY: hs[1], + RangeFingerprint: fp, + NumItems: 4, + }, + &EndRoundMessage{}, + }, + toSend: []*fakeSend{ + { + x: hs[0], + y: hs[3], + count: 2, + }, + { + x: hs[3], + y: hs[6], + count: 2, + }, + { + items: []Ordered{hs[4], hs[5], hs[7], hs[8]}, + }, + { + endRound: true, + }, + }, + }, + { + name: "server got 2nd request", + expectMsgs: []SyncMessage{ + &ItemBatchMessage{ + Contents: []types.Hash32{hs[9], hs[10]}, + }, + &ItemBatchMessage{ + Contents: []types.Hash32{hs[11]}, + }, + &EndRoundMessage{}, + }, + toSend: []*fakeSend{ + { + done: true, + }, + }, + }, + }) + + srv := newFakeRequester("srv", srvHandler) + var eg errgroup.Group + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + cancel() + eg.Wait() + }() + eg.Go(func() error { + return srv.Run(ctx) + }) + + client := newFakeRequester("client", nil, srv) + var c wireConduit + initReq, err := c.withInitialRequest(func(c Conduit) error { + if err := c.SendFingerprint(hs[0], hs[1], fp, 4); err != nil { + return err + } + return c.SendEndRound() + }) + require.NoError(t, err) + done := make(chan struct{}) + clientHandler := makeTestHandler(t, &c, done, []fakeRound{ + { + name: "client got 1st response", + expectMsgs: []SyncMessage{ + &RangeContentsMessage{ + RangeX: hs[0], + RangeY: hs[3], + NumItems: 2, + }, + &RangeContentsMessage{ + RangeX: hs[3], + RangeY: hs[6], + NumItems: 2, + }, + &ItemBatchMessage{ + Contents: []types.Hash32{hs[4], hs[5]}, + }, + &ItemBatchMessage{ + Contents: []types.Hash32{hs[7], hs[8]}, + }, + &EndRoundMessage{}, + }, + toSend: []*fakeSend{ + { + items: []Ordered{hs[9], hs[10], hs[11]}, + }, + { + endRound: true, + }, + }, + }, + { + name: "client got 2nd response", + expectMsgs: []SyncMessage{ + &DoneMessage{}, + }, + }, + }) + err = client.InteractiveRequest(context.Background(), "srv", initReq, clientHandler, func(err error) { + require.FailNow(t, "fail handler called", "error: %v", err) + }) + require.NoError(t, err) + <-done +} + +type getRequesterFunc func(name string, handler server.InteractiveHandler, peers ...requester) (requester, p2p.Peer) + +func testWireSync(t *testing.T, getRequester getRequesterFunc) requester { + cfg := xorSyncTestConfig{ + maxSendRange: 1, + numTestHashes: 100000, + minNumSpecificA: 4, + maxNumSpecificA: 100, + minNumSpecificB: 4, + maxNumSpecificB: 100, + } + var client requester + verifyXORSync(t, cfg, func(syncA, syncB *RangeSetReconciler, numSpecific int) { + srvHandler := MakeServerHandler(syncA) + srv, srvPeerID := getRequester("srv", srvHandler) + var eg errgroup.Group + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + cancel() + eg.Wait() + }() + eg.Go(func() error { + return srv.Run(ctx) + }) + + client, _ = getRequester("client", nil, srv) + err := SyncStore(ctx, client, srvPeerID, syncB) + require.NoError(t, err) + + if fr, ok := client.(*fakeRequester); ok { + t.Logf("numSpecific: %d, bytesSent %d, bytesReceived %d", + numSpecific, fr.bytesSent, fr.bytesReceived) + } + }) + return client +} + +func TestWireSync(t *testing.T) { + t.Run("fake requester", func(t *testing.T) { + testWireSync(t, func(name string, handler server.InteractiveHandler, peers ...requester) (requester, p2p.Peer) { + pid := p2p.Peer(name) + return newFakeRequester(pid, handler, peers...), pid + }) + }) + + t.Run("p2p", func(t *testing.T) { + mesh, err := mocknet.FullMeshConnected(2) + require.NoError(t, err) + proto := "itest" + opts := []server.Opt{ + server.WithTimeout(10 * time.Second), + server.WithLog(logtest.New(t)), + } + testWireSync(t, func(name string, handler server.InteractiveHandler, peers ...requester) (requester, p2p.Peer) { + if len(peers) == 0 { + return server.New(mesh.Hosts()[0], proto, handler, opts...), mesh.Hosts()[0].ID() + } + s := server.New(mesh.Hosts()[1], proto, handler, opts...) + // TODO: this 'Eventually' is somewhat misplaced + require.Eventually(t, func() bool { + for _, h := range mesh.Hosts()[0:] { + if len(h.Mux().Protocols()) == 0 { + return false + } + } + return true + }, time.Second, 10*time.Millisecond) + return s, mesh.Hosts()[1].ID() + }) + }) +} diff --git a/hashsync/interface.go b/hashsync/interface.go new file mode 100644 index 0000000000..4b4c21d5f3 --- /dev/null +++ b/hashsync/interface.go @@ -0,0 +1,13 @@ +package hashsync + +import ( + "context" + + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/p2p/server" +) + +type requester interface { + Run(context.Context) error + InteractiveRequest(context.Context, p2p.Peer, []byte, server.InteractiveHandler, func(error)) error +} diff --git a/hashsync/monoid_tree_test.go b/hashsync/monoid_tree_test.go index 071e207381..bdc9ee155b 100644 --- a/hashsync/monoid_tree_test.go +++ b/hashsync/monoid_tree_test.go @@ -14,6 +14,7 @@ type sampleID string var _ Ordered = sampleID("") +func (s sampleID) String() string { return string(s) } func (s sampleID) Compare(other any) int { return cmp.Compare(s, other.(sampleID)) } diff --git a/hashsync/rangesync.go b/hashsync/rangesync.go index 050c43afc8..237dc7013b 100644 --- a/hashsync/rangesync.go +++ b/hashsync/rangesync.go @@ -1,44 +1,105 @@ package hashsync import ( + "errors" + "fmt" "reflect" + "strings" ) const ( - defaultMaxSendRange = 16 + defaultMaxSendRange = 16 + defaultItemChunkSize = 16 ) +type MessageType byte + +const ( + MessageTypeDone MessageType = iota + MessageTypeEndRound + MessageTypeEmptySet + MessageTypeEmptyRange + MessageTypeFingerprint + MessageTypeRangeContents + MessageTypeItemBatch +) + +var messageTypes = []string{ + "done", + "endRound", + "emptySet", + "emptyRange", + "fingerprint", + "rangeContents", + "itemBatch", +} + +func (mtype MessageType) String() string { + if int(mtype) < len(messageTypes) { + return messageTypes[mtype] + } + return fmt.Sprintf("", int(mtype)) +} + type SyncMessage interface { + Type() MessageType X() Ordered Y() Ordered Fingerprint() any Count() int - HaveItems() bool + Items() []Ordered +} + +func SyncMessageToString(m SyncMessage) string { + var sb strings.Builder + sb.WriteString("<" + m.Type().String()) + + if x := m.X(); x != nil { + sb.WriteString(" X=" + x.(fmt.Stringer).String()) + } + if y := m.Y(); y != nil { + sb.WriteString(" Y=" + y.(fmt.Stringer).String()) + } + if count := m.Count(); count != 0 { + fmt.Fprintf(&sb, " Count=%d", count) + } + if fp := m.Fingerprint(); fp != nil { + sb.WriteString(" FP=" + fp.(fmt.Stringer).String()) + } + for _, item := range m.Items() { + sb.WriteString(" item=" + item.(fmt.Stringer).String()) + } + sb.WriteString(">") + return sb.String() } // Conduit handles receiving and sending peer messages type Conduit interface { // NextMessage returns the next SyncMessage, or nil if there - // are no more SyncMessages. + // are no more SyncMessages. NextMessage is only called after + // a NextItem call indicates that there are no more items. + // NextMessage will not be called after any of Send...() + // methods is invoked NextMessage() (SyncMessage, error) - // NextItem returns the next item in the set or nil if there - // are no more items - NextItem() (Ordered, error) // SendFingerprint sends range fingerprint to the peer. // Count must be > 0 - SendFingerprint(x, y Ordered, fingerprint any, count int) + SendFingerprint(x, y Ordered, fingerprint any, count int) error // SendEmptySet notifies the peer that it we don't have any items. // The corresponding SyncMessage has Count() == 0, X() == nil and Y() == nil - SendEmptySet() + SendEmptySet() error // SendEmptyRange notifies the peer that the specified range // is empty on our side. The corresponding SyncMessage has Count() == 0 - SendEmptyRange(x, y Ordered) - // SendItems sends the local items to the peer, requesting back - // the items peer has in that range. The corresponding - // SyncMessage has HaveItems() == true - SendItems(x, y Ordered, count int, it Iterator) - // SendItemsOnly sends just items without any message - SendItemsOnly(count int, it Iterator) + SendEmptyRange(x, y Ordered) error + // SendItems notifies the peer that the corresponding range items will + // be included in this sync round. The items themselves are sent via + // SendItemsOnly + SendRangeContents(x, y Ordered, count int) error + // SendItems sends just items without any message + SendItems(count, chunkSize int, it Iterator) error + // SendEndRound sends a message that signifies the end of sync round + SendEndRound() error + // SendDone sends a message that notifies the peer that sync is finished + SendDone() error } type Option func(r *RangeSetReconciler) @@ -49,14 +110,20 @@ func WithMaxSendRange(n int) Option { } } +func WithItemChunkSize(n int) Option { + return func(r *RangeSetReconciler) { + r.itemChunkSize = n + } +} + // Iterator points to in item in ItemStore type Iterator interface { // Equal returns true if this iterator is equal to another Iterator Equal(other Iterator) bool // Key returns the key corresponding to iterator Key() Ordered - // Next returns an iterator pointing to the next key or nil - // if this key is the last one in the store + // Next advances the iterator + // TODO: should return bool Next() } @@ -82,14 +149,16 @@ type ItemStore interface { } type RangeSetReconciler struct { - is ItemStore - maxSendRange int + is ItemStore + maxSendRange int + itemChunkSize int } func NewRangeSetReconciler(is ItemStore, opts ...Option) *RangeSetReconciler { rsr := &RangeSetReconciler{ - is: is, - maxSendRange: defaultMaxSendRange, + is: is, + maxSendRange: defaultMaxSendRange, + itemChunkSize: defaultItemChunkSize, } for _, opt := range opts { opt(rsr) @@ -100,19 +169,6 @@ func NewRangeSetReconciler(is ItemStore, opts ...Option) *RangeSetReconciler { return rsr } -func (rsr *RangeSetReconciler) addItems(c Conduit) error { - for { - item, err := c.NextItem() - if err != nil { - return err - } - if item == nil { - return nil - } - rsr.is.Add(item) - } -} - // func qqqqRmmeK(it Iterator) any { // if it == nil { // return "" @@ -123,65 +179,83 @@ func (rsr *RangeSetReconciler) addItems(c Conduit) error { // return fmt.Sprintf("%s", it.Key()) // } -func (rsr *RangeSetReconciler) processSubrange(c Conduit, preceding, start, end Iterator, x, y Ordered) Iterator { +func (rsr *RangeSetReconciler) processSubrange(c Conduit, preceding, start, end Iterator, x, y Ordered) (Iterator, error) { if preceding != nil && preceding.Key().Compare(x) > 0 { preceding = nil } // fmt.Fprintf(os.Stderr, "QQQQQ: preceding=%q\n", // qqqqRmmeK(preceding)) + // TODO: don't re-request range info for the first part of range after stop info := rsr.is.GetRangeInfo(preceding, x, y, -1) // fmt.Fprintf(os.Stderr, "QQQQQ: start=%q end=%q info.Start=%q info.End=%q info.FP=%q x=%q y=%q\n", // qqqqRmmeK(start), qqqqRmmeK(end), qqqqRmmeK(info.Start), qqqqRmmeK(info.End), info.Fingerprint, x, y) switch { + // TODO: make sending items from small chunks resulting from subdivision right away an option // case info.Count != 0 && info.Count <= rsr.maxSendRange: // // If the range is small enough, we send its contents. // // The peer may have more items of its own in that range, // // so we can't use SendItemsOnly(), instead we use SendItems, // // which includes our items and asks the peer to send any // // items it has in the range. - // c.SendItems(x, y, info.Count, info.Start) + // if err := c.SendRangeContents(x, y, info.Count); err != nil { + // return nil, err + // } + // if err := c.SendItems(info.Count, rsr.itemChunkSize, info.Start); err != nil { + // return nil, err + // } case info.Count == 0: // We have no more items in this subrange. // Ask peer to send any items it has in the range - c.SendEmptyRange(x, y) + if err := c.SendEmptyRange(x, y); err != nil { + return nil, err + } default: // The range is non-empty and large enough. // Send fingerprint so that the peer can further subdivide it. - c.SendFingerprint(x, y, info.Fingerprint, info.Count) + if err := c.SendFingerprint(x, y, info.Fingerprint, info.Count); err != nil { + return nil, err + } } // fmt.Fprintf(os.Stderr, "QQQQQ: info.End=%q\n", qqqqRmmeK(info.End)) - return info.End + return info.End, nil } -func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg SyncMessage) Iterator { +func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg SyncMessage) (it Iterator, done bool, err error) { x := msg.X() y := msg.Y() - if x == nil && y == nil { + done = true + if msg.Type() == MessageTypeEmptySet { // The peer has no items at all so didn't // even send X & Y (SendEmptySet) it := rsr.is.Min() if it == nil { // We don't have any items at all, too - return nil + return nil, true, nil } x = it.Key() y = x } else if x == nil || y == nil { - // TBD: never pass just one nil when decoding!!! - panic("invalid range") + return nil, false, errors.New("bad X or Y") } info := rsr.is.GetRangeInfo(preceding, x, y, -1) // fmt.Fprintf(os.Stderr, "msg %s fp %v start %#v end %#v count %d\n", msg, info.Fingerprint, info.Start, info.End, info.Count) switch { - case msg.HaveItems() || msg.Count() == 0: + case msg.Type() == MessageTypeEmptyRange || + msg.Type() == MessageTypeRangeContents || + msg.Type() == MessageTypeEmptySet: // The peer has no more items to send in this range after this // message, as it is either empty or it has sent all of its // items in the range to us, but there may be some items on our // side. In the latter case, send only the items themselves b/c // the range doesn't need any further handling by the peer. if info.Count != 0 { - c.SendItemsOnly(info.Count, info.Start) + done = false + if err := c.SendItems(info.Count, rsr.itemChunkSize, info.Start); err != nil { + return nil, false, err + } } + case msg.Type() != MessageTypeFingerprint: + return nil, false, fmt.Errorf("unexpected message type %s", msg.Type()) case fingerprintEqual(info.Fingerprint, msg.Fingerprint()): // The range is synced // case (info.Count+1)/2 <= rsr.maxSendRange: @@ -189,12 +263,20 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg // The range differs from the peer's version of it, but the it // is small enough (or would be small enough after split) or // empty on our side + done = false if info.Count != 0 { // fmt.Fprintf(os.Stderr, "small incoming range: %s -> SendItems\n", msg) - c.SendItems(x, y, info.Count, info.Start) + if err := c.SendRangeContents(x, y, info.Count); err != nil { + return nil, false, err + } + if err := c.SendItems(info.Count, rsr.itemChunkSize, info.Start); err != nil { + return nil, false, err + } } else { // fmt.Fprintf(os.Stderr, "small incoming range: %s -> empty range msg\n", msg) - c.SendEmptyRange(x, y) + if err := c.SendEmptyRange(x, y); err != nil { + return nil, false, err + } } default: // Need to split the range. @@ -206,53 +288,119 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg panic("BUG: can't split range with count > 1") } middle := part.End.Key() - next := rsr.processSubrange(c, info.Start, part.Start, part.End, x, middle) + next, err := rsr.processSubrange(c, info.Start, part.Start, part.End, x, middle) + if err != nil { + return nil, false, err + } // fmt.Fprintf(os.Stderr, "QQQQQ: next=%q\n", qqqqRmmeK(next)) - rsr.processSubrange(c, next, part.End, info.End, middle, y) + _, err = rsr.processSubrange(c, next, part.End, info.End, middle, y) + if err != nil { + return nil, false, err + } // fmt.Fprintf(os.Stderr, "normal: split X %s - middle %s - Y %s:\n %s", // msg.X(), middle, msg.Y(), msg) + done = false } - return info.End + return info.End, done, nil } -func (rsr *RangeSetReconciler) Initiate(c Conduit) { +func (rsr *RangeSetReconciler) Initiate(c Conduit) error { it := rsr.is.Min() if it == nil { - c.SendEmptySet() - return - } - min := it.Key() - info := rsr.is.GetRangeInfo(nil, min, min, -1) - if info.Count != 0 && info.Count < rsr.maxSendRange { - c.SendItems(min, min, info.Count, info.Start) + if err := c.SendEmptySet(); err != nil { + return err + } } else { - c.SendFingerprint(min, min, info.Fingerprint, info.Count) + min := it.Key() + info := rsr.is.GetRangeInfo(nil, min, min, -1) + switch { + case info.Count == 0: + panic("empty full min-min range") + case info.Count < rsr.maxSendRange: + if err := c.SendRangeContents(min, min, info.Count); err != nil { + return err + } + if err := c.SendItems(info.Count, rsr.itemChunkSize, it); err != nil { + return err + } + default: + if err := c.SendFingerprint(min, min, info.Fingerprint, info.Count); err != nil { + return err + } + } + } + if err := c.SendEndRound(); err != nil { + return err } + return nil } -func (rsr *RangeSetReconciler) Process(c Conduit) error { - var msgs []SyncMessage +func (rsr *RangeSetReconciler) getMessages(c Conduit) (msgs []SyncMessage, done bool, err error) { for { msg, err := c.NextMessage() - if err != nil { - return err + switch { + case err != nil: + return nil, false, err + case msg == nil: + return nil, false, errors.New("no end round marker") + default: + switch msg.Type() { + case MessageTypeEndRound: + return msgs, false, nil + case MessageTypeDone: + return msgs, true, nil + default: + msgs = append(msgs, msg) + } } - if msg == nil { - break + } +} + +func (rsr *RangeSetReconciler) Process(c Conduit) (done bool, err error) { + var msgs []SyncMessage + // All of the messages need to be received before processing + // them, as processing the messages involves sending more + // messages back to the peer + msgs, done, err = rsr.getMessages(c) + if err != nil { + return false, err + } + if done { + // items already added + if len(msgs) != 0 { + return false, errors.New("non-item messages with 'done' marker") } - msgs = append(msgs, msg) + return done, nil } + done = true + for _, msg := range msgs { + // TODO: pass preceding range, should be safe as the iterator is checked + if msg.Type() == MessageTypeItemBatch { + for _, item := range msg.Items() { + rsr.is.Add(item) + } + continue + } - if err := rsr.addItems(c); err != nil { - return err + _, msgDone, err := rsr.handleMessage(c, nil, msg) + if err != nil { + return false, err + } + if !msgDone { + done = false + } } - for _, msg := range msgs { - // TODO: need to sort the ranges, but also need to be careful - rsr.handleMessage(c, nil, msg) + if done { + err = c.SendDone() + } else { + err = c.SendEndRound() } - return nil + if err != nil { + return false, err + } + return done, nil } func fingerprintEqual(a, b any) bool { @@ -261,6 +409,17 @@ func fingerprintEqual(a, b any) bool { return reflect.DeepEqual(a, b) } +// TBD: !!! use wire types instead of multiple Send* methods in the Conduit interface !!! +// TBD: !!! queue outbound messages right in RangeSetReconciler while processing msgs, and no need for done in handleMessage this way ++ no need for complicated logic on the conduit part !!! +// TBD: !!! check that done message present !!! +// Note: can't just use send/recv channels instead of Conduit b/c Receive must be an explicit +// operation done via the underlying Interactor +// TBD: SyncTree +// * rename to SyncTree +// * rm Monoid stuff, use Hash32 for values and Hash12 for fingerprints +// * pass single chars as Hash32 for testing +// * track hashing and XORing during tests to recover the fingerprint substring in tests +// (but not during XOR test!) // TBD: successive messages with payloads can be combined! // TBD: limit the number of rounds (outside RangeSetReconciler) // TBD: process ascending ranges properly diff --git a/hashsync/rangesync_test.go b/hashsync/rangesync_test.go index 0b678a3d42..48ed90c73a 100644 --- a/hashsync/rangesync_test.go +++ b/hashsync/rangesync_test.go @@ -1,7 +1,6 @@ package hashsync import ( - "fmt" "math/rand" "slices" "testing" @@ -12,43 +11,40 @@ import ( ) type rangeMessage struct { - x, y Ordered - fp any - count int - haveItems bool + mtype MessageType + x, y Ordered + fp any + count int + items []Ordered } -func (m rangeMessage) X() Ordered { return m.x } -func (m rangeMessage) Y() Ordered { return m.y } -func (m rangeMessage) Fingerprint() any { return m.fp } -func (m rangeMessage) Count() int { return m.count } -func (m rangeMessage) HaveItems() bool { return m.haveItems } +func (m rangeMessage) Type() MessageType { return m.mtype } +func (m rangeMessage) X() Ordered { return m.x } +func (m rangeMessage) Y() Ordered { return m.y } +func (m rangeMessage) Fingerprint() any { return m.fp } +func (m rangeMessage) Count() int { return m.count } +func (m rangeMessage) Items() []Ordered { return m.items } var _ SyncMessage = rangeMessage{} func (m rangeMessage) String() string { - itemsStr := "" - if m.haveItems { - itemsStr = fmt.Sprintf(" +items") - } - return fmt.Sprintf("", - m.x, m.y, m.count, m.fp, itemsStr) + return SyncMessageToString(m) } type fakeConduit struct { - t *testing.T - msgs []rangeMessage - items []Ordered - resp *fakeConduit + t *testing.T + msgs []rangeMessage + resp *fakeConduit } var _ Conduit = &fakeConduit{} -func (fc *fakeConduit) done() bool { - if fc.resp == nil { - return true +func (fc *fakeConduit) numItems() int { + n := 0 + for _, m := range fc.msgs { + n += len(m.Items()) } - return len(fc.resp.msgs) == 0 && len(fc.resp.items) == 0 + return n } func (fc *fakeConduit) NextMessage() (SyncMessage, error) { @@ -61,75 +57,81 @@ func (fc *fakeConduit) NextMessage() (SyncMessage, error) { return nil, nil } -func (fc *fakeConduit) NextItem() (Ordered, error) { - if len(fc.items) != 0 { - item := fc.items[0] - fc.items = fc.items[1:] - return item, nil - } - - return nil, nil -} - func (fc *fakeConduit) ensureResp() { if fc.resp == nil { fc.resp = &fakeConduit{t: fc.t} } } -func (fc *fakeConduit) sendMsg(x, y Ordered, fingerprint any, count int, haveItems bool) { +func (fc *fakeConduit) sendMsg(mtype MessageType, x, y Ordered, fingerprint any, count int) { fc.ensureResp() msg := rangeMessage{ - x: x, - y: y, - fp: fingerprint, - count: count, - haveItems: haveItems, + mtype: mtype, + x: x, + y: y, + fp: fingerprint, + count: count, } fc.resp.msgs = append(fc.resp.msgs, msg) } -func (fc *fakeConduit) sendItems(count int, it Iterator) { - require.NotZero(fc.t, count) - require.NotNil(fc.t, it) - fc.ensureResp() - for i := 0; i < count; i++ { - if it.Key() == nil { - panic("fakeConduit.SendItems: went got to the end of the tree") - } - fc.resp.items = append(fc.resp.items, it.Key()) - it.Next() - } -} - -func (fc *fakeConduit) SendFingerprint(x, y Ordered, fingerprint any, count int) { +func (fc *fakeConduit) SendFingerprint(x, y Ordered, fingerprint any, count int) error { require.NotNil(fc.t, x) require.NotNil(fc.t, y) require.NotZero(fc.t, count) require.NotNil(fc.t, fingerprint) - fc.sendMsg(x, y, fingerprint, count, false) + fc.sendMsg(MessageTypeFingerprint, x, y, fingerprint, count) + return nil } -func (fc *fakeConduit) SendEmptySet() { - fc.sendMsg(nil, nil, nil, 0, false) +func (fc *fakeConduit) SendEmptySet() error { + fc.sendMsg(MessageTypeEmptySet, nil, nil, nil, 0) + return nil } -func (fc *fakeConduit) SendEmptyRange(x, y Ordered) { +func (fc *fakeConduit) SendEmptyRange(x, y Ordered) error { require.NotNil(fc.t, x) require.NotNil(fc.t, y) - fc.sendMsg(x, y, nil, 0, false) + fc.sendMsg(MessageTypeEmptyRange, x, y, nil, 0) + return nil } -func (fc *fakeConduit) SendItems(x, y Ordered, count int, it Iterator) { - require.Positive(fc.t, count) +func (fc *fakeConduit) SendRangeContents(x, y Ordered, count int) error { require.NotNil(fc.t, x) require.NotNil(fc.t, y) - fc.sendMsg(x, y, nil, count, true) - fc.sendItems(count, it) + fc.sendMsg(MessageTypeRangeContents, x, y, nil, count) + return nil +} + +func (fc *fakeConduit) SendItems(count, itemChunkSize int, it Iterator) error { + require.Positive(fc.t, count) + require.NotZero(fc.t, count) + require.NotNil(fc.t, it) + fc.ensureResp() + for i := 0; i < count; i += itemChunkSize { + msg := rangeMessage{mtype: MessageTypeItemBatch} + n := min(itemChunkSize, count-i) + for n > 0 { + if it.Key() == nil { + panic("fakeConduit.SendItems: went got to the end of the tree") + } + msg.items = append(msg.items, it.Key()) + it.Next() + n-- + } + fc.resp.msgs = append(fc.resp.msgs, msg) + } + return nil +} + +func (fc *fakeConduit) SendEndRound() error { + fc.sendMsg(MessageTypeEndRound, nil, nil, nil, 0) + return nil } -func (fc *fakeConduit) SendItemsOnly(count int, it Iterator) { - fc.sendItems(count, it) +func (fc *fakeConduit) SendDone() error { + fc.sendMsg(MessageTypeDone, nil, nil, nil, 0) + return nil } type dumbStoreIterator struct { @@ -467,25 +469,32 @@ func dumpRangeMessages(t *testing.T, msgs []rangeMessage, fmt string, args ...an func runSync(t *testing.T, syncA, syncB *RangeSetReconciler, maxRounds int) (nRounds, nMsg, nItems int) { fc := &fakeConduit{t: t} syncA.Initiate(fc) - require.False(t, fc.done(), "no messages from Initiate") var i int - for i = 0; !fc.done(); i++ { + done := false + // dumpRangeMessages(t, fc.resp.msgs, "A %q -> B %q (init):", storeItemStr(syncA.is), storeItemStr(syncB.is)) + // dumpRangeMessages(t, fc.resp.msgs, "A -> B (init):") + for i = 0; !done; i++ { if i == maxRounds { require.FailNow(t, "too many rounds", "didn't reconcile in %d rounds", i) } fc = fc.resp - // dumpRangeMessages(t, fc.msgs, "A %q -> B %q:", storeItemStr(syncA.is), storeItemStr(syncB.is)) nMsg += len(fc.msgs) - nItems += len(fc.items) - syncB.Process(fc) - if fc.done() { + nItems += fc.numItems() + var err error + done, err = syncB.Process(fc) + require.NoError(t, err) + // dumpRangeMessages(t, fc.resp.msgs, "B %q -> A %q:", storeItemStr(syncA.is), storeItemStr(syncB.is)) + // dumpRangeMessages(t, fc.resp.msgs, "B -> A:") + if done { break } fc = fc.resp nMsg += len(fc.msgs) - nItems += len(fc.items) - // dumpRangeMessages(t, fc.msgs, "B %q --> A %q:", storeItemStr(syncB.is), storeItemStr(syncA.is)) - syncA.Process(fc) + nItems += fc.numItems() + done, err = syncA.Process(fc) + require.NoError(t, err) + // dumpRangeMessages(t, fc.msgs, "A %q --> B %q:", storeItemStr(syncB.is), storeItemStr(syncA.is)) + // dumpRangeMessages(t, fc.resp.msgs, "A -> B:") } return i + 1, nMsg, nItems } @@ -545,10 +554,14 @@ func testRangeSync(t *testing.T, storeFactory storeFactory) { t.Logf("maxSendRange: %d", maxSendRange) storeA := makeStore(t, storeFactory, tc.a) disableReAdd(storeA) - syncA := NewRangeSetReconciler(storeA, WithMaxSendRange(maxSendRange)) + syncA := NewRangeSetReconciler(storeA, + WithMaxSendRange(maxSendRange), + WithItemChunkSize(3)) storeB := makeStore(t, storeFactory, tc.b) disableReAdd(storeB) - syncB := NewRangeSetReconciler(storeB, WithMaxSendRange(maxSendRange)) + syncB := NewRangeSetReconciler(storeB, + WithMaxSendRange(maxSendRange), + WithItemChunkSize(3)) nRounds, _, _ := runSync(t, syncA, syncB, tc.maxRounds[n]) t.Logf("%s: maxSendRange %d: %d rounds", tc.name, maxSendRange, nRounds) @@ -600,8 +613,12 @@ func testRandomSync(t *testing.T, storeFactory storeFactory) { slices.Sort(expectedSet) maxSendRange := rand.Intn(16) + 1 - syncA := NewRangeSetReconciler(storeA, WithMaxSendRange(maxSendRange)) - syncB := NewRangeSetReconciler(storeB, WithMaxSendRange(maxSendRange)) + syncA := NewRangeSetReconciler(storeA, + WithMaxSendRange(maxSendRange), + WithItemChunkSize(3)) + syncB := NewRangeSetReconciler(storeB, + WithMaxSendRange(maxSendRange), + WithItemChunkSize(3)) runSync(t, syncA, syncB, max(len(expectedSet), 2)) // FIXME: less rounds! // t.Logf("maxSendRange %d a %d b %d n %d", maxSendRange, len(bytesA), len(bytesB), n) @@ -615,7 +632,6 @@ func TestRandomSync(t *testing.T) { forTestStores(t, testRandomSync) } -// TBD: test XOR + big sync // TBD: include initiate round!!! // TBD: use logger for verbose logging (messages) // TBD: in fakeConduit -- check item count against the iterator in SendItems / SendItemsOnly!! diff --git a/hashsync/wire_types.go b/hashsync/wire_types.go new file mode 100644 index 0000000000..c5b6f5a5a3 --- /dev/null +++ b/hashsync/wire_types.go @@ -0,0 +1,109 @@ +package hashsync + +import ( + "github.com/spacemeshos/go-spacemesh/common/types" +) + +//go:generate scalegen + +type Marker struct{} + +func (*Marker) X() Ordered { return nil } +func (*Marker) Y() Ordered { return nil } +func (*Marker) Fingerprint() any { return nil } +func (*Marker) Count() int { return 0 } +func (*Marker) Items() []Ordered { return nil } + +// DoneMessage is a SyncMessage that denotes the end of the synchronization. +// The peer should stop any further processing after receiving this message. +type DoneMessage struct{ Marker } + +var _ SyncMessage = &DoneMessage{} + +func (*DoneMessage) Type() MessageType { return MessageTypeDone } + +// EndRoundMessage is a SyncMessage that denotes the end of the sync round. +type EndRoundMessage struct{ Marker } + +var _ SyncMessage = &EndRoundMessage{} + +func (*EndRoundMessage) Type() MessageType { return MessageTypeEndRound } + +// EmptySetMessage is a SyncMessage that denotes an empty set, requesting the +// peer to send all of its items +type EmptySetMessage struct{ Marker } + +var _ SyncMessage = &EmptySetMessage{} + +func (*EmptySetMessage) Type() MessageType { return MessageTypeEmptySet } + +// EmptyRangeMessage notifies the peer that it needs to send all of its items in +// the specified range +type EmptyRangeMessage struct { + RangeX, RangeY types.Hash32 +} + +var _ SyncMessage = &EmptyRangeMessage{} + +func (m *EmptyRangeMessage) Type() MessageType { return MessageTypeEmptyRange } +func (m *EmptyRangeMessage) X() Ordered { return m.RangeX } +func (m *EmptyRangeMessage) Y() Ordered { return m.RangeY } +func (m *EmptyRangeMessage) Fingerprint() any { return nil } +func (m *EmptyRangeMessage) Count() int { return 0 } +func (m *EmptyRangeMessage) Items() []Ordered { return nil } + +// FingerprintMessage contains range fingerprint for comparison against the +// peer's fingerprint of the range with the same bounds [RangeX, RangeY) +type FingerprintMessage struct { + RangeX, RangeY types.Hash32 + RangeFingerprint types.Hash12 + NumItems uint32 +} + +var _ SyncMessage = &FingerprintMessage{} + +func (m *FingerprintMessage) Type() MessageType { return MessageTypeFingerprint } +func (m *FingerprintMessage) X() Ordered { return m.RangeX } +func (m *FingerprintMessage) Y() Ordered { return m.RangeY } +func (m *FingerprintMessage) Fingerprint() any { return m.RangeFingerprint } +func (m *FingerprintMessage) Count() int { return int(m.NumItems) } +func (m *FingerprintMessage) Items() []Ordered { return nil } + +// RangeContentsMessage denotes a range for which the set of items has been sent. +// The peer needs to send back any items it has in the same range bounded +// by [RangeX, RangeY) +type RangeContentsMessage struct { + RangeX, RangeY types.Hash32 + NumItems uint32 +} + +var _ SyncMessage = &RangeContentsMessage{} + +func (m *RangeContentsMessage) Type() MessageType { return MessageTypeRangeContents } +func (m *RangeContentsMessage) X() Ordered { return m.RangeX } +func (m *RangeContentsMessage) Y() Ordered { return m.RangeY } +func (m *RangeContentsMessage) Fingerprint() any { return nil } +func (m *RangeContentsMessage) Count() int { return int(m.NumItems) } +func (m *RangeContentsMessage) Items() []Ordered { return nil } + +// ItemBatchMessage denotes a batch of items to be added to the peer's set +type ItemBatchMessage struct { + Contents []types.Hash32 `scale:"max=1024"` +} + +var _ SyncMessage = &ItemBatchMessage{} + +func (m *ItemBatchMessage) Type() MessageType { return MessageTypeItemBatch } +func (m *ItemBatchMessage) X() Ordered { return nil } +func (m *ItemBatchMessage) Y() Ordered { return nil } +func (m *ItemBatchMessage) Fingerprint() any { return nil } +func (m *ItemBatchMessage) Count() int { return 0 } +func (m *ItemBatchMessage) Items() []Ordered { + r := make([]Ordered, len(m.Contents)) + for n, item := range m.Contents { + r[n] = item + } + return r +} + +// TODO: don't do scalegen for empty types diff --git a/hashsync/wire_types_scale.go b/hashsync/wire_types_scale.go new file mode 100644 index 0000000000..bd51c04cdc --- /dev/null +++ b/hashsync/wire_types_scale.go @@ -0,0 +1,258 @@ +// Code generated by github.com/spacemeshos/go-scale/scalegen. DO NOT EDIT. + +// nolint +package hashsync + +import ( + "github.com/spacemeshos/go-scale" + "github.com/spacemeshos/go-spacemesh/common/types" +) + +func (t *Marker) EncodeScale(enc *scale.Encoder) (total int, err error) { + return total, nil +} + +func (t *Marker) DecodeScale(dec *scale.Decoder) (total int, err error) { + return total, nil +} + +func (t *DoneMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := t.Marker.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *DoneMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := t.Marker.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *EndRoundMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := t.Marker.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *EndRoundMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := t.Marker.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *EmptySetMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := t.Marker.EncodeScale(enc) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *EmptySetMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := t.Marker.DecodeScale(dec) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *EmptyRangeMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeByteArray(enc, t.RangeX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.RangeY[:]) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *EmptyRangeMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := scale.DecodeByteArray(dec, t.RangeX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.RangeY[:]) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *FingerprintMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeByteArray(enc, t.RangeX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.RangeY[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.RangeFingerprint[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeCompact32(enc, uint32(t.NumItems)) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *FingerprintMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := scale.DecodeByteArray(dec, t.RangeX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.RangeY[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.RangeFingerprint[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeCompact32(dec) + if err != nil { + return total, err + } + total += n + t.NumItems = uint32(field) + } + return total, nil +} + +func (t *RangeContentsMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeByteArray(enc, t.RangeX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.RangeY[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeCompact32(enc, uint32(t.NumItems)) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *RangeContentsMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + n, err := scale.DecodeByteArray(dec, t.RangeX[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.DecodeByteArray(dec, t.RangeY[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeCompact32(dec) + if err != nil { + return total, err + } + total += n + t.NumItems = uint32(field) + } + return total, nil +} + +func (t *ItemBatchMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.Contents, 1024) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *ItemBatchMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + field, n, err := scale.DecodeStructSliceWithLimit[types.Hash32](dec, 1024) + if err != nil { + return total, err + } + total += n + t.Contents = field + } + return total, nil +} diff --git a/hashsync/xorsync.go b/hashsync/xorsync.go index 88063a0054..e14773a21d 100644 --- a/hashsync/xorsync.go +++ b/hashsync/xorsync.go @@ -46,8 +46,6 @@ func (m Hash32To12Xor) Fingerprint(v any) any { // TODO: fix types.CalcHash12() h := v.(types.Hash32) var r types.Hash12 - // copy(r[:], h[20:]) - // return r hasher := hashPool.Get().(*blake3.Hasher) defer func() { hasher.Reset() diff --git a/hashsync/xorsync_test.go b/hashsync/xorsync_test.go index 80e6cce385..67a4c95eae 100644 --- a/hashsync/xorsync_test.go +++ b/hashsync/xorsync_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -43,10 +44,6 @@ func collectStoreItems[T Ordered](is ItemStore) (r []T) { } } -const numTestHashes = 100000 - -// const numTestHashes = 100 - type catchTransferTwice struct { ItemStore t *testing.T @@ -56,7 +53,7 @@ type catchTransferTwice struct { func (s *catchTransferTwice) Add(k Ordered) { h := k.(types.Hash32) _, found := s.added[h] - require.False(s.t, found, "hash sent twice") + assert.False(s.t, found, "hash sent twice") s.ItemStore.Add(k) if s.added == nil { s.added = make(map[types.Hash32]bool) @@ -64,50 +61,65 @@ func (s *catchTransferTwice) Add(k Ordered) { s.added[h] = true } -const xorTestMaxSendRange = 1 +type xorSyncTestConfig struct { + maxSendRange int + numTestHashes int + minNumSpecificA int + maxNumSpecificA int + minNumSpecificB int + maxNumSpecificB int +} -func TestBigSyncHash32(t *testing.T) { - numSpecificA := rand.Intn(96) + 4 - numSpecificB := rand.Intn(96) + 4 - // numSpecificA := rand.Intn(6) + 4 - // numSpecificB := rand.Intn(6) + 4 - src := make([]types.Hash32, numTestHashes) +func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(syncA, syncB *RangeSetReconciler, numSpecific int)) { + numSpecificA := rand.Intn(cfg.maxNumSpecificA+1-cfg.minNumSpecificA) + cfg.minNumSpecificA + numSpecificB := rand.Intn(cfg.maxNumSpecificB+1-cfg.minNumSpecificB) + cfg.minNumSpecificB + src := make([]types.Hash32, cfg.numTestHashes) for n := range src { src[n] = types.RandomHash() } - sliceA := src[:numTestHashes-numSpecificB] + sliceA := src[:cfg.numTestHashes-numSpecificB] storeA := NewMonoidTreeStore(Hash32To12Xor{}) for _, h := range sliceA { storeA.Add(h) } storeA = &catchTransferTwice{t: t, ItemStore: storeA} - syncA := NewRangeSetReconciler(storeA, WithMaxSendRange(xorTestMaxSendRange)) + syncA := NewRangeSetReconciler(storeA, WithMaxSendRange(cfg.maxSendRange)) - sliceB := append([]types.Hash32(nil), src[:numTestHashes-numSpecificB-numSpecificA]...) - sliceB = append(sliceB, src[numTestHashes-numSpecificB:]...) + sliceB := append([]types.Hash32(nil), src[:cfg.numTestHashes-numSpecificB-numSpecificA]...) + sliceB = append(sliceB, src[cfg.numTestHashes-numSpecificB:]...) storeB := NewMonoidTreeStore(Hash32To12Xor{}) for _, h := range sliceB { storeB.Add(h) } storeB = &catchTransferTwice{t: t, ItemStore: storeB} - syncB := NewRangeSetReconciler(storeB, WithMaxSendRange(xorTestMaxSendRange)) - - nRounds, nMsg, nItems := runSync(t, syncA, syncB, 100) - excess := float64(nItems-numSpecificA-numSpecificB) / float64(numSpecificA+numSpecificB) - t.Logf("numSpecificA: %d, numSpecificB: %d, nRounds: %d, nMsg: %d, nItems: %d, excess: %.2f", - numSpecificA, numSpecificB, nRounds, nMsg, nItems, excess) + syncB := NewRangeSetReconciler(storeB, WithMaxSendRange(cfg.maxSendRange)) slices.SortFunc(src, func(a, b types.Hash32) int { return a.Compare(b) }) + + sync(syncA, syncB, numSpecificA+numSpecificB) + itemsA := collectStoreItems[types.Hash32](storeA) itemsB := collectStoreItems[types.Hash32](storeB) require.Equal(t, itemsA, itemsB) require.Equal(t, src, itemsA) } -// TODO: try catching items sent twice in a simpler test -// TODO: check why insertion takes so long (1000000 items => too long wait) -// TODO: number of items transferred is unreasonable for 100k total / 1 range size: -// xorsync_test.go:56: numSpecificA: 141, numSpecificB: 784, nRounds: 11, nMsg: 13987, nItems: 3553 +func TestBigSyncHash32(t *testing.T) { + cfg := xorSyncTestConfig{ + maxSendRange: 1, + numTestHashes: 100000, + minNumSpecificA: 4, + maxNumSpecificA: 100, + minNumSpecificB: 4, + maxNumSpecificB: 100, + } + verifyXORSync(t, cfg, func(syncA, syncB *RangeSetReconciler, numSpecific int) { + nRounds, nMsg, nItems := runSync(t, syncA, syncB, 100) + itemCoef := float64(nItems) / float64(numSpecific) + t.Logf("numSpecific: %d, nRounds: %d, nMsg: %d, nItems: %d, itemCoef: %.2f", + numSpecific, nRounds, nMsg, nItems, itemCoef) + }) +} From a043c878af4bcf8bfec86176fac29caaec98ba12 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 10 Jan 2024 23:27:40 +0400 Subject: [PATCH 10/76] hashsync: add items to the store even in case of errors Not tested yet --- hashsync/rangesync.go | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/hashsync/rangesync.go b/hashsync/rangesync.go index 237dc7013b..afdbdbc468 100644 --- a/hashsync/rangesync.go +++ b/hashsync/rangesync.go @@ -340,9 +340,9 @@ func (rsr *RangeSetReconciler) getMessages(c Conduit) (msgs []SyncMessage, done msg, err := c.NextMessage() switch { case err != nil: - return nil, false, err + return msgs, false, err case msg == nil: - return nil, false, errors.New("no end round marker") + return msgs, false, errors.New("no end round marker") default: switch msg.Type() { case MessageTypeEndRound: @@ -362,9 +362,6 @@ func (rsr *RangeSetReconciler) Process(c Conduit) (done bool, err error) { // them, as processing the messages involves sending more // messages back to the peer msgs, done, err = rsr.getMessages(c) - if err != nil { - return false, err - } if done { // items already added if len(msgs) != 0 { @@ -374,7 +371,6 @@ func (rsr *RangeSetReconciler) Process(c Conduit) (done bool, err error) { } done = true for _, msg := range msgs { - // TODO: pass preceding range, should be safe as the iterator is checked if msg.Type() == MessageTypeItemBatch { for _, item := range msg.Items() { rsr.is.Add(item) @@ -382,15 +378,26 @@ func (rsr *RangeSetReconciler) Process(c Conduit) (done bool, err error) { continue } - _, msgDone, err := rsr.handleMessage(c, nil, msg) + // If there was an error, just add any items received, + // but ignore other messages if err != nil { - return false, err + continue } + + // TODO: pass preceding range. Somehow, currently the code + // breaks if we capture the iterator from handleMessage and + // pass it to the next handleMessage call (it shouldn't) + var msgDone bool + _, msgDone, err = rsr.handleMessage(c, nil, msg) if !msgDone { done = false } } + if err != nil { + return false, err + } + if done { err = c.SendDone() } else { @@ -409,6 +416,7 @@ func fingerprintEqual(a, b any) bool { return reflect.DeepEqual(a, b) } +// TBD: test: add items to the store even in case of NextMessage() failure // TBD: !!! use wire types instead of multiple Send* methods in the Conduit interface !!! // TBD: !!! queue outbound messages right in RangeSetReconciler while processing msgs, and no need for done in handleMessage this way ++ no need for complicated logic on the conduit part !!! // TBD: !!! check that done message present !!! From b8cc6b010b756c6abbcb589c0c6858051ce693fe Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 11 Jan 2024 00:24:47 +0400 Subject: [PATCH 11/76] hashsync: rename MonoidTree to SyncTree --- hashsync/monoid_tree_store.go | 90 ---- hashsync/rangesync_test.go | 12 +- hashsync/{monoid_tree.go => sync_tree.go} | 508 +++++++++--------- ...{monoid_tree_test.go => sync_tree_test.go} | 112 ++-- hashsync/sync_trees_store.go | 90 ++++ hashsync/xorsync_test.go | 4 +- 6 files changed, 408 insertions(+), 408 deletions(-) delete mode 100644 hashsync/monoid_tree_store.go rename hashsync/{monoid_tree.go => sync_tree.go} (56%) rename hashsync/{monoid_tree_test.go => sync_tree_test.go} (84%) create mode 100644 hashsync/sync_trees_store.go diff --git a/hashsync/monoid_tree_store.go b/hashsync/monoid_tree_store.go deleted file mode 100644 index 4ea94b0558..0000000000 --- a/hashsync/monoid_tree_store.go +++ /dev/null @@ -1,90 +0,0 @@ -package hashsync - -type monoidTreeIterator struct { - mt MonoidTree - ptr MonoidTreePointer -} - -var _ Iterator = &monoidTreeIterator{} - -func (it *monoidTreeIterator) Equal(other Iterator) bool { - o := other.(*monoidTreeIterator) - if it.mt != o.mt { - panic("comparing iterators from different MonoidTreeStore") - } - return it.ptr.Equal(o.ptr) -} - -func (it *monoidTreeIterator) Key() Ordered { - return it.ptr.Key() -} - -func (it *monoidTreeIterator) Next() { - it.ptr.Next() - if it.ptr.Key() == nil { - it.ptr = it.mt.Min() - } -} - -type MonoidTreeStore struct { - mt MonoidTree -} - -var _ ItemStore = &MonoidTreeStore{} - -func NewMonoidTreeStore(m Monoid) ItemStore { - return &MonoidTreeStore{ - mt: NewMonoidTree(CombineMonoids(m, CountingMonoid{})), - } -} - -// Add implements ItemStore. -func (mts *MonoidTreeStore) Add(k Ordered) { - mts.mt.Add(k) -} - -func (mts *MonoidTreeStore) iter(ptr MonoidTreePointer) Iterator { - if ptr == nil { - return nil - } - return &monoidTreeIterator{ - mt: mts.mt, - ptr: ptr, - } -} - -// GetRangeInfo implements ItemStore. -func (mts *MonoidTreeStore) GetRangeInfo(preceding Iterator, x Ordered, y Ordered, count int) RangeInfo { - var stop FingerprintPredicate - var node MonoidTreePointer - if preceding != nil { - p := preceding.(*monoidTreeIterator) - if p.mt != mts.mt { - panic("GetRangeInfo: preceding iterator from a wrong MonoidTreeStore") - } - node = p.ptr - } - if count >= 0 { - stop = func(fp any) bool { - return CombinedSecond[int](fp) > count - } - } - fp, startPtr, endPtr := mts.mt.RangeFingerprint(node, x, y, stop) - cfp := fp.(CombinedFingerprint) - return RangeInfo{ - Fingerprint: cfp.First, - Count: cfp.Second.(int), - Start: mts.iter(startPtr), - End: mts.iter(endPtr), - } -} - -// Min implements ItemStore. -func (mts *MonoidTreeStore) Min() Iterator { - return mts.iter(mts.mt.Min()) -} - -// Max implements ItemStore. -func (mts *MonoidTreeStore) Max() Iterator { - return mts.iter(mts.mt.Max()) -} diff --git a/hashsync/rangesync_test.go b/hashsync/rangesync_test.go index 48ed90c73a..4b5f2a29db 100644 --- a/hashsync/rangesync_test.go +++ b/hashsync/rangesync_test.go @@ -396,15 +396,15 @@ func makeDumbStore(t *testing.T) ItemStore { return &dumbStore{} } -func makeMonoidTreeStore(t *testing.T) ItemStore { - return NewMonoidTreeStore(sampleMonoid{}) +func makeSyncTreeStore(t *testing.T) ItemStore { + return NewSyncTreeStore(sampleMonoid{}) } -func makeVerifiedMonoidTreeStore(t *testing.T) ItemStore { +func makeVerifiedSyncTreeStore(t *testing.T) ItemStore { return &verifiedStore{ t: t, knownGood: makeDumbStore(t), - store: makeMonoidTreeStore(t), + store: makeSyncTreeStore(t), } } @@ -442,11 +442,11 @@ var testStores = []struct { }, { name: "monoid tree store", - factory: makeMonoidTreeStore, + factory: makeSyncTreeStore, }, { name: "verified monoid tree store", - factory: makeVerifiedMonoidTreeStore, + factory: makeVerifiedSyncTreeStore, }, } diff --git a/hashsync/monoid_tree.go b/hashsync/sync_tree.go similarity index 56% rename from hashsync/monoid_tree.go rename to hashsync/sync_tree.go index 3b0133455c..181fa68ba8 100644 --- a/hashsync/monoid_tree.go +++ b/hashsync/sync_tree.go @@ -31,39 +31,39 @@ func (fpred FingerprintPredicate) Match(y any) bool { return fpred != nil && fpred(y) } -type MonoidTree interface { - Copy() MonoidTree +type SyncTree interface { + Copy() SyncTree Fingerprint() any Add(k Ordered) Set(k Ordered, v any) Lookup(k Ordered) (any, bool) - Min() MonoidTreePointer - Max() MonoidTreePointer - RangeFingerprint(ptr MonoidTreePointer, start, end Ordered, stop FingerprintPredicate) (fp any, startNode, endNode MonoidTreePointer) + Min() SyncTreePointer + Max() SyncTreePointer + RangeFingerprint(ptr SyncTreePointer, start, end Ordered, stop FingerprintPredicate) (fp any, startNode, endNode SyncTreePointer) Dump() string } -func MonoidTreeFromSortedSlice[T Ordered](m Monoid, items []T) MonoidTree { +func SyncTreeFromSortedSlice[T Ordered](m Monoid, items []T) SyncTree { s := make([]Ordered, len(items)) for n, item := range items { s[n] = item } - mt := NewMonoidTree(m).(*monoidTree) - mt.root = mt.buildFromSortedSlice(nil, s) - return mt + st := NewSyncTree(m).(*syncTree) + st.root = st.buildFromSortedSlice(nil, s) + return st } -func MonoidTreeFromSlice[T Ordered](m Monoid, items []T) MonoidTree { +func SyncTreeFromSlice[T Ordered](m Monoid, items []T) SyncTree { sorted := make([]T, len(items)) copy(sorted, items) slices.SortFunc(sorted, func(a, b T) int { return a.Compare(b) }) - return MonoidTreeFromSortedSlice(m, items) + return SyncTreeFromSortedSlice(m, items) } -type MonoidTreePointer interface { - Equal(other MonoidTreePointer) bool +type SyncTreePointer interface { + Equal(other SyncTreePointer) bool Key() Ordered Value() any Prev() @@ -105,24 +105,24 @@ func (d dir) String() string { const initialParentStackSize = 32 -type monoidTreePointer struct { - parentStack []*monoidTreeNode - node *monoidTreeNode +type syncTreePointer struct { + parentStack []*syncTreeNode + node *syncTreeNode } -var _ MonoidTreePointer = &monoidTreePointer{} +var _ SyncTreePointer = &syncTreePointer{} -func (p *monoidTreePointer) clone() *monoidTreePointer { +func (p *syncTreePointer) clone() *syncTreePointer { // TODO: copy node stack - r := &monoidTreePointer{ - parentStack: make([]*monoidTreeNode, len(p.parentStack), cap(p.parentStack)), + r := &syncTreePointer{ + parentStack: make([]*syncTreeNode, len(p.parentStack), cap(p.parentStack)), node: p.node, } copy(r.parentStack, p.parentStack) return r } -func (p *monoidTreePointer) parent() { +func (p *syncTreePointer) parent() { n := len(p.parentStack) if n == 0 { p.node = nil @@ -133,21 +133,21 @@ func (p *monoidTreePointer) parent() { } } -func (p *monoidTreePointer) left() { +func (p *syncTreePointer) left() { if p.node != nil { p.parentStack = append(p.parentStack, p.node) p.node = p.node.left } } -func (p *monoidTreePointer) right() { +func (p *syncTreePointer) right() { if p.node != nil { p.parentStack = append(p.parentStack, p.node) p.node = p.node.right } } -func (p *monoidTreePointer) min() { +func (p *syncTreePointer) min() { for { switch { case p.node == nil || p.node.left == nil: @@ -158,7 +158,7 @@ func (p *monoidTreePointer) min() { } } -func (p *monoidTreePointer) max() { +func (p *syncTreePointer) max() { for { switch { case p.node == nil || p.node.right == nil: @@ -169,14 +169,14 @@ func (p *monoidTreePointer) max() { } } -func (p *monoidTreePointer) Equal(other MonoidTreePointer) bool { +func (p *syncTreePointer) Equal(other SyncTreePointer) bool { if other == nil { return p.node == nil } - return p.node == other.(*monoidTreePointer).node + return p.node == other.(*syncTreePointer).node } -func (p *monoidTreePointer) Prev() { +func (p *syncTreePointer) Prev() { switch { case p.node == nil: case p.node.left != nil: @@ -194,7 +194,7 @@ func (p *monoidTreePointer) Prev() { } } -func (p *monoidTreePointer) Next() { +func (p *syncTreePointer) Next() { switch { case p.node == nil: case p.node.right != nil: @@ -212,23 +212,23 @@ func (p *monoidTreePointer) Next() { } } -func (p *monoidTreePointer) Key() Ordered { +func (p *syncTreePointer) Key() Ordered { if p.node == nil { return nil } return p.node.key } -func (p *monoidTreePointer) Value() any { +func (p *syncTreePointer) Value() any { if p.node == nil { return nil } return p.node.value } -type monoidTreeNode struct { - left *monoidTreeNode - right *monoidTreeNode +type syncTreeNode struct { + left *syncTreeNode + right *syncTreeNode key Ordered value any max Ordered @@ -236,55 +236,55 @@ type monoidTreeNode struct { flags flags } -func (mn *monoidTreeNode) red() bool { - return mn != nil && (mn.flags&flagBlack) == 0 +func (sn *syncTreeNode) red() bool { + return sn != nil && (sn.flags&flagBlack) == 0 } -func (mn *monoidTreeNode) black() bool { - return mn == nil || (mn.flags&flagBlack) != 0 +func (sn *syncTreeNode) black() bool { + return sn == nil || (sn.flags&flagBlack) != 0 } -func (mn *monoidTreeNode) child(dir dir) *monoidTreeNode { - if mn == nil { +func (sn *syncTreeNode) child(dir dir) *syncTreeNode { + if sn == nil { return nil } if dir == left { - return mn.left + return sn.left } - return mn.right + return sn.right } -func (mn *monoidTreeNode) Key() Ordered { return mn.key } +func (sn *syncTreeNode) Key() Ordered { return sn.key } -func (mn *monoidTreeNode) dump(w io.Writer, indent int) { +func (sn *syncTreeNode) dump(w io.Writer, indent int) { indentStr := strings.Repeat(" ", indent) - fmt.Fprintf(w, "%skey: %v\n", indentStr, mn.key) - fmt.Fprintf(w, "%smax: %v\n", indentStr, mn.max) - fmt.Fprintf(w, "%sfp: %v\n", indentStr, mn.fingerprint) + fmt.Fprintf(w, "%skey: %v\n", indentStr, sn.key) + fmt.Fprintf(w, "%smax: %v\n", indentStr, sn.max) + fmt.Fprintf(w, "%sfp: %v\n", indentStr, sn.fingerprint) color := "red" - if mn.black() { + if sn.black() { color = "black" } fmt.Fprintf(w, "%scolor: %v\n", indentStr, color) - if mn.left != nil { + if sn.left != nil { fmt.Fprintf(w, "%sleft:\n", indentStr) - mn.left.dump(w, indent+1) - if mn.left.key.Compare(mn.key) >= 0 { + sn.left.dump(w, indent+1) + if sn.left.key.Compare(sn.key) >= 0 { fmt.Fprintf(w, "%sERROR: left key >= parent key\n", indentStr) } } - if mn.right != nil { + if sn.right != nil { fmt.Fprintf(w, "%sright:\n", indentStr) - mn.right.dump(w, indent+1) - if mn.right.key.Compare(mn.key) <= 0 { + sn.right.dump(w, indent+1) + if sn.right.key.Compare(sn.key) <= 0 { fmt.Fprintf(w, "%sERROR: right key <= parent key\n", indentStr) } } } -func (mn *monoidTreeNode) dumpSubtree() string { +func (sn *syncTreeNode) dumpSubtree() string { var sb strings.Builder - mn.dump(&sb, 0) + sn.dump(&sb, 0) return sb.String() } @@ -292,305 +292,305 @@ func (mn *monoidTreeNode) dumpSubtree() string { // so that it can be used in further cloned trees. // A non-cloned node cannot have any cloned children, so the function // stops the recursion at any non-cloned node. -func (mn *monoidTreeNode) cleanCloned() { - if mn == nil || mn.flags&flagCloned == 0 { +func (sn *syncTreeNode) cleanCloned() { + if sn == nil || sn.flags&flagCloned == 0 { return } - mn.flags &^= flagCloned - mn.left.cleanCloned() - mn.right.cleanCloned() + sn.flags &^= flagCloned + sn.left.cleanCloned() + sn.right.cleanCloned() } -type monoidTree struct { +type syncTree struct { m Monoid - root *monoidTreeNode - cachedMinPtr *monoidTreePointer - cachedMaxPtr *monoidTreePointer + root *syncTreeNode + cachedMinPtr *syncTreePointer + cachedMaxPtr *syncTreePointer } -func NewMonoidTree(m Monoid) MonoidTree { - return &monoidTree{m: m} +func NewSyncTree(m Monoid) SyncTree { + return &syncTree{m: m} } -func (mt *monoidTree) Copy() MonoidTree { +func (st *syncTree) Copy() SyncTree { // Clean flagCloned from any nodes created specifically // for this subtree. This will mean they will have to be // re-cloned if they need to be changed again. - mt.root.cleanCloned() + st.root.cleanCloned() // Don't reuse cachedMinPtr / cachedMaxPtr for the cloned // tree to be on the safe side - return &monoidTree{ - m: mt.m, - root: mt.root, + return &syncTree{ + m: st.m, + root: st.root, } } -func (mt *monoidTree) rootPtr() *monoidTreePointer { - return &monoidTreePointer{ - parentStack: make([]*monoidTreeNode, 0, initialParentStackSize), - node: mt.root, +func (st *syncTree) rootPtr() *syncTreePointer { + return &syncTreePointer{ + parentStack: make([]*syncTreeNode, 0, initialParentStackSize), + node: st.root, } } -func (mt *monoidTree) ensureCloned(mn *monoidTreeNode) *monoidTreeNode { - if mn.flags&flagCloned != 0 { - return mn +func (st *syncTree) ensureCloned(sn *syncTreeNode) *syncTreeNode { + if sn.flags&flagCloned != 0 { + return sn } - cloned := *mn + cloned := *sn cloned.flags |= flagCloned return &cloned } -func (mt *monoidTree) setChild(mn *monoidTreeNode, dir dir, child *monoidTreeNode) *monoidTreeNode { - if mn == nil { +func (st *syncTree) setChild(sn *syncTreeNode, dir dir, child *syncTreeNode) *syncTreeNode { + if sn == nil { panic("setChild for a nil node") } - if mn.child(dir) == child { - return mn + if sn.child(dir) == child { + return sn } - mn = mt.ensureCloned(mn) + sn = st.ensureCloned(sn) if dir == left { - mn.left = child + sn.left = child } else { - mn.right = child + sn.right = child } - return mn + return sn } -func (mt *monoidTree) flip(mn *monoidTreeNode) *monoidTreeNode { - if mn.left == nil || mn.right == nil { +func (st *syncTree) flip(sn *syncTreeNode) *syncTreeNode { + if sn.left == nil || sn.right == nil { panic("can't flip color with one or more nil children") } - left := mt.ensureCloned(mn.left) - right := mt.ensureCloned(mn.right) - mn = mt.ensureCloned(mn) - mn.left = left - mn.right = right + left := st.ensureCloned(sn.left) + right := st.ensureCloned(sn.right) + sn = st.ensureCloned(sn) + sn.left = left + sn.right = right - mn.flags ^= flagBlack + sn.flags ^= flagBlack left.flags ^= flagBlack right.flags ^= flagBlack - return mn + return sn } -func (mt *monoidTree) Min() MonoidTreePointer { - if mt.root == nil { +func (st *syncTree) Min() SyncTreePointer { + if st.root == nil { return nil } - if mt.cachedMinPtr == nil { - mt.cachedMinPtr = mt.rootPtr() - mt.cachedMinPtr.min() + if st.cachedMinPtr == nil { + st.cachedMinPtr = st.rootPtr() + st.cachedMinPtr.min() } - if mt.cachedMinPtr.node == nil { + if st.cachedMinPtr.node == nil { panic("BUG: no minNode in a non-empty tree") } - return mt.cachedMinPtr.clone() + return st.cachedMinPtr.clone() } -func (mt *monoidTree) Max() MonoidTreePointer { - if mt.root == nil { +func (st *syncTree) Max() SyncTreePointer { + if st.root == nil { return nil } - if mt.cachedMaxPtr == nil { - mt.cachedMaxPtr = mt.rootPtr() - mt.cachedMaxPtr.max() + if st.cachedMaxPtr == nil { + st.cachedMaxPtr = st.rootPtr() + st.cachedMaxPtr.max() } - if mt.cachedMaxPtr.node == nil { + if st.cachedMaxPtr.node == nil { panic("BUG: no maxNode in a non-empty tree") } - return mt.cachedMaxPtr.clone() + return st.cachedMaxPtr.clone() } -func (mt *monoidTree) Fingerprint() any { - if mt.root == nil { - return mt.m.Identity() +func (st *syncTree) Fingerprint() any { + if st.root == nil { + return st.m.Identity() } - return mt.root.fingerprint + return st.root.fingerprint } -func (mt *monoidTree) newNode(parent *monoidTreeNode, k Ordered, v any) *monoidTreeNode { - return &monoidTreeNode{ +func (st *syncTree) newNode(parent *syncTreeNode, k Ordered, v any) *syncTreeNode { + return &syncTreeNode{ key: k, value: v, max: k, - fingerprint: mt.m.Fingerprint(k), + fingerprint: st.m.Fingerprint(k), } } -func (mt *monoidTree) buildFromSortedSlice(parent *monoidTreeNode, s []Ordered) *monoidTreeNode { +func (st *syncTree) buildFromSortedSlice(parent *syncTreeNode, s []Ordered) *syncTreeNode { switch len(s) { case 0: return nil case 1: - return mt.newNode(nil, s[0], nil) + return st.newNode(nil, s[0], nil) } middle := len(s) / 2 - node := mt.newNode(parent, s[middle], nil) - node.left = mt.buildFromSortedSlice(node, s[:middle]) - node.right = mt.buildFromSortedSlice(node, s[middle+1:]) + node := st.newNode(parent, s[middle], nil) + node.left = st.buildFromSortedSlice(node, s[:middle]) + node.right = st.buildFromSortedSlice(node, s[middle+1:]) if node.left != nil { - node.fingerprint = mt.m.Op(node.left.fingerprint, node.fingerprint) + node.fingerprint = st.m.Op(node.left.fingerprint, node.fingerprint) } if node.right != nil { - node.fingerprint = mt.m.Op(node.fingerprint, node.right.fingerprint) + node.fingerprint = st.m.Op(node.fingerprint, node.right.fingerprint) node.max = node.right.max } return node } -func (mt *monoidTree) safeFingerprint(mn *monoidTreeNode) any { - if mn == nil { - return mt.m.Identity() +func (st *syncTree) safeFingerprint(sn *syncTreeNode) any { + if sn == nil { + return st.m.Identity() } - return mn.fingerprint + return sn.fingerprint } -func (mt *monoidTree) updateFingerprintAndMax(mn *monoidTreeNode) { - fp := mt.m.Op(mt.safeFingerprint(mn.left), mt.m.Fingerprint(mn.key)) - fp = mt.m.Op(fp, mt.safeFingerprint(mn.right)) - newMax := mn.key - if mn.right != nil { - newMax = mn.right.max +func (st *syncTree) updateFingerprintAndMax(sn *syncTreeNode) { + fp := st.m.Op(st.safeFingerprint(sn.left), st.m.Fingerprint(sn.key)) + fp = st.m.Op(fp, st.safeFingerprint(sn.right)) + newMax := sn.key + if sn.right != nil { + newMax = sn.right.max } - if mn.flags&flagCloned == 0 && - (!reflect.DeepEqual(mn.fingerprint, fp) || mn.max.Compare(newMax) != 0) { + if sn.flags&flagCloned == 0 && + (!reflect.DeepEqual(sn.fingerprint, fp) || sn.max.Compare(newMax) != 0) { panic("BUG: updating fingerprint/max for a non-cloned node") } - mn.fingerprint = fp - mn.max = newMax + sn.fingerprint = fp + sn.max = newMax } -func (mt *monoidTree) rotate(mn *monoidTreeNode, d dir) *monoidTreeNode { - // mn.verify() +func (st *syncTree) rotate(sn *syncTreeNode, d dir) *syncTreeNode { + // sn.verify() rd := d.flip() - tmp := mn.child(rd) + tmp := sn.child(rd) if tmp == nil { panic("BUG: nil parent after rotate") } // fmt.Fprintf(os.Stderr, "QQQQQ: rotate %s (child at %s is %s): subtree:\n%s\n", - // d, rd, tmp.key, mn.dumpSubtree()) - mn = mt.setChild(mn, rd, tmp.child(d)) - tmp = mt.setChild(tmp, d, mn) + // d, rd, tmp.key, sn.dumpSubtree()) + sn = st.setChild(sn, rd, tmp.child(d)) + tmp = st.setChild(tmp, d, sn) // copy node color to the tmp - tmp.flags = (tmp.flags &^ flagBlack) | (mn.flags & flagBlack) - mn.flags &^= flagBlack // set to red + tmp.flags = (tmp.flags &^ flagBlack) | (sn.flags & flagBlack) + sn.flags &^= flagBlack // set to red - // it's important to update mn first as it may be the new right child of + // it's important to update sn first as it may be the new right child of // tmp, and we need to update tmp.max too - mt.updateFingerprintAndMax(mn) - mt.updateFingerprintAndMax(tmp) + st.updateFingerprintAndMax(sn) + st.updateFingerprintAndMax(tmp) return tmp } -func (mt *monoidTree) doubleRotate(mn *monoidTreeNode, d dir) *monoidTreeNode { +func (st *syncTree) doubleRotate(sn *syncTreeNode, d dir) *syncTreeNode { rd := d.flip() - mn = mt.setChild(mn, rd, mt.rotate(mn.child(rd), rd)) - return mt.rotate(mn, d) + sn = st.setChild(sn, rd, st.rotate(sn.child(rd), rd)) + return st.rotate(sn, d) } -func (mt *monoidTree) Add(k Ordered) { - mt.add(k, nil, false) +func (st *syncTree) Add(k Ordered) { + st.add(k, nil, false) } -func (mt *monoidTree) Set(k Ordered, v any) { - mt.add(k, v, true) +func (st *syncTree) Set(k Ordered, v any) { + st.add(k, v, true) } -func (mt *monoidTree) add(k Ordered, v any, set bool) { - mt.root = mt.insert(mt.root, k, v, true, set) - if mt.root.flags&flagBlack == 0 { - mt.root = mt.ensureCloned(mt.root) - mt.root.flags |= flagBlack +func (st *syncTree) add(k Ordered, v any, set bool) { + st.root = st.insert(st.root, k, v, true, set) + if st.root.flags&flagBlack == 0 { + st.root = st.ensureCloned(st.root) + st.root.flags |= flagBlack } } -func (mt *monoidTree) insert(mn *monoidTreeNode, k Ordered, v any, rb, set bool) *monoidTreeNode { +func (st *syncTree) insert(sn *syncTreeNode, k Ordered, v any, rb, set bool) *syncTreeNode { // simplified insert implementation idea from // https://zarif98sjs.github.io/blog/blog/redblacktree/ - if mn == nil { - mn = mt.newNode(nil, k, v) + if sn == nil { + sn = st.newNode(nil, k, v) // the new node is not really "cloned", but at this point it's // only present in this tree so we can safely modify it // without allocating new nodes - mn.flags |= flagCloned + sn.flags |= flagCloned // when the tree is being modified, cached min/max ptrs are no longer valid - mt.cachedMinPtr = nil - mt.cachedMaxPtr = nil - return mn + st.cachedMinPtr = nil + st.cachedMaxPtr = nil + return sn } - c := k.Compare(mn.key) + c := k.Compare(sn.key) if c == 0 { - if v != mn.value { - mn = mt.ensureCloned(mn) - mn.value = v + if v != sn.value { + sn = st.ensureCloned(sn) + sn.value = v } - return mn + return sn } d := left if c > 0 { d = right } - oldChild := mn.child(d) - newChild := mt.insert(oldChild, k, v, rb, set) - mn = mt.setChild(mn, d, newChild) + oldChild := sn.child(d) + newChild := st.insert(oldChild, k, v, rb, set) + sn = st.setChild(sn, d, newChild) updateFP := true if rb { // non-red-black insert is used for testing - mn, updateFP = mt.insertFixup(mn, d, oldChild != newChild) + sn, updateFP = st.insertFixup(sn, d, oldChild != newChild) } if updateFP { - mt.updateFingerprintAndMax(mn) + st.updateFingerprintAndMax(sn) } - return mn + return sn } // insertFixup fixes a subtree after insert according to Red-Black tree rules. // It returns the updated node and a boolean indicating whether the fingerprint/max // update is needed. The latter is NOT the case -func (mt *monoidTree) insertFixup(mn *monoidTreeNode, d dir, updateFP bool) (*monoidTreeNode, bool) { - child := mn.child(d) +func (st *syncTree) insertFixup(sn *syncTreeNode, d dir, updateFP bool) (*syncTreeNode, bool) { + child := sn.child(d) rd := d.flip() switch { case child.black(): - return mn, true - case mn.child(rd).red(): - // both children of mn are red => any child has 2 reds in a row + return sn, true + case sn.child(rd).red(): + // both children of sn are red => any child has 2 reds in a row // (LL LR RR RL) => flip colors if child.child(d).red() || child.child(rd).red() { - return mt.flip(mn), true + return st.flip(sn), true } - return mn, true + return sn, true case child.child(d).red(): - // another child of mn is black + // another child of sn is black // any child has 2 reds in a row (LL RR) => rotate - // rotate will update fingerprint of mn and the node + // rotate will update fingerprint of sn and the node // that replaces it - return mt.rotate(mn, rd), updateFP + return st.rotate(sn, rd), updateFP case child.child(rd).red(): - // another child of mn is black + // another child of sn is black // any child has 2 reds in a row (LR RL) => align first, then rotate - // doubleRotate will update fingerprint of mn and the node + // doubleRotate will update fingerprint of sn and the node // that replaces it - return mt.doubleRotate(mn, rd), updateFP + return st.doubleRotate(sn, rd), updateFP default: - return mn, true + return sn, true } } -func (mt *monoidTree) Lookup(k Ordered) (any, bool) { +func (st *syncTree) Lookup(k Ordered) (any, bool) { // TODO: lookups shouldn't cause any allocation! - ptr := mt.rootPtr() - if !mt.findGTENode(ptr, k) || ptr.node == nil || ptr.Key().Compare(k) != 0 { + ptr := st.rootPtr() + if !st.findGTENode(ptr, k) || ptr.node == nil || ptr.Key().Compare(k) != 0 { return nil, false } return ptr.Value(), true } -func (mt *monoidTree) findGTENode(ptr *monoidTreePointer, x Ordered) bool { +func (st *syncTree) findGTENode(ptr *syncTreePointer, x Ordered) bool { for { switch { case ptr.node == nil: @@ -602,19 +602,19 @@ func (mt *monoidTree) findGTENode(ptr *monoidTreePointer, x Ordered) bool { // All of this subtree is below v, maybe we can have // some luck with the parent node ptr.parent() - mt.findGTENode(ptr, x) + st.findGTENode(ptr, x) case x.Compare(ptr.node.key) >= 0: // We're still below x (or at x, but allowEqual is // false), but given that we checked Max and saw that // this subtree has some keys that are greater than // or equal to x, we can find them on the right if ptr.node.right == nil { - // mn.Max lied to us - panic("BUG: MonoidTreeNode: x > mn.Max but no right branch") + // sn.Max lied to us + panic("BUG: SyncTreeNode: x > sn.Max but no right branch") } // Avoid endless recursion in case of a bug if x.Compare(ptr.node.right.max) > 0 { - panic("BUG: MonoidTreeNode: inconsistent Max on the right branch") + panic("BUG: SyncTreeNode: inconsistent Max on the right branch") } ptr.right() case ptr.node.left == nil || x.Compare(ptr.node.left.max) > 0: @@ -624,26 +624,26 @@ func (mt *monoidTree) findGTENode(ptr *monoidTreePointer, x Ordered) bool { return true default: // Some keys on the left branch are greater or equal - // than x accordingto mn.Left.Max + // than x accordingto sn.Left.Max ptr.left() } } } -func (mt *monoidTree) rangeFingerprint(preceding MonoidTreePointer, start, end Ordered, stop FingerprintPredicate) (fp any, startPtr, endPtr *monoidTreePointer) { - if mt.root == nil { - return mt.m.Identity(), nil, nil +func (st *syncTree) rangeFingerprint(preceding SyncTreePointer, start, end Ordered, stop FingerprintPredicate) (fp any, startPtr, endPtr *syncTreePointer) { + if st.root == nil { + return st.m.Identity(), nil, nil } - var ptr *monoidTreePointer + var ptr *syncTreePointer if preceding == nil { - ptr = mt.rootPtr() + ptr = st.rootPtr() } else { - ptr = preceding.(*monoidTreePointer) + ptr = preceding.(*syncTreePointer) } - minPtr := mt.Min().(*monoidTreePointer) - acc := mt.m.Identity() - haveGTE := mt.findGTENode(ptr, start) + minPtr := st.Min().(*syncTreePointer) + acc := st.m.Identity() + haveGTE := st.findGTENode(ptr, start) startPtr = ptr.clone() switch { case start.Compare(end) >= 0: @@ -652,16 +652,16 @@ func (mt *monoidTree) rangeFingerprint(preceding MonoidTreePointer, start, end O // [start, max_element] and [min_element, end) var stopped bool if haveGTE { - acc, stopped = mt.aggregateUntil(ptr, acc, start, UpperBound{}, stop) + acc, stopped = st.aggregateUntil(ptr, acc, start, UpperBound{}, stop) } if !stopped && end.Compare(minPtr.Key()) > 0 { ptr = minPtr.clone() - acc, _ = mt.aggregateUntil(ptr, acc, LowerBound{}, end, stop) + acc, _ = st.aggregateUntil(ptr, acc, LowerBound{}, end, stop) } case haveGTE: // normal range, that is, start < end - acc, _ = mt.aggregateUntil(ptr, mt.m.Identity(), start, end, stop) + acc, _ = st.aggregateUntil(ptr, st.m.Identity(), start, end, stop) } if startPtr.node == nil { @@ -674,11 +674,11 @@ func (mt *monoidTree) rangeFingerprint(preceding MonoidTreePointer, start, end O return acc, startPtr, ptr } -func (mt *monoidTree) RangeFingerprint(ptr MonoidTreePointer, start, end Ordered, stop FingerprintPredicate) (fp any, startNode, endNode MonoidTreePointer) { - fp, startPtr, endPtr := mt.rangeFingerprint(ptr, start, end, stop) +func (st *syncTree) RangeFingerprint(ptr SyncTreePointer, start, end Ordered, stop FingerprintPredicate) (fp any, startNode, endNode SyncTreePointer) { + fp, startPtr, endPtr := st.rangeFingerprint(ptr, start, end, stop) switch { case startPtr == nil && endPtr == nil: - // avoid wrapping nil in MonoidTreePointer interface + // avoid wrapping nil in SyncTreePointer interface return fp, nil, nil case startPtr == nil || endPtr == nil: panic("BUG: can't have nil node just on one end") @@ -687,19 +687,19 @@ func (mt *monoidTree) RangeFingerprint(ptr MonoidTreePointer, start, end Ordered } } -func (mt *monoidTree) aggregateUntil(ptr *monoidTreePointer, acc any, start, end Ordered, stop FingerprintPredicate) (fp any, stopped bool) { - acc, stopped = mt.aggregateUp(ptr, acc, start, end, stop) +func (st *syncTree) aggregateUntil(ptr *syncTreePointer, acc any, start, end Ordered, stop FingerprintPredicate) (fp any, stopped bool) { + acc, stopped = st.aggregateUp(ptr, acc, start, end, stop) if ptr.node == nil || end.Compare(ptr.node.key) <= 0 || stopped { return acc, stopped } // fmt.Fprintf(os.Stderr, "QQQQQ: from aggregateUp: acc %q; ptr.node %q\n", acc, ptr.node.key) - f := mt.m.Op(acc, mt.m.Fingerprint(ptr.node.key)) + f := st.m.Op(acc, st.m.Fingerprint(ptr.node.key)) if stop.Match(f) { return acc, true } ptr.right() - return mt.aggregateDown(ptr, f, end, stop) + return st.aggregateDown(ptr, f, end, stop) } // aggregateUp ascends from the left (lower) end of the range towards the LCA @@ -716,44 +716,44 @@ func (mt *monoidTree) aggregateUntil(ptr *monoidTreePointer, acc any, start, end // If stop function is passed, we find the node on which it returns true // for the fingerprint accumulated between start and that node, if the target // node is somewhere to the left from the LCA. -func (mt *monoidTree) aggregateUp(ptr *monoidTreePointer, acc any, start, end Ordered, stop FingerprintPredicate) (fp any, stopped bool) { +func (st *syncTree) aggregateUp(ptr *syncTreePointer, acc any, start, end Ordered, stop FingerprintPredicate) (fp any, stopped bool) { for { switch { case ptr.node == nil: // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: null node\n") return acc, false case stop.Match(acc): - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: stop: node %v acc %v\n", mn.key, acc) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: stop: node %v acc %v\n", sn.key, acc) ptr.Prev() return acc, true case end.Compare(ptr.node.max) <= 0: - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: LCA: node %v acc %v\n", mn.key, acc) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: LCA: node %v acc %v\n", sn.key, acc) // This node is a the LCA, the starting point for AggregateDown return acc, false case start.Compare(ptr.node.key) <= 0: // This node is within the target range - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: in-range node %v acc %v\n", mn.key, acc) - f := mt.m.Op(acc, mt.m.Fingerprint(ptr.node.key)) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: in-range node %v acc %v\n", sn.key, acc) + f := st.m.Op(acc, st.m.Fingerprint(ptr.node.key)) if stop.Match(f) { - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: stop at the own node %v acc %v\n", mn.key, acc) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: stop at the own node %v acc %v\n", sn.key, acc) return acc, true } - f1 := mt.m.Op(f, mt.safeFingerprint(ptr.node.right)) + f1 := st.m.Op(f, st.safeFingerprint(ptr.node.right)) if stop.Match(f1) { - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: right subtree matches node %v acc %v f1 %v\n", mn.key, acc, f1) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: right subtree matches node %v acc %v f1 %v\n", sn.key, acc, f1) // The target node is somewhere in the right subtree if ptr.node.right == nil { panic("BUG: nil right child with non-identity fingerprint") } ptr.right() - acc := mt.boundedAggregate(ptr, f, stop) + acc := st.boundedAggregate(ptr, f, stop) if ptr.node == nil { panic("BUG: aggregateUp: bad subtree fingerprint on the right branch") } // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: right subtree: node %v acc %v\n", node.key, acc) return acc, true } else { - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: no right subtree match: node %v acc %v f1 %v\n", mn.key, acc, f1) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: no right subtree match: node %v acc %v f1 %v\n", sn.key, acc, f1) acc = f1 } } @@ -767,11 +767,11 @@ func (mt *monoidTree) aggregateUp(ptr *monoidTreePointer, acc any, start, end Or // aggregation using their saved fingerprint. // If stop function is passed, we find the node on which it returns true // for the fingerprint accumulated between start and that node -func (mt *monoidTree) aggregateDown(ptr *monoidTreePointer, acc any, end Ordered, stop FingerprintPredicate) (fp any, stopped bool) { +func (st *syncTree) aggregateDown(ptr *syncTreePointer, acc any, end Ordered, stop FingerprintPredicate) (fp any, stopped bool) { for { switch { case ptr.node == nil: - // fmt.Fprintf(os.Stderr, "QQQQQ: mn == nil\n") + // fmt.Fprintf(os.Stderr, "QQQQQ: sn == nil\n") return acc, false case stop.Match(acc): // fmt.Fprintf(os.Stderr, "QQQQQ: stop on node\n") @@ -782,7 +782,7 @@ func (mt *monoidTree) aggregateDown(ptr *monoidTreePointer, acc any, end Ordered // We're within the range but there also may be nodes // within the range to the right. The left branch is // fully within the range - f := mt.m.Op(acc, mt.safeFingerprint(ptr.node.left)) + f := st.m.Op(acc, st.safeFingerprint(ptr.node.left)) if stop.Match(f) { // fmt.Fprintf(os.Stderr, "QQQQQ: left subtree covers it\n") // The target node is somewhere in the left subtree @@ -790,9 +790,9 @@ func (mt *monoidTree) aggregateDown(ptr *monoidTreePointer, acc any, end Ordered panic("BUG: aggregateDown: nil left child with non-identity fingerprint") } ptr.left() - return mt.boundedAggregate(ptr, acc, stop), true + return st.boundedAggregate(ptr, acc, stop), true } - f1 := mt.m.Op(f, mt.m.Fingerprint(ptr.node.key)) + f1 := st.m.Op(f, st.m.Fingerprint(ptr.node.key)) if stop.Match(f1) { // fmt.Fprintf(os.Stderr, "QQQQQ: stop at the node, prev %#v\n", node.prev()) return f, true @@ -804,7 +804,7 @@ func (mt *monoidTree) aggregateDown(ptr *monoidTreePointer, acc any, end Ordered case ptr.node.left == nil || end.Compare(ptr.node.left.max) > 0: // fmt.Fprintf(os.Stderr, "QQQQQ: node covers the range\n") // Found the rightmost bounding node - f := mt.m.Op(acc, mt.safeFingerprint(ptr.node.left)) + f := st.m.Op(acc, st.safeFingerprint(ptr.node.left)) if stop.Match(f) { // The target node is somewhere in the left subtree if ptr.node.left == nil { @@ -812,7 +812,7 @@ func (mt *monoidTree) aggregateDown(ptr *monoidTreePointer, acc any, end Ordered } // XXXXX fixme ptr.left() - return mt.boundedAggregate(ptr, acc, stop), true + return st.boundedAggregate(ptr, acc, stop), true } return f, false default: @@ -823,16 +823,16 @@ func (mt *monoidTree) aggregateDown(ptr *monoidTreePointer, acc any, end Ordered } } -func (mt *monoidTree) boundedAggregate(ptr *monoidTreePointer, acc any, stop FingerprintPredicate) any { +func (st *syncTree) boundedAggregate(ptr *syncTreePointer, acc any, stop FingerprintPredicate) any { for { - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: node %v, acc %v\n", mn.key, acc) + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: node %v, acc %v\n", sn.key, acc) if ptr.node == nil { return acc } // If we don't need to stop, or if the stop point is somewhere after // this subtree, we can just use the pre-calculated subtree fingerprint - if f := mt.m.Op(acc, ptr.node.fingerprint); !stop.Match(f) { + if f := st.m.Op(acc, ptr.node.fingerprint); !stop.Match(f) { return f } @@ -844,54 +844,54 @@ func (mt *monoidTree) boundedAggregate(ptr *monoidTreePointer, acc any, stop Fin if ptr.node.left != nil { // See if we can skip recursion on the left branch - f := mt.m.Op(acc, ptr.node.left.fingerprint) + f := st.m.Op(acc, ptr.node.left.fingerprint) if !stop.Match(f) { - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: mn Left non-nil and no-stop %v, f %v, left fingerprint %v\n", mn.key, f, mn.Left.Fingerprint) + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: sn Left non-nil and no-stop %v, f %v, left fingerprint %v\n", sn.key, f, sn.Left.Fingerprint) acc = f } else { // The target node must be contained in the left subtree ptr.left() - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: mn Left non-nil and stop %v, new node %v, acc %v\n", mn.key, node.key, acc) + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: sn Left non-nil and stop %v, new node %v, acc %v\n", sn.key, node.key, acc) continue } } - f := mt.m.Op(acc, mt.m.Fingerprint(ptr.node.key)) + f := st.m.Op(acc, st.m.Fingerprint(ptr.node.key)) if stop.Match(f) { return acc } acc = f if ptr.node.right != nil { - f1 := mt.m.Op(f, ptr.node.right.fingerprint) + f1 := st.m.Op(f, ptr.node.right.fingerprint) if !stop.Match(f1) { // The right branch is still below the target fingerprint - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: mn Right non-nil and no-stop %v, acc %v\n", mn.key, acc) + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: sn Right non-nil and no-stop %v, acc %v\n", sn.key, acc) acc = f1 } else { // The target node must be contained in the right subtree acc = f ptr.right() - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: mn Right non-nil and stop %v, new node %v, acc %v\n", mn.key, node.key, acc) + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: sn Right non-nil and stop %v, new node %v, acc %v\n", sn.key, node.key, acc) continue } } - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: %v -- return acc %v\n", mn.key, acc) + // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: %v -- return acc %v\n", sn.key, acc) return acc } } -func (mt *monoidTree) Dump() string { - if mt.root == nil { +func (st *syncTree) Dump() string { + if st.root == nil { return "" } var sb strings.Builder - mt.root.dump(&sb, 0) + st.root.dump(&sb, 0) return sb.String() } // TBD: !!! values and Lookup (via findGTENode) !!! -// TODO: rename MonoidTreeNode to just Node, MonoidTree to SyncTree +// TODO: rename SyncTreeNode to just Node, SyncTree to SyncTree // TODO: use sync.Pool for node alloc // see also: // https://www.akshaydeo.com/blog/2017/12/23/How-did-I-improve-latency-by-700-percent-using-syncPool/ diff --git a/hashsync/monoid_tree_test.go b/hashsync/sync_tree_test.go similarity index 84% rename from hashsync/monoid_tree_test.go rename to hashsync/sync_tree_test.go index bdc9ee155b..144c66d72c 100644 --- a/hashsync/monoid_tree_test.go +++ b/hashsync/sync_tree_test.go @@ -31,40 +31,40 @@ func sampleCountMonoid() Monoid { return CombineMonoids(sampleMonoid{}, CountingMonoid{}) } -func makeStringConcatTree(chars string) MonoidTree { +func makeStringConcatTree(chars string) SyncTree { ids := make([]sampleID, len(chars)) for n, c := range chars { ids[n] = sampleID(c) } - return MonoidTreeFromSlice[sampleID](sampleCountMonoid(), ids) + return SyncTreeFromSlice[sampleID](sampleCountMonoid(), ids) } // dumbAdd inserts the node into the tree without trying to maintain the // red-black properties -func dumbAdd(mt MonoidTree, k Ordered) { - mtree := mt.(*monoidTree) - mtree.root = mtree.insert(mtree.root, k, nil, false, false) +func dumbAdd(st SyncTree, k Ordered) { + stree := st.(*syncTree) + stree.root = stree.insert(stree.root, k, nil, false, false) } // makeDumbTree constructs a binary tree by adding the chars one-by-one without // trying to maintain the red-black properties -func makeDumbTree(chars string) MonoidTree { +func makeDumbTree(chars string) SyncTree { if len(chars) == 0 { panic("empty set") } - mt := NewMonoidTree(sampleCountMonoid()) + st := NewSyncTree(sampleCountMonoid()) for _, c := range chars { - dumbAdd(mt, sampleID(c)) + dumbAdd(st, sampleID(c)) } - return mt + return st } -func makeRBTree(chars string) MonoidTree { - mt := NewMonoidTree(sampleCountMonoid()) +func makeRBTree(chars string) SyncTree { + st := NewSyncTree(sampleCountMonoid()) for _, c := range chars { - mt.Add(sampleID(c)) + st.Add(sampleID(c)) } - return mt + return st } func gtePos(all string, item string) int { @@ -126,7 +126,7 @@ func naiveRange(all, x, y string, stopCount int) (fingerprint, startStr, endStr } func TestEmptyTree(t *testing.T) { - tree := NewMonoidTree(sampleCountMonoid()) + tree := NewSyncTree(sampleCountMonoid()) rfp1, startNode, endNode := tree.RangeFingerprint(nil, sampleID("a"), sampleID("a"), nil) require.Nil(t, startNode) require.Nil(t, endNode) @@ -147,7 +147,7 @@ func TestEmptyTree(t *testing.T) { } } -func testMonoidTreeRanges(t *testing.T, tree MonoidTree) { +func testSyncTreeRanges(t *testing.T, tree SyncTree) { all := "abcdefghijklmnopqr" for _, tc := range []struct { all string @@ -234,18 +234,18 @@ func testMonoidTreeRanges(t *testing.T, tree MonoidTree) { } } -func TestMonoidTreeRanges(t *testing.T) { +func TestSyncTreeRanges(t *testing.T) { t.Run("pre-balanced tree", func(t *testing.T) { - testMonoidTreeRanges(t, makeStringConcatTree("abcdefghijklmnopqr")) + testSyncTreeRanges(t, makeStringConcatTree("abcdefghijklmnopqr")) }) t.Run("sequential add", func(t *testing.T) { - testMonoidTreeRanges(t, makeDumbTree("abcdefghijklmnopqr")) + testSyncTreeRanges(t, makeDumbTree("abcdefghijklmnopqr")) }) t.Run("shuffled add", func(t *testing.T) { - testMonoidTreeRanges(t, makeDumbTree("lodrnifeqacmbhkgjp")) + testSyncTreeRanges(t, makeDumbTree("lodrnifeqacmbhkgjp")) }) t.Run("red-black add", func(t *testing.T) { - testMonoidTreeRanges(t, makeRBTree("lodrnifeqacmbhkgjp")) + testSyncTreeRanges(t, makeRBTree("lodrnifeqacmbhkgjp")) }) } @@ -270,7 +270,7 @@ func TestAscendingRanges(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { var fps []string - var node MonoidTreePointer + var node SyncTreePointer for n, rng := range tc.ranges { x := sampleID(rng[0]) y := sampleID(rng[1]) @@ -290,64 +290,64 @@ func TestAscendingRanges(t *testing.T) { } } -func verifyBinaryTree(t *testing.T, mn *monoidTreeNode) { - cloned := mn.flags&flagCloned != 0 - if mn.left != nil { +func verifyBinaryTree(t *testing.T, sn *syncTreeNode) { + cloned := sn.flags&flagCloned != 0 + if sn.left != nil { if !cloned { - require.Zero(t, mn.left.flags&flagCloned, "cloned left child of a non-cloned node") + require.Zero(t, sn.left.flags&flagCloned, "cloned left child of a non-cloned node") } - require.Negative(t, mn.left.key.Compare(mn.key)) + require.Negative(t, sn.left.key.Compare(sn.key)) // not a "real" pointer (no parent stack), just to get max - leftMax := &monoidTreePointer{node: mn.left} + leftMax := &syncTreePointer{node: sn.left} leftMax.max() - require.Negative(t, leftMax.Key().Compare(mn.key)) - verifyBinaryTree(t, mn.left) + require.Negative(t, leftMax.Key().Compare(sn.key)) + verifyBinaryTree(t, sn.left) } - if mn.right != nil { + if sn.right != nil { if !cloned { - require.Zero(t, mn.right.flags&flagCloned, "cloned right child of a non-cloned node") + require.Zero(t, sn.right.flags&flagCloned, "cloned right child of a non-cloned node") } - require.Positive(t, mn.right.key.Compare(mn.key)) + require.Positive(t, sn.right.key.Compare(sn.key)) // not a "real" pointer (no parent stack), just to get min - rightMin := &monoidTreePointer{node: mn.right} + rightMin := &syncTreePointer{node: sn.right} rightMin.min() - require.Positive(t, rightMin.Key().Compare(mn.key)) - verifyBinaryTree(t, mn.right) + require.Positive(t, rightMin.Key().Compare(sn.key)) + verifyBinaryTree(t, sn.right) } } -func verifyRedBlackNode(t *testing.T, mn *monoidTreeNode, blackDepth int) int { - if mn == nil { +func verifyRedBlackNode(t *testing.T, sn *syncTreeNode, blackDepth int) int { + if sn == nil { return blackDepth + 1 } - if mn.flags&flagBlack == 0 { - if mn.left != nil { - require.Equal(t, flagBlack, mn.left.flags&flagBlack, "left child of a red node is red") + if sn.flags&flagBlack == 0 { + if sn.left != nil { + require.Equal(t, flagBlack, sn.left.flags&flagBlack, "left child of a red node is red") } - if mn.right != nil { - require.Equal(t, flagBlack, mn.right.flags&flagBlack, "right child of a red node is red") + if sn.right != nil { + require.Equal(t, flagBlack, sn.right.flags&flagBlack, "right child of a red node is red") } } else { blackDepth++ } - bdLeft := verifyRedBlackNode(t, mn.left, blackDepth) - bdRight := verifyRedBlackNode(t, mn.right, blackDepth) - require.Equal(t, bdLeft, bdRight, "subtree black depth for node %s", mn.key) + bdLeft := verifyRedBlackNode(t, sn.left, blackDepth) + bdRight := verifyRedBlackNode(t, sn.right, blackDepth) + require.Equal(t, bdLeft, bdRight, "subtree black depth for node %s", sn.key) return bdLeft } -func verifyRedBlack(t *testing.T, mt *monoidTree) { - if mt.root == nil { +func verifyRedBlack(t *testing.T, st *syncTree) { + if st.root == nil { return } - require.Equal(t, flagBlack, mt.root.flags&flagBlack, "root node must be black") - verifyRedBlackNode(t, mt.root, 0) + require.Equal(t, flagBlack, st.root.flags&flagBlack, "root node must be black") + verifyRedBlackNode(t, st.root, 0) } func TestRedBlackTreeInsert(t *testing.T) { for i := 0; i < 1000; i++ { - tree := NewMonoidTree(sampleCountMonoid()) + tree := NewSyncTree(sampleCountMonoid()) items := []byte("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz") count := rand.Intn(len(items)) + 1 items = items[:count] @@ -359,7 +359,7 @@ func TestRedBlackTreeInsert(t *testing.T) { // items := []byte("0123456789ABCDEFG") // shuffled := []byte("0678DF1CG5A9324BE") - trees := make([]MonoidTree, len(shuffled)) + trees := make([]SyncTree, len(shuffled)) treeDumps := make([]string, len(shuffled)) for i := 0; i < len(shuffled); i++ { trees[i] = tree.Copy() @@ -370,7 +370,7 @@ func TestRedBlackTreeInsert(t *testing.T) { // this shouldn't change anything trees[i-1].Add(sampleID(shuffled[rand.Intn(i-1)])) // cloning should not happen b/c no new nodes are inserted - require.Zero(t, trees[i-1].(*monoidTree).root.flags&flagCloned) + require.Zero(t, trees[i-1].(*syncTree).root.flags&flagCloned) } } @@ -383,8 +383,8 @@ func TestRedBlackTreeInsert(t *testing.T) { // t.Logf("items: %q", string(items)) // t.Logf("shuffled: %q", string(shuffled)) // t.Logf("QQQQQ: tree:\n%s", tree.Dump()) - verifyBinaryTree(t, tree.(*monoidTree).root) - verifyRedBlack(t, tree.(*monoidTree)) + verifyBinaryTree(t, tree.(*syncTree).root) + verifyRedBlack(t, tree.(*syncTree)) for ptr := tree.Min(); ptr.Key() != nil; ptr.Next() { // avoid endless loop due to bugs in the tree impl require.Less(t, n, len(items)*2, "got much more items than needed: %q -- %q", actualItems, shuffled) @@ -402,7 +402,7 @@ func TestRedBlackTreeInsert(t *testing.T) { } } -type makeTestTreeFunc func(chars string) MonoidTree +type makeTestTreeFunc func(chars string) SyncTree func testRandomOrderAndRanges(t *testing.T, mktree makeTestTreeFunc) { all := "abcdefghijklmnopqr" @@ -491,7 +491,7 @@ func TestTreeValues(t *testing.T) { // flagCloned on the root should be cleared after copy // and not set again by Set b/c the value is the same tree.Set(sampleID("d"), 456) // nothing changed - require.Zero(t, tree.(*monoidTree).root.flags&flagCloned) + require.Zero(t, tree.(*syncTree).root.flags&flagCloned) tree1.Set(sampleID("b"), 1234) tree1.Set(sampleID("c"), 222) diff --git a/hashsync/sync_trees_store.go b/hashsync/sync_trees_store.go new file mode 100644 index 0000000000..8ff6f20a02 --- /dev/null +++ b/hashsync/sync_trees_store.go @@ -0,0 +1,90 @@ +package hashsync + +type syncTreeIterator struct { + st SyncTree + ptr SyncTreePointer +} + +var _ Iterator = &syncTreeIterator{} + +func (it *syncTreeIterator) Equal(other Iterator) bool { + o := other.(*syncTreeIterator) + if it.st != o.st { + panic("comparing iterators from different SyncTreeStore") + } + return it.ptr.Equal(o.ptr) +} + +func (it *syncTreeIterator) Key() Ordered { + return it.ptr.Key() +} + +func (it *syncTreeIterator) Next() { + it.ptr.Next() + if it.ptr.Key() == nil { + it.ptr = it.st.Min() + } +} + +type SyncTreeStore struct { + st SyncTree +} + +var _ ItemStore = &SyncTreeStore{} + +func NewSyncTreeStore(m Monoid) ItemStore { + return &SyncTreeStore{ + st: NewSyncTree(CombineMonoids(m, CountingMonoid{})), + } +} + +// Add implements ItemStore. +func (sts *SyncTreeStore) Add(k Ordered) { + sts.st.Add(k) +} + +func (sts *SyncTreeStore) iter(ptr SyncTreePointer) Iterator { + if ptr == nil { + return nil + } + return &syncTreeIterator{ + st: sts.st, + ptr: ptr, + } +} + +// GetRangeInfo implements ItemStore. +func (sts *SyncTreeStore) GetRangeInfo(preceding Iterator, x Ordered, y Ordered, count int) RangeInfo { + var stop FingerprintPredicate + var node SyncTreePointer + if preceding != nil { + p := preceding.(*syncTreeIterator) + if p.st != sts.st { + panic("GetRangeInfo: preceding iterator from a wrong SyncTreeStore") + } + node = p.ptr + } + if count >= 0 { + stop = func(fp any) bool { + return CombinedSecond[int](fp) > count + } + } + fp, startPtr, endPtr := sts.st.RangeFingerprint(node, x, y, stop) + cfp := fp.(CombinedFingerprint) + return RangeInfo{ + Fingerprint: cfp.First, + Count: cfp.Second.(int), + Start: sts.iter(startPtr), + End: sts.iter(endPtr), + } +} + +// Min implements ItemStore. +func (sts *SyncTreeStore) Min() Iterator { + return sts.iter(sts.st.Min()) +} + +// Max implements ItemStore. +func (sts *SyncTreeStore) Max() Iterator { + return sts.iter(sts.st.Max()) +} diff --git a/hashsync/xorsync_test.go b/hashsync/xorsync_test.go index 67a4c95eae..ba43ef5611 100644 --- a/hashsync/xorsync_test.go +++ b/hashsync/xorsync_test.go @@ -79,7 +79,7 @@ func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(syncA, syncB * } sliceA := src[:cfg.numTestHashes-numSpecificB] - storeA := NewMonoidTreeStore(Hash32To12Xor{}) + storeA := NewSyncTreeStore(Hash32To12Xor{}) for _, h := range sliceA { storeA.Add(h) } @@ -88,7 +88,7 @@ func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(syncA, syncB * sliceB := append([]types.Hash32(nil), src[:cfg.numTestHashes-numSpecificB-numSpecificA]...) sliceB = append(sliceB, src[cfg.numTestHashes-numSpecificB:]...) - storeB := NewMonoidTreeStore(Hash32To12Xor{}) + storeB := NewSyncTreeStore(Hash32To12Xor{}) for _, h := range sliceB { storeB.Add(h) } From a8fa3fd74de3b4a60b8210abbb0b46cb2abe9a74 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Fri, 12 Jan 2024 07:30:29 +0400 Subject: [PATCH 12/76] hashsync: send the actual objects along with the hashes --- hashsync/handler.go | 99 +++++++++++++++++++++++++++----- hashsync/handler_test.go | 48 ++++++++++------ hashsync/rangesync.go | 28 ++++++--- hashsync/rangesync_test.go | 108 +++++++++++++++++++++++------------ hashsync/sync_trees_store.go | 44 ++++++++++++-- hashsync/wire_types.go | 33 +++++------ hashsync/wire_types_scale.go | 19 +++++- hashsync/xorsync_test.go | 72 +++++++++++++++++------ 8 files changed, 336 insertions(+), 115 deletions(-) diff --git a/hashsync/handler.go b/hashsync/handler.go index 9624f57a47..9340a59d79 100644 --- a/hashsync/handler.go +++ b/hashsync/handler.go @@ -14,6 +14,69 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/server" ) +type sendable interface { + codec.Encodable + Type() MessageType +} + +type decodedItemBatchMessage struct { + ContentKeys []types.Hash32 + ContentValues []any +} + +var _ SyncMessage = &decodedItemBatchMessage{} + +func (m *decodedItemBatchMessage) Type() MessageType { return MessageTypeItemBatch } +func (m *decodedItemBatchMessage) X() Ordered { return nil } +func (m *decodedItemBatchMessage) Y() Ordered { return nil } +func (m *decodedItemBatchMessage) Fingerprint() any { return nil } +func (m *decodedItemBatchMessage) Count() int { return 0 } +func (m *decodedItemBatchMessage) Keys() []Ordered { + r := make([]Ordered, len(m.ContentKeys)) + for n, k := range m.ContentKeys { + r[n] = k + } + return r +} +func (m *decodedItemBatchMessage) Values() []any { + r := make([]any, len(m.ContentValues)) + for n, v := range m.ContentValues { + r[n] = v + } + return r +} + +func (m *decodedItemBatchMessage) encode() (*ItemBatchMessage, error) { + var b bytes.Buffer + for _, v := range m.ContentValues { + _, err := codec.EncodeTo(&b, v.(codec.Encodable)) + if err != nil { + return nil, err + } + } + return &ItemBatchMessage{ + ContentKeys: m.ContentKeys, + ContentValues: b.Bytes(), + }, nil +} + +func decodeItemBatchMessage(m *ItemBatchMessage, newValue NewValueFunc) (*decodedItemBatchMessage, error) { + d := &decodedItemBatchMessage{ContentKeys: m.ContentKeys} + b := bytes.NewBuffer(m.ContentValues) + for b.Len() != 0 { + v := newValue().(codec.Decodable) + if _, err := codec.DecodeFrom(b, v); err != nil { + return nil, err + } + d.ContentValues = append(d.ContentValues, v) + } + if len(d.ContentValues) != len(d.ContentKeys) { + return nil, fmt.Errorf("mismatched key / value counts: %d / %d", + len(d.ContentKeys), len(d.ContentValues)) + } + return d, nil +} + type outboundMessage struct { code MessageType // TODO: "mt" msg codec.Encodable @@ -25,6 +88,7 @@ type wireConduit struct { i server.Interactor pendingMsgs []SyncMessage initReqBuf *bytes.Buffer + newValue NewValueFunc // rmmePrint bool } @@ -69,7 +133,11 @@ func (c *wireConduit) receive() (msgs []SyncMessage, err error) { if _, err := codec.DecodeFrom(b, &m); err != nil { return nil, err } - msgs = append(msgs, &m) + dm, err := decodeItemBatchMessage(&m, c.newValue) + if err != nil { + return nil, err + } + msgs = append(msgs, dm) case MessageTypeEmptySet: msgs = append(msgs, &EmptySetMessage{}) case MessageTypeEmptyRange: @@ -96,13 +164,13 @@ func (c *wireConduit) receive() (msgs []SyncMessage, err error) { } } -func (c *wireConduit) send(m SyncMessage) error { +func (c *wireConduit) send(m sendable) error { // fmt.Fprintf(os.Stderr, "QQQQQ: wireConduit: sending %s m %#v\n", m.Type(), m) msg := []byte{byte(m.Type())} // if c.rmmePrint { // fmt.Fprintf(os.Stderr, "QQQQQ: send: %s\n", SyncMessageToString(m)) // } - encoded, err := codec.Encode(m.(codec.Encodable)) + encoded, err := codec.Encode(m) if err != nil { return fmt.Errorf("error encoding %T: %w", m, err) } @@ -170,17 +238,22 @@ func (c *wireConduit) SendRangeContents(x Ordered, y Ordered, count int) error { func (c *wireConduit) SendItems(count, itemChunkSize int, it Iterator) error { for i := 0; i < count; i += itemChunkSize { - var msg ItemBatchMessage + var msg decodedItemBatchMessage n := min(itemChunkSize, count-i) for n > 0 { if it.Key() == nil { panic("fakeConduit.SendItems: went got to the end of the tree") } - msg.Contents = append(msg.Contents, it.Key().(types.Hash32)) + msg.ContentKeys = append(msg.ContentKeys, it.Key().(types.Hash32)) + msg.ContentValues = append(msg.ContentValues, it.Value()) it.Next() n-- } - if err := c.send(&msg); err != nil { + encoded, err := msg.encode() + if err != nil { + return err + } + if err := c.send(encoded); err != nil { return err } } @@ -229,16 +302,18 @@ func makeHandler(rsr *RangeSetReconciler, c *wireConduit, done chan struct{}) se } } -func MakeServerHandler(rsr *RangeSetReconciler) server.InteractiveHandler { +func MakeServerHandler(is ItemStore, opts ...Option) server.InteractiveHandler { return func(ctx context.Context, i server.Interactor) (time.Duration, error) { - var c wireConduit + c := wireConduit{newValue: is.New} + rsr := NewRangeSetReconciler(is, opts...) h := makeHandler(rsr, &c, nil) return h(ctx, i) } } -func SyncStore(ctx context.Context, r requester, peer p2p.Peer, rsr *RangeSetReconciler) error { - var c wireConduit +func SyncStore(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, opts ...Option) error { + c := wireConduit{newValue: is.New} + rsr := NewRangeSetReconciler(is, opts...) // c.rmmePrint = true initReq, err := c.withInitialRequest(rsr.Initiate) if err != nil { @@ -261,9 +336,7 @@ func SyncStore(ctx context.Context, r requester, peer p2p.Peer, rsr *RangeSetRec } } -// TODO: HashSyncer object (SyncStore, also server handler, implementing ServerHandler) -// TODO: HashSyncer options instead of itemChunkSize (WithItemChunkSize, WithMaxSendRange) -// TODO: duration +// TODO: request duration // TODO: validate counts // TODO: don't forget about Initiate!!! // TBD: use MessageType instead of byte diff --git a/hashsync/handler_test.go b/hashsync/handler_test.go index 7585d90f16..03f347a871 100644 --- a/hashsync/handler_test.go +++ b/hashsync/handler_test.go @@ -185,6 +185,14 @@ func (it *sliceIterator) Key() Ordered { return nil } +func (it *sliceIterator) Value() any { + k := it.Key() + if k == nil { + return nil + } + return mkFakeValue(k.(types.Hash32)) +} + func (it *sliceIterator) Next() { if len(it.s) != 0 { it.s = it.s[1:] @@ -260,7 +268,7 @@ func (r *fakeRound) handleConversation(t *testing.T, c *wireConduit) error { return nil } -func makeTestHandler(t *testing.T, c *wireConduit, done chan struct{}, rounds []fakeRound) server.InteractiveHandler { +func makeTestHandler(t *testing.T, c *wireConduit, newValue NewValueFunc, done chan struct{}, rounds []fakeRound) server.InteractiveHandler { return func(ctx context.Context, i server.Interactor) (time.Duration, error) { defer func() { if done != nil { @@ -268,7 +276,7 @@ func makeTestHandler(t *testing.T, c *wireConduit, done chan struct{}, rounds [] } }() if c == nil { - c = &wireConduit{i: i} + c = &wireConduit{i: i, newValue: newValue} } else { c.i = i } @@ -288,7 +296,7 @@ func TestWireConduit(t *testing.T) { hs[n] = types.RandomHash() } fp := types.Hash12(hs[2][:12]) - srvHandler := makeTestHandler(t, nil, nil, []fakeRound{ + srvHandler := makeTestHandler(t, nil, func() any { return new(fakeValue) }, nil, []fakeRound{ { name: "server got 1st request", expectMsgs: []SyncMessage{ @@ -322,11 +330,13 @@ func TestWireConduit(t *testing.T) { { name: "server got 2nd request", expectMsgs: []SyncMessage{ - &ItemBatchMessage{ - Contents: []types.Hash32{hs[9], hs[10]}, + &decodedItemBatchMessage{ + ContentKeys: []types.Hash32{hs[9], hs[10]}, + ContentValues: []any{mkFakeValue(hs[9]), mkFakeValue(hs[10])}, }, - &ItemBatchMessage{ - Contents: []types.Hash32{hs[11]}, + &decodedItemBatchMessage{ + ContentKeys: []types.Hash32{hs[11]}, + ContentValues: []any{mkFakeValue(hs[11])}, }, &EndRoundMessage{}, }, @@ -351,6 +361,7 @@ func TestWireConduit(t *testing.T) { client := newFakeRequester("client", nil, srv) var c wireConduit + c.newValue = func() any { return new(fakeValue) } initReq, err := c.withInitialRequest(func(c Conduit) error { if err := c.SendFingerprint(hs[0], hs[1], fp, 4); err != nil { return err @@ -359,7 +370,7 @@ func TestWireConduit(t *testing.T) { }) require.NoError(t, err) done := make(chan struct{}) - clientHandler := makeTestHandler(t, &c, done, []fakeRound{ + clientHandler := makeTestHandler(t, &c, c.newValue, done, []fakeRound{ { name: "client got 1st response", expectMsgs: []SyncMessage{ @@ -373,11 +384,13 @@ func TestWireConduit(t *testing.T) { RangeY: hs[6], NumItems: 2, }, - &ItemBatchMessage{ - Contents: []types.Hash32{hs[4], hs[5]}, + &decodedItemBatchMessage{ + ContentKeys: []types.Hash32{hs[4], hs[5]}, + ContentValues: []any{mkFakeValue(hs[4]), mkFakeValue(hs[5])}, }, - &ItemBatchMessage{ - Contents: []types.Hash32{hs[7], hs[8]}, + &decodedItemBatchMessage{ + ContentKeys: []types.Hash32{hs[7], hs[8]}, + ContentValues: []any{mkFakeValue(hs[7]), mkFakeValue(hs[8])}, }, &EndRoundMessage{}, }, @@ -398,7 +411,8 @@ func TestWireConduit(t *testing.T) { }, }) err = client.InteractiveRequest(context.Background(), "srv", initReq, clientHandler, func(err error) { - require.FailNow(t, "fail handler called", "error: %v", err) + t.Errorf("fail handler called: %v", err) + close(done) }) require.NoError(t, err) <-done @@ -416,8 +430,8 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) requester { maxNumSpecificB: 100, } var client requester - verifyXORSync(t, cfg, func(syncA, syncB *RangeSetReconciler, numSpecific int) { - srvHandler := MakeServerHandler(syncA) + verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []Option) { + srvHandler := MakeServerHandler(storeA, opts...) srv, srvPeerID := getRequester("srv", srvHandler) var eg errgroup.Group ctx, cancel := context.WithCancel(context.Background()) @@ -430,7 +444,7 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) requester { }) client, _ = getRequester("client", nil, srv) - err := SyncStore(ctx, client, srvPeerID, syncB) + err := SyncStore(ctx, client, srvPeerID, storeB, opts...) require.NoError(t, err) if fr, ok := client.(*fakeRequester); ok { @@ -475,3 +489,5 @@ func TestWireSync(t *testing.T) { }) }) } + +// TODO: test fail handler diff --git a/hashsync/rangesync.go b/hashsync/rangesync.go index afdbdbc468..e1087eb917 100644 --- a/hashsync/rangesync.go +++ b/hashsync/rangesync.go @@ -47,7 +47,8 @@ type SyncMessage interface { Y() Ordered Fingerprint() any Count() int - Items() []Ordered + Keys() []Ordered + Values() []any } func SyncMessageToString(m SyncMessage) string { @@ -66,13 +67,16 @@ func SyncMessageToString(m SyncMessage) string { if fp := m.Fingerprint(); fp != nil { sb.WriteString(" FP=" + fp.(fmt.Stringer).String()) } - for _, item := range m.Items() { - sb.WriteString(" item=" + item.(fmt.Stringer).String()) + vals := m.Values() + for n, k := range m.Keys() { + fmt.Fprintf(&sb, " item=[%s:%#v]", k.(fmt.Stringer).String(), vals[n]) } sb.WriteString(">") return sb.String() } +type NewValueFunc func() any + // Conduit handles receiving and sending peer messages type Conduit interface { // NextMessage returns the next SyncMessage, or nil if there @@ -120,10 +124,13 @@ func WithItemChunkSize(n int) Option { type Iterator interface { // Equal returns true if this iterator is equal to another Iterator Equal(other Iterator) bool - // Key returns the key corresponding to iterator + // Key returns the key corresponding to iterator position. It returns + // nil if the ItemStore is empty Key() Ordered + // Value returns the value corresponding to the iterator. It returns nil + // if the ItemStore is empty + Value() any // Next advances the iterator - // TODO: should return bool Next() } @@ -134,8 +141,8 @@ type RangeInfo struct { } type ItemStore interface { - // Add adds a key to the store - Add(k Ordered) + // Add adds a key-value pair to the store + Add(k Ordered, v any) // GetRangeInfo returns RangeInfo for the item range in the tree. // If count >= 0, at most count items are returned, and RangeInfo // is returned for the corresponding subrange of the requested range @@ -146,6 +153,8 @@ type ItemStore interface { // Max returns the iterator pointing at the maximum element // in the store. If the store is empty, it returns nil Max() Iterator + // New returns an empty payload value + New() any } type RangeSetReconciler struct { @@ -372,8 +381,9 @@ func (rsr *RangeSetReconciler) Process(c Conduit) (done bool, err error) { done = true for _, msg := range msgs { if msg.Type() == MessageTypeItemBatch { - for _, item := range msg.Items() { - rsr.is.Add(item) + vals := msg.Values() + for n, k := range msg.Keys() { + rsr.is.Add(k, vals[n]) } continue } diff --git a/hashsync/rangesync_test.go b/hashsync/rangesync_test.go index 4b5f2a29db..ea8c5593d2 100644 --- a/hashsync/rangesync_test.go +++ b/hashsync/rangesync_test.go @@ -11,21 +11,23 @@ import ( ) type rangeMessage struct { - mtype MessageType - x, y Ordered - fp any - count int - items []Ordered + mtype MessageType + x, y Ordered + fp any + count int + keys []Ordered + values []any } +var _ SyncMessage = rangeMessage{} + func (m rangeMessage) Type() MessageType { return m.mtype } func (m rangeMessage) X() Ordered { return m.x } func (m rangeMessage) Y() Ordered { return m.y } func (m rangeMessage) Fingerprint() any { return m.fp } func (m rangeMessage) Count() int { return m.count } -func (m rangeMessage) Items() []Ordered { return m.items } - -var _ SyncMessage = rangeMessage{} +func (m rangeMessage) Keys() []Ordered { return m.keys } +func (m rangeMessage) Values() []any { return m.values } func (m rangeMessage) String() string { return SyncMessageToString(m) @@ -42,7 +44,7 @@ var _ Conduit = &fakeConduit{} func (fc *fakeConduit) numItems() int { n := 0 for _, m := range fc.msgs { - n += len(m.Items()) + n += len(m.Keys()) } return n } @@ -115,7 +117,8 @@ func (fc *fakeConduit) SendItems(count, itemChunkSize int, it Iterator) error { if it.Key() == nil { panic("fakeConduit.SendItems: went got to the end of the tree") } - msg.items = append(msg.items, it.Key()) + msg.keys = append(msg.keys, it.Key()) + msg.values = append(msg.values, it.Value()) it.Next() n-- } @@ -150,56 +153,70 @@ func (it *dumbStoreIterator) Equal(other Iterator) bool { } func (it *dumbStoreIterator) Key() Ordered { - return it.ds.items[it.n] + return it.ds.keys[it.n] +} + +func (it *dumbStoreIterator) Value() any { + if len(it.ds.keys) == 0 { + return nil + } + return it.ds.m[it.Key().(sampleID)] } func (it *dumbStoreIterator) Next() { - if len(it.ds.items) != 0 { - it.n = (it.n + 1) % len(it.ds.items) + if len(it.ds.keys) != 0 { + it.n = (it.n + 1) % len(it.ds.keys) } } type dumbStore struct { - items []sampleID + keys []sampleID + m map[sampleID]any } var _ ItemStore = &dumbStore{} -func (ds *dumbStore) Add(k Ordered) { +func (ds *dumbStore) Add(k Ordered, v any) { + if ds.m == nil { + ds.m = make(map[sampleID]any) + } id := k.(sampleID) - if len(ds.items) == 0 { - ds.items = []sampleID{id} + if len(ds.keys) == 0 { + ds.keys = []sampleID{id} + ds.m[id] = v return } - p := slices.IndexFunc(ds.items, func(other sampleID) bool { + p := slices.IndexFunc(ds.keys, func(other sampleID) bool { return other >= id }) switch { case p < 0: - ds.items = append(ds.items, id) - case id == ds.items[p]: + ds.keys = append(ds.keys, id) + ds.m[id] = v + case id == ds.keys[p]: // already present default: - ds.items = slices.Insert(ds.items, p, id) + ds.keys = slices.Insert(ds.keys, p, id) + ds.m[id] = v } } func (ds *dumbStore) iter(n int) Iterator { - if n == -1 || n == len(ds.items) { + if n == -1 || n == len(ds.keys) { return nil } return &dumbStoreIterator{ds: ds, n: n} } func (ds *dumbStore) last() sampleID { - if len(ds.items) == 0 { + if len(ds.keys) == 0 { panic("can't get the last element: zero items") } - return ds.items[len(ds.items)-1] + return ds.keys[len(ds.keys)-1] } func (ds *dumbStore) iterFor(s sampleID) Iterator { - n := slices.Index(ds.items, s) + n := slices.Index(ds.keys, s) if n == -1 { panic("item not found: " + s) } @@ -229,7 +246,7 @@ func (ds *dumbStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) R } func (ds *dumbStore) Min() Iterator { - if len(ds.items) == 0 { + if len(ds.keys) == 0 { return nil } return &dumbStoreIterator{ @@ -239,15 +256,19 @@ func (ds *dumbStore) Min() Iterator { } func (ds *dumbStore) Max() Iterator { - if len(ds.items) == 0 { + if len(ds.keys) == 0 { return nil } return &dumbStoreIterator{ ds: ds, - n: len(ds.items) - 1, + n: len(ds.keys) - 1, } } +func (it *dumbStore) New() any { + panic("not implemented") +} + type verifiedStoreIterator struct { t *testing.T knownGood Iterator @@ -274,6 +295,13 @@ func (it verifiedStoreIterator) Key() Ordered { return k2 } +func (it verifiedStoreIterator) Value() any { + v1 := it.knownGood.Value() + v2 := it.it.Value() + assert.Equal(it.t, v1, v2, "values") + return v2 +} + func (it verifiedStoreIterator) Next() { it.knownGood.Next() it.it.Next() @@ -296,7 +324,7 @@ func disableReAdd(s ItemStore) { } } -func (vs *verifiedStore) Add(k Ordered) { +func (vs *verifiedStore) Add(k Ordered, v any) { if vs.disableReAdd { _, found := vs.added[k.(sampleID)] require.False(vs.t, found, "hash sent twice: %v", k) @@ -305,8 +333,8 @@ func (vs *verifiedStore) Add(k Ordered) { } vs.added[k.(sampleID)] = struct{}{} } - vs.knownGood.Add(k) - vs.store.Add(k) + vs.knownGood.Add(k, v) + vs.store.Add(k, v) } func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo { @@ -358,7 +386,7 @@ func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count in func (vs *verifiedStore) Min() Iterator { m1 := vs.knownGood.Min() - m2 := vs.knownGood.Min() + m2 := vs.store.Min() if m1 == nil { require.Nil(vs.t, m2, "Min") return nil @@ -375,7 +403,7 @@ func (vs *verifiedStore) Min() Iterator { func (vs *verifiedStore) Max() Iterator { m1 := vs.knownGood.Max() - m2 := vs.knownGood.Max() + m2 := vs.store.Max() if m1 == nil { require.Nil(vs.t, m2, "Max") return nil @@ -390,6 +418,13 @@ func (vs *verifiedStore) Max() Iterator { } } +func (vs *verifiedStore) New() any { + v1 := vs.knownGood.New() + v2 := vs.store.New() + require.Equal(vs.t, v1, v2, "New") + return v2 +} + type storeFactory func(t *testing.T) ItemStore func makeDumbStore(t *testing.T) ItemStore { @@ -397,7 +432,10 @@ func makeDumbStore(t *testing.T) ItemStore { } func makeSyncTreeStore(t *testing.T) ItemStore { - return NewSyncTreeStore(sampleMonoid{}) + return NewSyncTreeStore(sampleMonoid{}, nil, func() any { + // newValue func is only called by wireConduit + panic("not implemented") + }) } func makeVerifiedSyncTreeStore(t *testing.T) ItemStore { @@ -411,7 +449,7 @@ func makeVerifiedSyncTreeStore(t *testing.T) ItemStore { func makeStore(t *testing.T, f storeFactory, items string) ItemStore { s := f(t) for _, c := range items { - s.Add(sampleID(c)) + s.Add(sampleID(c), "") } return s } diff --git a/hashsync/sync_trees_store.go b/hashsync/sync_trees_store.go index 8ff6f20a02..a6f7d2e2e1 100644 --- a/hashsync/sync_trees_store.go +++ b/hashsync/sync_trees_store.go @@ -1,8 +1,24 @@ package hashsync +type ValueHandler interface { + Load(k Ordered, treeValue any) (v any) + Store(k Ordered, v any) (treeValue any) +} + +type defaultValueHandler struct{} + +func (vh defaultValueHandler) Load(k Ordered, treeValue any) (v any) { + return treeValue +} + +func (vh defaultValueHandler) Store(k Ordered, v any) (treeValue any) { + return v +} + type syncTreeIterator struct { st SyncTree ptr SyncTreePointer + vh ValueHandler } var _ Iterator = &syncTreeIterator{} @@ -19,6 +35,10 @@ func (it *syncTreeIterator) Key() Ordered { return it.ptr.Key() } +func (it *syncTreeIterator) Value() any { + return it.vh.Load(it.ptr.Key(), it.ptr.Value()) +} + func (it *syncTreeIterator) Next() { it.ptr.Next() if it.ptr.Key() == nil { @@ -27,20 +47,28 @@ func (it *syncTreeIterator) Next() { } type SyncTreeStore struct { - st SyncTree + st SyncTree + vh ValueHandler + newValue NewValueFunc } var _ ItemStore = &SyncTreeStore{} -func NewSyncTreeStore(m Monoid) ItemStore { +func NewSyncTreeStore(m Monoid, vh ValueHandler, newValue NewValueFunc) ItemStore { + if vh == nil { + vh = defaultValueHandler{} + } return &SyncTreeStore{ - st: NewSyncTree(CombineMonoids(m, CountingMonoid{})), + st: NewSyncTree(CombineMonoids(m, CountingMonoid{})), + vh: vh, + newValue: newValue, } } // Add implements ItemStore. -func (sts *SyncTreeStore) Add(k Ordered) { - sts.st.Add(k) +func (sts *SyncTreeStore) Add(k Ordered, v any) { + treeValue := sts.vh.Store(k, v) + sts.st.Set(k, treeValue) } func (sts *SyncTreeStore) iter(ptr SyncTreePointer) Iterator { @@ -50,6 +78,7 @@ func (sts *SyncTreeStore) iter(ptr SyncTreePointer) Iterator { return &syncTreeIterator{ st: sts.st, ptr: ptr, + vh: sts.vh, } } @@ -88,3 +117,8 @@ func (sts *SyncTreeStore) Min() Iterator { func (sts *SyncTreeStore) Max() Iterator { return sts.iter(sts.st.Max()) } + +// New implements ItemStore. +func (sts *SyncTreeStore) New() any { + return sts.newValue() +} diff --git a/hashsync/wire_types.go b/hashsync/wire_types.go index c5b6f5a5a3..39a21c152d 100644 --- a/hashsync/wire_types.go +++ b/hashsync/wire_types.go @@ -12,7 +12,8 @@ func (*Marker) X() Ordered { return nil } func (*Marker) Y() Ordered { return nil } func (*Marker) Fingerprint() any { return nil } func (*Marker) Count() int { return 0 } -func (*Marker) Items() []Ordered { return nil } +func (*Marker) Keys() []Ordered { return nil } +func (*Marker) Values() []any { return nil } // DoneMessage is a SyncMessage that denotes the end of the synchronization. // The peer should stop any further processing after receiving this message. @@ -50,7 +51,8 @@ func (m *EmptyRangeMessage) X() Ordered { return m.RangeX } func (m *EmptyRangeMessage) Y() Ordered { return m.RangeY } func (m *EmptyRangeMessage) Fingerprint() any { return nil } func (m *EmptyRangeMessage) Count() int { return 0 } -func (m *EmptyRangeMessage) Items() []Ordered { return nil } +func (m *EmptyRangeMessage) Keys() []Ordered { return nil } +func (m *EmptyRangeMessage) Values() []any { return nil } // FingerprintMessage contains range fingerprint for comparison against the // peer's fingerprint of the range with the same bounds [RangeX, RangeY) @@ -67,7 +69,8 @@ func (m *FingerprintMessage) X() Ordered { return m.RangeX } func (m *FingerprintMessage) Y() Ordered { return m.RangeY } func (m *FingerprintMessage) Fingerprint() any { return m.RangeFingerprint } func (m *FingerprintMessage) Count() int { return int(m.NumItems) } -func (m *FingerprintMessage) Items() []Ordered { return nil } +func (m *FingerprintMessage) Keys() []Ordered { return nil } +func (m *FingerprintMessage) Values() []any { return nil } // RangeContentsMessage denotes a range for which the set of items has been sent. // The peer needs to send back any items it has in the same range bounded @@ -84,26 +87,18 @@ func (m *RangeContentsMessage) X() Ordered { return m.RangeX } func (m *RangeContentsMessage) Y() Ordered { return m.RangeY } func (m *RangeContentsMessage) Fingerprint() any { return nil } func (m *RangeContentsMessage) Count() int { return int(m.NumItems) } -func (m *RangeContentsMessage) Items() []Ordered { return nil } +func (m *RangeContentsMessage) Keys() []Ordered { return nil } +func (m *RangeContentsMessage) Values() []any { return nil } -// ItemBatchMessage denotes a batch of items to be added to the peer's set +// ItemBatchMessage denotes a batch of items to be added to the peer's set. +// ItemBatchMessage doesn't implement SyncMessage interface by itself +// and needs to be wrapped in TypedItemBatchMessage[T] that implements +// SyncMessage by providing the proper Values() method type ItemBatchMessage struct { - Contents []types.Hash32 `scale:"max=1024"` + ContentKeys []types.Hash32 `scale:"max=1024"` + ContentValues []byte `scale:"max=1024"` } -var _ SyncMessage = &ItemBatchMessage{} - func (m *ItemBatchMessage) Type() MessageType { return MessageTypeItemBatch } -func (m *ItemBatchMessage) X() Ordered { return nil } -func (m *ItemBatchMessage) Y() Ordered { return nil } -func (m *ItemBatchMessage) Fingerprint() any { return nil } -func (m *ItemBatchMessage) Count() int { return 0 } -func (m *ItemBatchMessage) Items() []Ordered { - r := make([]Ordered, len(m.Contents)) - for n, item := range m.Contents { - r[n] = item - } - return r -} // TODO: don't do scalegen for empty types diff --git a/hashsync/wire_types_scale.go b/hashsync/wire_types_scale.go index bd51c04cdc..91c863f951 100644 --- a/hashsync/wire_types_scale.go +++ b/hashsync/wire_types_scale.go @@ -236,7 +236,14 @@ func (t *RangeContentsMessage) DecodeScale(dec *scale.Decoder) (total int, err e func (t *ItemBatchMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { { - n, err := scale.EncodeStructSliceWithLimit(enc, t.Contents, 1024) + n, err := scale.EncodeStructSliceWithLimit(enc, t.ContentKeys, 1024) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteSliceWithLimit(enc, t.ContentValues, 1024) if err != nil { return total, err } @@ -252,7 +259,15 @@ func (t *ItemBatchMessage) DecodeScale(dec *scale.Decoder) (total int, err error return total, err } total += n - t.Contents = field + t.ContentKeys = field + } + { + field, n, err := scale.DecodeByteSliceWithLimit(dec, 1024) + if err != nil { + return total, err + } + total += n + t.ContentValues = field } return total, nil } diff --git a/hashsync/xorsync_test.go b/hashsync/xorsync_test.go index ba43ef5611..58278c3735 100644 --- a/hashsync/xorsync_test.go +++ b/hashsync/xorsync_test.go @@ -5,6 +5,7 @@ import ( "slices" "testing" + "github.com/spacemeshos/go-scale" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -29,14 +30,22 @@ func TestHash32To12Xor(t *testing.T) { require.Equal(t, m.Op(m.Op(fp1, fp2), fp3), m.Op(fp1, m.Op(fp2, fp3))) } -func collectStoreItems[T Ordered](is ItemStore) (r []T) { +type pair[K any, V any] struct { + k K + v V +} + +func collectStoreItems[K Ordered, V any](is ItemStore) (r []pair[K, V]) { it := is.Min() if it == nil { return nil } endAt := is.Min() for { - r = append(r, it.Key().(T)) + r = append(r, pair[K, V]{ + k: it.Key().(K), + v: it.Value().(V), + }) it.Next() if it.Equal(endAt) { return r @@ -50,11 +59,11 @@ type catchTransferTwice struct { added map[types.Hash32]bool } -func (s *catchTransferTwice) Add(k Ordered) { +func (s *catchTransferTwice) Add(k Ordered, v any) { h := k.(types.Hash32) _, found := s.added[h] assert.False(s.t, found, "hash sent twice") - s.ItemStore.Add(k) + s.ItemStore.Add(k, v) if s.added == nil { s.added = make(map[types.Hash32]bool) } @@ -70,7 +79,31 @@ type xorSyncTestConfig struct { maxNumSpecificB int } -func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(syncA, syncB *RangeSetReconciler, numSpecific int)) { +type fakeValue struct { + v string +} + +var _ scale.Decodable = &fakeValue{} +var _ scale.Encodable = &fakeValue{} + +func mkFakeValue(h types.Hash32) *fakeValue { + return &fakeValue{v: h.String()} +} + +func (fv *fakeValue) DecodeScale(dec *scale.Decoder) (total int, err error) { + s, total, err := scale.DecodeString(dec) + fv.v = s + return total, err +} + +func (fv *fakeValue) EncodeScale(enc *scale.Encoder) (total int, err error) { + return scale.EncodeString(enc, fv.v) +} + +func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB ItemStore, numSpecific int, opts []Option)) { + opts := []Option{ + WithMaxSendRange(cfg.maxSendRange), + } numSpecificA := rand.Intn(cfg.maxNumSpecificA+1-cfg.minNumSpecificA) + cfg.minNumSpecificA numSpecificB := rand.Intn(cfg.maxNumSpecificB+1-cfg.minNumSpecificB) + cfg.minNumSpecificB src := make([]types.Hash32, cfg.numTestHashes) @@ -79,32 +112,37 @@ func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(syncA, syncB * } sliceA := src[:cfg.numTestHashes-numSpecificB] - storeA := NewSyncTreeStore(Hash32To12Xor{}) + storeA := NewSyncTreeStore(Hash32To12Xor{}, nil, func() any { return new(fakeValue) }) for _, h := range sliceA { - storeA.Add(h) + storeA.Add(h, mkFakeValue(h)) } storeA = &catchTransferTwice{t: t, ItemStore: storeA} - syncA := NewRangeSetReconciler(storeA, WithMaxSendRange(cfg.maxSendRange)) sliceB := append([]types.Hash32(nil), src[:cfg.numTestHashes-numSpecificB-numSpecificA]...) sliceB = append(sliceB, src[cfg.numTestHashes-numSpecificB:]...) - storeB := NewSyncTreeStore(Hash32To12Xor{}) + storeB := NewSyncTreeStore(Hash32To12Xor{}, nil, func() any { return new(fakeValue) }) for _, h := range sliceB { - storeB.Add(h) + storeB.Add(h, mkFakeValue(h)) } storeB = &catchTransferTwice{t: t, ItemStore: storeB} - syncB := NewRangeSetReconciler(storeB, WithMaxSendRange(cfg.maxSendRange)) slices.SortFunc(src, func(a, b types.Hash32) int { return a.Compare(b) }) - sync(syncA, syncB, numSpecificA+numSpecificB) + sync(storeA, storeB, numSpecificA+numSpecificB, opts) - itemsA := collectStoreItems[types.Hash32](storeA) - itemsB := collectStoreItems[types.Hash32](storeB) + itemsA := collectStoreItems[types.Hash32, *fakeValue](storeA) + itemsB := collectStoreItems[types.Hash32, *fakeValue](storeB) require.Equal(t, itemsA, itemsB) - require.Equal(t, src, itemsA) + srcPairs := make([]pair[types.Hash32, *fakeValue], len(src)) + for n, h := range src { + srcPairs[n] = pair[types.Hash32, *fakeValue]{ + k: h, + v: mkFakeValue(h), + } + } + require.Equal(t, srcPairs, itemsA) } func TestBigSyncHash32(t *testing.T) { @@ -116,7 +154,9 @@ func TestBigSyncHash32(t *testing.T) { minNumSpecificB: 4, maxNumSpecificB: 100, } - verifyXORSync(t, cfg, func(syncA, syncB *RangeSetReconciler, numSpecific int) { + verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []Option) { + syncA := NewRangeSetReconciler(storeA, opts...) + syncB := NewRangeSetReconciler(storeB, opts...) nRounds, nMsg, nItems := runSync(t, syncA, syncB, 100) itemCoef := float64(nItems) / float64(numSpecific) t.Logf("numSpecific: %d, nRounds: %d, nMsg: %d, nItems: %d, itemCoef: %.2f", From 57767f917e8303d61ce34b796433389a8dd46c6f Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 13 Jan 2024 08:07:59 +0400 Subject: [PATCH 13/76] hashsync: support bounded sync --- hashsync/rangesync.go | 19 +++++--- hashsync/rangesync_test.go | 90 ++++++++++++++++++++++++++++---------- 2 files changed, 80 insertions(+), 29 deletions(-) diff --git a/hashsync/rangesync.go b/hashsync/rangesync.go index e1087eb917..1bf4b721a9 100644 --- a/hashsync/rangesync.go +++ b/hashsync/rangesync.go @@ -315,25 +315,32 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg func (rsr *RangeSetReconciler) Initiate(c Conduit) error { it := rsr.is.Min() - if it == nil { + var x Ordered + if it != nil { + x = it.Key() + } + return rsr.InitiateBounded(c, x, x) +} + +func (rsr *RangeSetReconciler) InitiateBounded(c Conduit, x, y Ordered) error { + if x == nil { if err := c.SendEmptySet(); err != nil { return err } } else { - min := it.Key() - info := rsr.is.GetRangeInfo(nil, min, min, -1) + info := rsr.is.GetRangeInfo(nil, x, y, -1) switch { case info.Count == 0: panic("empty full min-min range") case info.Count < rsr.maxSendRange: - if err := c.SendRangeContents(min, min, info.Count); err != nil { + if err := c.SendRangeContents(x, y, info.Count); err != nil { return err } - if err := c.SendItems(info.Count, rsr.itemChunkSize, it); err != nil { + if err := c.SendItems(info.Count, rsr.itemChunkSize, info.Start); err != nil { return err } default: - if err := c.SendFingerprint(min, min, info.Fingerprint, info.Count); err != nil { + if err := c.SendFingerprint(x, y, info.Fingerprint, info.Count); err != nil { return err } } diff --git a/hashsync/rangesync_test.go b/hashsync/rangesync_test.go index ea8c5593d2..6663363159 100644 --- a/hashsync/rangesync_test.go +++ b/hashsync/rangesync_test.go @@ -506,23 +506,33 @@ func dumpRangeMessages(t *testing.T, msgs []rangeMessage, fmt string, args ...an func runSync(t *testing.T, syncA, syncB *RangeSetReconciler, maxRounds int) (nRounds, nMsg, nItems int) { fc := &fakeConduit{t: t} - syncA.Initiate(fc) + require.NoError(t, syncA.Initiate(fc)) + return doRunSync(fc, syncA, syncB, maxRounds) +} + +func runBoundedSync(t *testing.T, syncA, syncB *RangeSetReconciler, x, y Ordered, maxRounds int) (nRounds, nMsg, nItems int) { + fc := &fakeConduit{t: t} + require.NoError(t, syncA.InitiateBounded(fc, x, y)) + return doRunSync(fc, syncA, syncB, maxRounds) +} + +func doRunSync(fc *fakeConduit, syncA, syncB *RangeSetReconciler, maxRounds int) (nRounds, nMsg, nItems int) { var i int done := false - // dumpRangeMessages(t, fc.resp.msgs, "A %q -> B %q (init):", storeItemStr(syncA.is), storeItemStr(syncB.is)) - // dumpRangeMessages(t, fc.resp.msgs, "A -> B (init):") + // dumpRangeMessages(fc.t, fc.resp.msgs, "A %q -> B %q (init):", storeItemStr(syncA.is), storeItemStr(syncB.is)) + // dumpRangeMessages(fc.t, fc.resp.msgs, "A -> B (init):") for i = 0; !done; i++ { if i == maxRounds { - require.FailNow(t, "too many rounds", "didn't reconcile in %d rounds", i) + require.FailNow(fc.t, "too many rounds", "didn't reconcile in %d rounds", i) } fc = fc.resp nMsg += len(fc.msgs) nItems += fc.numItems() var err error done, err = syncB.Process(fc) - require.NoError(t, err) - // dumpRangeMessages(t, fc.resp.msgs, "B %q -> A %q:", storeItemStr(syncA.is), storeItemStr(syncB.is)) - // dumpRangeMessages(t, fc.resp.msgs, "B -> A:") + require.NoError(fc.t, err) + // dumpRangeMessages(fc.t, fc.resp.msgs, "B %q -> A %q:", storeItemStr(syncA.is), storeItemStr(syncB.is)) + // dumpRangeMessages(fc.t, fc.resp.msgs, "B -> A:") if done { break } @@ -530,9 +540,9 @@ func runSync(t *testing.T, syncA, syncB *RangeSetReconciler, maxRounds int) (nRo nMsg += len(fc.msgs) nItems += fc.numItems() done, err = syncA.Process(fc) - require.NoError(t, err) - // dumpRangeMessages(t, fc.msgs, "A %q --> B %q:", storeItemStr(syncB.is), storeItemStr(syncA.is)) - // dumpRangeMessages(t, fc.resp.msgs, "A -> B:") + require.NoError(fc.t, err) + // dumpRangeMessages(fc.t, fc.msgs, "A %q --> B %q:", storeItemStr(syncB.is), storeItemStr(syncA.is)) + // dumpRangeMessages(fc.t, fc.resp.msgs, "A -> B:") } return i + 1, nMsg, nItems } @@ -541,50 +551,78 @@ func testRangeSync(t *testing.T, storeFactory storeFactory) { for _, tc := range []struct { name string a, b string - final string + finalA string + finalB string maxRounds [4]int + x, y string }{ { name: "empty sets", a: "", b: "", - final: "", + finalA: "", + finalB: "", maxRounds: [4]int{1, 1, 1, 1}, }, { name: "empty to non-empty", a: "", b: "abcd", - final: "abcd", - maxRounds: [4]int{1, 1, 1, 1}, + finalA: "abcd", + finalB: "abcd", + maxRounds: [4]int{2, 2, 2, 2}, }, { name: "non-empty to empty", a: "abcd", b: "", - final: "abcd", + finalA: "abcd", + finalB: "abcd", maxRounds: [4]int{2, 2, 2, 2}, }, { name: "non-intersecting sets", a: "ab", b: "cd", - final: "abcd", + finalA: "abcd", + finalB: "abcd", maxRounds: [4]int{3, 2, 2, 2}, }, { name: "intersecting sets", a: "acdefghijklmn", b: "bcdopqr", - final: "abcdefghijklmnopqr", - maxRounds: [4]int{4, 4, 4, 3}, + finalA: "abcdefghijklmnopqr", + finalB: "abcdefghijklmnopqr", + maxRounds: [4]int{4, 4, 3, 3}, + }, + { + name: "bounded reconciliation", + a: "acdefghijklmn", + b: "bcdopqr", + finalA: "abcdefghijklmn", + finalB: "abcdefgopqr", + maxRounds: [4]int{3, 3, 2, 2}, + x: "a", + y: "h", + }, + { + name: "bounded reconciliation with rollover", + a: "acdefghijklmn", + b: "bcdopqr", + finalA: "acdefghijklmnopqr", + finalB: "bcdhijklmnopqr", + maxRounds: [4]int{4, 3, 3, 2}, + x: "h", + y: "a", }, { name: "sync against 1-element set", a: "bcd", b: "a", - final: "abcd", - maxRounds: [4]int{3, 2, 2, 1}, + finalA: "abcd", + finalB: "abcd", + maxRounds: [4]int{2, 2, 2, 2}, }, } { t.Run(tc.name, func(t *testing.T) { @@ -601,11 +639,17 @@ func testRangeSync(t *testing.T, storeFactory storeFactory) { WithMaxSendRange(maxSendRange), WithItemChunkSize(3)) - nRounds, _, _ := runSync(t, syncA, syncB, tc.maxRounds[n]) + var nRounds int + if tc.x == "" { + nRounds, _, _ = runSync(t, syncA, syncB, tc.maxRounds[n]) + } else { + nRounds, _, _ = runBoundedSync(t, syncA, syncB, + sampleID(tc.x), sampleID(tc.y), tc.maxRounds[n]) + } t.Logf("%s: maxSendRange %d: %d rounds", tc.name, maxSendRange, nRounds) - require.Equal(t, storeItemStr(storeA), storeItemStr(storeB)) - require.Equal(t, tc.final, storeItemStr(storeA)) + require.Equal(t, tc.finalA, storeItemStr(storeA)) + require.Equal(t, tc.finalB, storeItemStr(storeB)) } }) } From e3d3cd60bb61b0d02705ffcffade7d24321de2d1 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 13 Jan 2024 12:22:39 +0400 Subject: [PATCH 14/76] hashsync: implement sync probes --- hashsync/handler.go | 105 ++++++++- hashsync/handler_test.go | 163 +++++++++---- hashsync/rangesync.go | 58 ++++- hashsync/rangesync_test.go | 435 +++++++++++++++++++++-------------- hashsync/wire_types.go | 25 ++ hashsync/wire_types_scale.go | 38 +++ hashsync/xorsync_test.go | 27 +-- 7 files changed, 604 insertions(+), 247 deletions(-) diff --git a/hashsync/handler.go b/hashsync/handler.go index 9340a59d79..779cb83ac2 100644 --- a/hashsync/handler.go +++ b/hashsync/handler.go @@ -77,11 +77,6 @@ func decodeItemBatchMessage(m *ItemBatchMessage, newValue NewValueFunc) (*decode return d, nil } -type outboundMessage struct { - code MessageType // TODO: "mt" - msg codec.Encodable -} - type conduitState int type wireConduit struct { @@ -158,6 +153,12 @@ func (c *wireConduit) receive() (msgs []SyncMessage, err error) { return nil, err } msgs = append(msgs, &m) + case MessageTypeQuery: + var m QueryMessage + if _, err := codec.DecodeFrom(b, &m); err != nil { + return nil, err + } + msgs = append(msgs, &m) default: return nil, fmt.Errorf("invalid message code %02x", code) } @@ -211,7 +212,7 @@ func (c *wireConduit) NextMessage() (SyncMessage, error) { return msgs[0], nil } -func (c *wireConduit) SendFingerprint(x Ordered, y Ordered, fingerprint any, count int) error { +func (c *wireConduit) SendFingerprint(x, y Ordered, fingerprint any, count int) error { return c.send(&FingerprintMessage{ RangeX: x.(types.Hash32), RangeY: y.(types.Hash32), @@ -224,11 +225,11 @@ func (c *wireConduit) SendEmptySet() error { return c.send(&EmptySetMessage{}) } -func (c *wireConduit) SendEmptyRange(x Ordered, y Ordered) error { +func (c *wireConduit) SendEmptyRange(x, y Ordered) error { return c.send(&EmptyRangeMessage{RangeX: x.(types.Hash32), RangeY: y.(types.Hash32)}) } -func (c *wireConduit) SendRangeContents(x Ordered, y Ordered, count int) error { +func (c *wireConduit) SendRangeContents(x, y Ordered, count int) error { return c.send(&RangeContentsMessage{ RangeX: x.(types.Hash32), RangeY: y.(types.Hash32), @@ -268,6 +269,17 @@ func (c *wireConduit) SendDone() error { return c.send(&DoneMessage{}) } +func (c *wireConduit) SendQuery(x, y Ordered) error { + if x == nil && y == nil { + return c.send(&QueryMessage{}) + } else if x == nil || y == nil { + panic("BUG: SendQuery: bad range: just one of the bounds is nil") + } + xh := x.(types.Hash32) + yh := y.(types.Hash32) + return c.send(&QueryMessage{RangeX: &xh, RangeY: &yh}) +} + func (c *wireConduit) withInitialRequest(toCall func(Conduit) error) ([]byte, error) { c.initReqBuf = new(bytes.Buffer) defer func() { c.initReqBuf = nil }() @@ -311,11 +323,29 @@ func MakeServerHandler(is ItemStore, opts ...Option) server.InteractiveHandler { } } +func BoundedSyncStore(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, x, y types.Hash32, opts ...Option) error { + return syncStore(ctx, r, peer, is, &x, &y, opts) +} + func SyncStore(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, opts ...Option) error { + return syncStore(ctx, r, peer, is, nil, nil, opts) +} + +func syncStore(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, x, y *types.Hash32, opts []Option) error { c := wireConduit{newValue: is.New} rsr := NewRangeSetReconciler(is, opts...) // c.rmmePrint = true - initReq, err := c.withInitialRequest(rsr.Initiate) + var ( + initReq []byte + err error + ) + if x == nil { + initReq, err = c.withInitialRequest(rsr.Initiate) + } else { + initReq, err = c.withInitialRequest(func(c Conduit) error { + return rsr.InitiateBounded(c, *x, *y) + }) + } if err != nil { return err } @@ -330,12 +360,69 @@ func SyncStore(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, op } select { case <-ctx.Done(): + <-done return ctx.Err() case <-done: return reqErr } } +func Probe(ctx context.Context, r requester, peer p2p.Peer, opts ...Option) (fp any, count int, err error) { + return boundedProbe(ctx, r, peer, nil, nil, opts) +} + +func BoundedProbe(ctx context.Context, r requester, peer p2p.Peer, x, y types.Hash32, opts ...Option) (fp any, count int, err error) { + return boundedProbe(ctx, r, peer, &x, &y, opts) +} + +func boundedProbe(ctx context.Context, r requester, peer p2p.Peer, x, y *types.Hash32, opts []Option) (fp any, count int, err error) { + c := wireConduit{ + newValue: func() any { return nil }, // not used + } + rsr := NewRangeSetReconciler(nil, opts...) + // c.rmmePrint = true + var initReq []byte + if x == nil { + initReq, err = c.withInitialRequest(func(c Conduit) error { + return rsr.InitiateProbe(c) + }) + } else { + initReq, err = c.withInitialRequest(func(c Conduit) error { + return rsr.InitiateBoundedProbe(c, *x, *y) + }) + } + if err != nil { + return nil, 0, err + } + done := make(chan struct{}, 2) + h := func(ctx context.Context, i server.Interactor) (time.Duration, error) { + defer func() { + done <- struct{}{} + }() + c.i = i + var err error + fp, count, err = rsr.HandleProbeResponse(&c) + return 0, err + } + var reqErr error + if err = r.InteractiveRequest(ctx, peer, initReq, h, func(err error) { + reqErr = err + done <- struct{}{} + }); err != nil { + return nil, 0, err + } + select { + case <-ctx.Done(): + <-done + return nil, 0, ctx.Err() + case <-done: + if reqErr != nil { + return nil, 0, reqErr + } + return fp, count, nil + } +} + // TODO: request duration // TODO: validate counts // TODO: don't forget about Initiate!!! diff --git a/hashsync/handler_test.go b/hashsync/handler_test.go index 03f347a871..6cf55f37b4 100644 --- a/hashsync/handler_test.go +++ b/hashsync/handler_test.go @@ -420,6 +420,61 @@ func TestWireConduit(t *testing.T) { type getRequesterFunc func(name string, handler server.InteractiveHandler, peers ...requester) (requester, p2p.Peer) +func withClientServer( + storeA, storeB ItemStore, + getRequester getRequesterFunc, + opts []Option, + toCall func(ctx context.Context, client requester, srvPeerID p2p.Peer), +) { + srvHandler := MakeServerHandler(storeA, opts...) + srv, srvPeerID := getRequester("srv", srvHandler) + var eg errgroup.Group + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + cancel() + eg.Wait() + }() + eg.Go(func() error { + return srv.Run(ctx) + }) + + client, _ := getRequester("client", nil, srv) + toCall(ctx, client, srvPeerID) +} + +func fakeRequesterGetter(t *testing.T) getRequesterFunc { + return func(name string, handler server.InteractiveHandler, peers ...requester) (requester, p2p.Peer) { + pid := p2p.Peer(name) + return newFakeRequester(pid, handler, peers...), pid + } +} + +func p2pRequesterGetter(t *testing.T) getRequesterFunc { + mesh, err := mocknet.FullMeshConnected(2) + require.NoError(t, err) + proto := "itest" + opts := []server.Opt{ + server.WithTimeout(10 * time.Second), + server.WithLog(logtest.New(t)), + } + return func(name string, handler server.InteractiveHandler, peers ...requester) (requester, p2p.Peer) { + if len(peers) == 0 { + return server.New(mesh.Hosts()[0], proto, handler, opts...), mesh.Hosts()[0].ID() + } + s := server.New(mesh.Hosts()[1], proto, handler, opts...) + // TODO: this 'Eventually' is somewhat misplaced + require.Eventually(t, func() bool { + for _, h := range mesh.Hosts()[0:] { + if len(h.Mux().Protocols()) == 0 { + return false + } + } + return true + }, time.Second, 10*time.Millisecond) + return s, mesh.Hosts()[1].ID() + } +} + func testWireSync(t *testing.T, getRequester getRequesterFunc) requester { cfg := xorSyncTestConfig{ maxSendRange: 1, @@ -430,64 +485,76 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) requester { maxNumSpecificB: 100, } var client requester - verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []Option) { - srvHandler := MakeServerHandler(storeA, opts...) - srv, srvPeerID := getRequester("srv", srvHandler) - var eg errgroup.Group - ctx, cancel := context.WithCancel(context.Background()) - defer func() { - cancel() - eg.Wait() - }() - eg.Go(func() error { - return srv.Run(ctx) - }) - - client, _ = getRequester("client", nil, srv) - err := SyncStore(ctx, client, srvPeerID, storeB, opts...) - require.NoError(t, err) - - if fr, ok := client.(*fakeRequester); ok { - t.Logf("numSpecific: %d, bytesSent %d, bytesReceived %d", - numSpecific, fr.bytesSent, fr.bytesReceived) - } + verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []Option) bool { + withClientServer( + storeA, storeB, getRequester, opts, + func(ctx context.Context, client requester, srvPeerID p2p.Peer) { + err := SyncStore(ctx, client, srvPeerID, storeB, opts...) + require.NoError(t, err) + + if fr, ok := client.(*fakeRequester); ok { + t.Logf("numSpecific: %d, bytesSent %d, bytesReceived %d", + numSpecific, fr.bytesSent, fr.bytesReceived) + } + }) + return true }) return client } func TestWireSync(t *testing.T) { t.Run("fake requester", func(t *testing.T) { - testWireSync(t, func(name string, handler server.InteractiveHandler, peers ...requester) (requester, p2p.Peer) { - pid := p2p.Peer(name) - return newFakeRequester(pid, handler, peers...), pid - }) + testWireSync(t, fakeRequesterGetter(t)) + }) + t.Run("p2p", func(t *testing.T) { + testWireSync(t, p2pRequesterGetter(t)) }) +} +func testWireProbe(t *testing.T, getRequester getRequesterFunc) requester { + cfg := xorSyncTestConfig{ + maxSendRange: 1, + numTestHashes: 32, + minNumSpecificA: 4, + maxNumSpecificA: 4, + minNumSpecificB: 4, + maxNumSpecificB: 4, + } + var client requester + verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []Option) bool { + withClientServer( + storeA, storeB, getRequester, opts, + func(ctx context.Context, client requester, srvPeerID p2p.Peer) { + minA := storeA.Min().Key() + infoA := storeA.GetRangeInfo(nil, minA, minA, -1) + fpA, countA, err := Probe(ctx, client, srvPeerID, opts...) + require.NoError(t, err) + require.Equal(t, infoA.Fingerprint, fpA) + require.Equal(t, infoA.Count, countA) + + minA = storeA.Min().Key() + partInfoA := storeA.GetRangeInfo(nil, minA, minA, infoA.Count/2) + x := partInfoA.Start.Key().(types.Hash32) + y := partInfoA.End.Key().(types.Hash32) + // partInfoA = storeA.GetRangeInfo(nil, x, y, -1) + fpA, countA, err = BoundedProbe(ctx, client, srvPeerID, x, y, opts...) + require.NoError(t, err) + require.Equal(t, partInfoA.Fingerprint, fpA) + require.Equal(t, partInfoA.Count, countA) + }) + return false + }) + return client +} + +func TestWireProbe(t *testing.T) { + t.Run("fake requester", func(t *testing.T) { + testWireProbe(t, fakeRequesterGetter(t)) + }) t.Run("p2p", func(t *testing.T) { - mesh, err := mocknet.FullMeshConnected(2) - require.NoError(t, err) - proto := "itest" - opts := []server.Opt{ - server.WithTimeout(10 * time.Second), - server.WithLog(logtest.New(t)), - } - testWireSync(t, func(name string, handler server.InteractiveHandler, peers ...requester) (requester, p2p.Peer) { - if len(peers) == 0 { - return server.New(mesh.Hosts()[0], proto, handler, opts...), mesh.Hosts()[0].ID() - } - s := server.New(mesh.Hosts()[1], proto, handler, opts...) - // TODO: this 'Eventually' is somewhat misplaced - require.Eventually(t, func() bool { - for _, h := range mesh.Hosts()[0:] { - if len(h.Mux().Protocols()) == 0 { - return false - } - } - return true - }, time.Second, 10*time.Millisecond) - return s, mesh.Hosts()[1].ID() - }) + testWireProbe(t, p2pRequesterGetter(t)) }) } +// TODO: test bounded sync // TODO: test fail handler diff --git a/hashsync/rangesync.go b/hashsync/rangesync.go index 1bf4b721a9..e9e7483ec5 100644 --- a/hashsync/rangesync.go +++ b/hashsync/rangesync.go @@ -22,6 +22,7 @@ const ( MessageTypeFingerprint MessageTypeRangeContents MessageTypeItemBatch + MessageTypeQuery ) var messageTypes = []string{ @@ -104,6 +105,10 @@ type Conduit interface { SendEndRound() error // SendDone sends a message that notifies the peer that sync is finished SendDone() error + // SendQuery sends a message requesting fingerprint and count of the + // whole range or part of the range. The response will never contain any + // actual data items + SendQuery(x, y Ordered) error } type Option func(r *RangeSetReconciler) @@ -233,7 +238,7 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg x := msg.X() y := msg.Y() done = true - if msg.Type() == MessageTypeEmptySet { + if msg.Type() == MessageTypeEmptySet || (msg.Type() == MessageTypeQuery && x == nil && y == nil) { // The peer has no items at all so didn't // even send X & Y (SendEmptySet) it := rsr.is.Min() @@ -263,6 +268,11 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg return nil, false, err } } + case msg.Type() == MessageTypeQuery: + if err := c.SendFingerprint(x, y, info.Fingerprint, info.Count); err != nil { + return nil, false, err + } + return nil, true, nil case msg.Type() != MessageTypeFingerprint: return nil, false, fmt.Errorf("unexpected message type %s", msg.Type()) case fingerprintEqual(info.Fingerprint, msg.Fingerprint()): @@ -372,6 +382,52 @@ func (rsr *RangeSetReconciler) getMessages(c Conduit) (msgs []SyncMessage, done } } +func (rsr *RangeSetReconciler) InitiateProbe(c Conduit) error { + return rsr.InitiateBoundedProbe(c, nil, nil) +} + +func (rsr *RangeSetReconciler) InitiateBoundedProbe(c Conduit, x, y Ordered) error { + if err := c.SendQuery(x, y); err != nil { + return err + } + if err := c.SendEndRound(); err != nil { + return err + } + return nil +} + +func (rsr *RangeSetReconciler) HandleProbeResponse(c Conduit) (fp any, count int, err error) { + gotRange := false + for { + msg, err := c.NextMessage() + switch { + case err != nil: + return nil, 0, err + case msg == nil: + return nil, 0, errors.New("no end round marker") + default: + switch mt := msg.Type(); mt { + case MessageTypeEndRound: + return nil, 0, errors.New("non-final round in response to a probe") + case MessageTypeDone: + // the peer is not expecting any new messages + return fp, count, nil + case MessageTypeFingerprint: + fp = msg.Fingerprint() + count = msg.Count() + fallthrough + case MessageTypeEmptySet, MessageTypeEmptyRange: + if gotRange { + return nil, 0, errors.New("single range message expected") + } + gotRange = true + default: + return nil, 0, fmt.Errorf("unexpected message type: %v", msg.Type()) + } + } + } +} + func (rsr *RangeSetReconciler) Process(c Conduit) (done bool, err error) { var msgs []SyncMessage // All of the messages need to be received before processing diff --git a/hashsync/rangesync_test.go b/hashsync/rangesync_test.go index 6663363159..373759ff7a 100644 --- a/hashsync/rangesync_test.go +++ b/hashsync/rangesync_test.go @@ -36,11 +36,16 @@ func (m rangeMessage) String() string { type fakeConduit struct { t *testing.T msgs []rangeMessage - resp *fakeConduit + resp []rangeMessage } var _ Conduit = &fakeConduit{} +func (fc *fakeConduit) gotoResponse() { + fc.msgs = fc.resp + fc.resp = nil +} + func (fc *fakeConduit) numItems() int { n := 0 for _, m := range fc.msgs { @@ -59,14 +64,7 @@ func (fc *fakeConduit) NextMessage() (SyncMessage, error) { return nil, nil } -func (fc *fakeConduit) ensureResp() { - if fc.resp == nil { - fc.resp = &fakeConduit{t: fc.t} - } -} - func (fc *fakeConduit) sendMsg(mtype MessageType, x, y Ordered, fingerprint any, count int) { - fc.ensureResp() msg := rangeMessage{ mtype: mtype, x: x, @@ -74,7 +72,7 @@ func (fc *fakeConduit) sendMsg(mtype MessageType, x, y Ordered, fingerprint any, fp: fingerprint, count: count, } - fc.resp.msgs = append(fc.resp.msgs, msg) + fc.resp = append(fc.resp, msg) } func (fc *fakeConduit) SendFingerprint(x, y Ordered, fingerprint any, count int) error { @@ -109,7 +107,6 @@ func (fc *fakeConduit) SendItems(count, itemChunkSize int, it Iterator) error { require.Positive(fc.t, count) require.NotZero(fc.t, count) require.NotNil(fc.t, it) - fc.ensureResp() for i := 0; i < count; i += itemChunkSize { msg := rangeMessage{mtype: MessageTypeItemBatch} n := min(itemChunkSize, count-i) @@ -122,7 +119,7 @@ func (fc *fakeConduit) SendItems(count, itemChunkSize int, it Iterator) error { it.Next() n-- } - fc.resp.msgs = append(fc.resp.msgs, msg) + fc.resp = append(fc.resp, msg) } return nil } @@ -137,6 +134,11 @@ func (fc *fakeConduit) SendDone() error { return nil } +func (fc *fakeConduit) SendQuery(x, y Ordered) error { + fc.sendMsg(MessageTypeQuery, x, y, nil, 0) + return nil +} + type dumbStoreIterator struct { ds *dumbStore n int @@ -518,203 +520,284 @@ func runBoundedSync(t *testing.T, syncA, syncB *RangeSetReconciler, x, y Ordered func doRunSync(fc *fakeConduit, syncA, syncB *RangeSetReconciler, maxRounds int) (nRounds, nMsg, nItems int) { var i int - done := false + aDone, bDone := false, false // dumpRangeMessages(fc.t, fc.resp.msgs, "A %q -> B %q (init):", storeItemStr(syncA.is), storeItemStr(syncB.is)) // dumpRangeMessages(fc.t, fc.resp.msgs, "A -> B (init):") - for i = 0; !done; i++ { + for i = 0; ; i++ { if i == maxRounds { require.FailNow(fc.t, "too many rounds", "didn't reconcile in %d rounds", i) } - fc = fc.resp + fc.gotoResponse() nMsg += len(fc.msgs) nItems += fc.numItems() var err error - done, err = syncB.Process(fc) + bDone, err = syncB.Process(fc) require.NoError(fc.t, err) + // a party should never send anything in response to the "done" message + require.False(fc.t, aDone && !bDone, "A is done but B after that is not") // dumpRangeMessages(fc.t, fc.resp.msgs, "B %q -> A %q:", storeItemStr(syncA.is), storeItemStr(syncB.is)) // dumpRangeMessages(fc.t, fc.resp.msgs, "B -> A:") - if done { + if aDone && bDone { + require.Empty(fc.t, fc.resp, "got messages from B in response to done msg from A") break } - fc = fc.resp + fc.gotoResponse() nMsg += len(fc.msgs) nItems += fc.numItems() - done, err = syncA.Process(fc) + aDone, err = syncA.Process(fc) require.NoError(fc.t, err) // dumpRangeMessages(fc.t, fc.msgs, "A %q --> B %q:", storeItemStr(syncB.is), storeItemStr(syncA.is)) // dumpRangeMessages(fc.t, fc.resp.msgs, "A -> B:") + require.False(fc.t, bDone && !aDone, "B is done but A after that is not") + if aDone && bDone { + require.Empty(fc.t, fc.resp, "got messages from A in response to done msg from B") + break + } } return i + 1, nMsg, nItems } -func testRangeSync(t *testing.T, storeFactory storeFactory) { - for _, tc := range []struct { - name string - a, b string - finalA string - finalB string - maxRounds [4]int - x, y string - }{ - { - name: "empty sets", - a: "", - b: "", - finalA: "", - finalB: "", - maxRounds: [4]int{1, 1, 1, 1}, - }, - { - name: "empty to non-empty", - a: "", - b: "abcd", - finalA: "abcd", - finalB: "abcd", - maxRounds: [4]int{2, 2, 2, 2}, - }, - { - name: "non-empty to empty", - a: "abcd", - b: "", - finalA: "abcd", - finalB: "abcd", - maxRounds: [4]int{2, 2, 2, 2}, - }, - { - name: "non-intersecting sets", - a: "ab", - b: "cd", - finalA: "abcd", - finalB: "abcd", - maxRounds: [4]int{3, 2, 2, 2}, - }, - { - name: "intersecting sets", - a: "acdefghijklmn", - b: "bcdopqr", - finalA: "abcdefghijklmnopqr", - finalB: "abcdefghijklmnopqr", - maxRounds: [4]int{4, 4, 3, 3}, - }, - { - name: "bounded reconciliation", - a: "acdefghijklmn", - b: "bcdopqr", - finalA: "abcdefghijklmn", - finalB: "abcdefgopqr", - maxRounds: [4]int{3, 3, 2, 2}, - x: "a", - y: "h", - }, - { - name: "bounded reconciliation with rollover", - a: "acdefghijklmn", - b: "bcdopqr", - finalA: "acdefghijklmnopqr", - finalB: "bcdhijklmnopqr", - maxRounds: [4]int{4, 3, 3, 2}, - x: "h", - y: "a", - }, - { - name: "sync against 1-element set", - a: "bcd", - b: "a", - finalA: "abcd", - finalB: "abcd", - maxRounds: [4]int{2, 2, 2, 2}, - }, - } { - t.Run(tc.name, func(t *testing.T) { - for n, maxSendRange := range []int{1, 2, 3, 4} { - t.Logf("maxSendRange: %d", maxSendRange) - storeA := makeStore(t, storeFactory, tc.a) - disableReAdd(storeA) - syncA := NewRangeSetReconciler(storeA, - WithMaxSendRange(maxSendRange), - WithItemChunkSize(3)) - storeB := makeStore(t, storeFactory, tc.b) - disableReAdd(storeB) - syncB := NewRangeSetReconciler(storeB, - WithMaxSendRange(maxSendRange), - WithItemChunkSize(3)) - - var nRounds int - if tc.x == "" { - nRounds, _, _ = runSync(t, syncA, syncB, tc.maxRounds[n]) - } else { - nRounds, _, _ = runBoundedSync(t, syncA, syncB, - sampleID(tc.x), sampleID(tc.y), tc.maxRounds[n]) - } - t.Logf("%s: maxSendRange %d: %d rounds", tc.name, maxSendRange, nRounds) +func runProbe(t *testing.T, from, to *RangeSetReconciler) (fp any, count int) { + fc := &fakeConduit{t: t} + require.NoError(t, from.InitiateProbe(fc)) + return doRunProbe(fc, from, to) +} - require.Equal(t, tc.finalA, storeItemStr(storeA)) - require.Equal(t, tc.finalB, storeItemStr(storeB)) - } - }) - } +func runBoundedProbe(t *testing.T, from, to *RangeSetReconciler, x, y Ordered) (fp any, count int) { + fc := &fakeConduit{t: t} + require.NoError(t, from.InitiateBoundedProbe(fc, x, y)) + return doRunProbe(fc, from, to) } -func TestRangeSync(t *testing.T) { - forTestStores(t, testRangeSync) +func doRunProbe(fc *fakeConduit, from, to *RangeSetReconciler) (fp any, count int) { + require.NotEmpty(fc.t, fc.resp, "empty initial round") + fc.gotoResponse() + done, err := to.Process(fc) + require.True(fc.t, done) + require.NoError(fc.t, err) + fc.gotoResponse() + fp, count, err = from.HandleProbeResponse(fc) + require.NoError(fc.t, err) + require.Nil(fc.t, fc.resp, "got messages from Probe in response to done msg") + return fp, count } -func testRandomSync(t *testing.T, storeFactory storeFactory) { - var bytesA, bytesB []byte - defer func() { - if t.Failed() { - t.Logf("Random sync failed: %q <-> %q", bytesA, bytesB) - } - }() - for i := 0; i < 1000; i++ { - var chars []byte - for c := byte(33); c < 127; c++ { - chars = append(chars, c) +func TestRangeSync(t *testing.T) { + forTestStores(t, func(t *testing.T, storeFactory storeFactory) { + for _, tc := range []struct { + name string + a, b string + finalA string + finalB string + x, y string + countA int + countB int + fpA any + fpB any + maxRounds [4]int + }{ + { + name: "empty sets", + a: "", + b: "", + finalA: "", + finalB: "", + countA: 0, + countB: 0, + fpA: nil, + fpB: nil, + maxRounds: [4]int{1, 1, 1, 1}, + }, + { + name: "empty to non-empty", + a: "", + b: "abcd", + finalA: "abcd", + finalB: "abcd", + countA: 0, + countB: 4, + fpA: nil, + fpB: "abcd", + maxRounds: [4]int{2, 2, 2, 2}, + }, + { + name: "non-empty to empty", + a: "abcd", + b: "", + finalA: "abcd", + finalB: "abcd", + countA: 4, + countB: 0, + fpA: "abcd", + fpB: nil, + maxRounds: [4]int{2, 2, 2, 2}, + }, + { + name: "non-intersecting sets", + a: "ab", + b: "cd", + finalA: "abcd", + finalB: "abcd", + countA: 2, + countB: 2, + fpA: "ab", + fpB: "cd", + maxRounds: [4]int{3, 2, 2, 2}, + }, + { + name: "intersecting sets", + a: "acdefghijklmn", + b: "bcdopqr", + finalA: "abcdefghijklmnopqr", + finalB: "abcdefghijklmnopqr", + countA: 13, + countB: 7, + fpA: "acdefghijklmn", + fpB: "bcdopqr", + maxRounds: [4]int{4, 4, 3, 3}, + }, + { + name: "bounded reconciliation", + a: "acdefghijklmn", + b: "bcdopqr", + finalA: "abcdefghijklmn", + finalB: "abcdefgopqr", + x: "a", + y: "h", + countA: 6, + countB: 3, + fpA: "acdefg", + fpB: "bcd", + maxRounds: [4]int{3, 3, 2, 2}, + }, + { + name: "bounded reconciliation with rollover", + a: "acdefghijklmn", + b: "bcdopqr", + finalA: "acdefghijklmnopqr", + finalB: "bcdhijklmnopqr", + x: "h", + y: "a", + countA: 7, + countB: 4, + fpA: "hijklmn", + fpB: "opqr", + maxRounds: [4]int{4, 3, 3, 2}, + }, + { + name: "sync against 1-element set", + a: "bcd", + b: "a", + finalA: "abcd", + finalB: "abcd", + countA: 3, + countB: 1, + fpA: "bcd", + fpB: "a", + maxRounds: [4]int{2, 2, 2, 2}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + for n, maxSendRange := range []int{1, 2, 3, 4} { + t.Logf("maxSendRange: %d", maxSendRange) + storeA := makeStore(t, storeFactory, tc.a) + disableReAdd(storeA) + syncA := NewRangeSetReconciler(storeA, + WithMaxSendRange(maxSendRange), + WithItemChunkSize(3)) + storeB := makeStore(t, storeFactory, tc.b) + disableReAdd(storeB) + syncB := NewRangeSetReconciler(storeB, + WithMaxSendRange(maxSendRange), + WithItemChunkSize(3)) + + var ( + countA, countB, nRounds int + fpA, fpB any + ) + if tc.x == "" { + fpA, countA = runProbe(t, syncB, syncA) + fpB, countB = runProbe(t, syncA, syncB) + nRounds, _, _ = runSync(t, syncA, syncB, tc.maxRounds[n]) + } else { + x := sampleID(tc.x) + y := sampleID(tc.y) + fpA, countA = runBoundedProbe(t, syncB, syncA, x, y) + fpB, countB = runBoundedProbe(t, syncA, syncB, x, y) + nRounds, _, _ = runBoundedSync(t, syncA, syncB, x, y, tc.maxRounds[n]) + } + t.Logf("%s: maxSendRange %d: %d rounds", tc.name, maxSendRange, nRounds) + + require.Equal(t, tc.countA, countA, "countA") + require.Equal(t, tc.countB, countB, "countB") + require.Equal(t, tc.fpA, fpA, "fpA") + require.Equal(t, tc.fpB, fpB, "fpB") + require.Equal(t, tc.finalA, storeItemStr(storeA), "finalA") + require.Equal(t, tc.finalB, storeItemStr(storeB), "finalB") + } + }) } + }) +} - bytesA = append([]byte(nil), chars...) - rand.Shuffle(len(bytesA), func(i, j int) { - bytesA[i], bytesA[j] = bytesA[j], bytesA[i] - }) - bytesA = bytesA[:rand.Intn(len(bytesA))] - storeA := makeStore(t, storeFactory, string(bytesA)) +func TestRandomSync(t *testing.T) { + forTestStores(t, func(t *testing.T, storeFactory storeFactory) { + var bytesA, bytesB []byte + defer func() { + if t.Failed() { + t.Logf("Random sync failed: %q <-> %q", bytesA, bytesB) + } + }() + for i := 0; i < 1000; i++ { + var chars []byte + for c := byte(33); c < 127; c++ { + chars = append(chars, c) + } - bytesB = append([]byte(nil), chars...) - rand.Shuffle(len(bytesB), func(i, j int) { - bytesB[i], bytesB[j] = bytesB[j], bytesB[i] - }) - bytesB = bytesB[:rand.Intn(len(bytesB))] - storeB := makeStore(t, storeFactory, string(bytesB)) + bytesA = append([]byte(nil), chars...) + rand.Shuffle(len(bytesA), func(i, j int) { + bytesA[i], bytesA[j] = bytesA[j], bytesA[i] + }) + bytesA = bytesA[:rand.Intn(len(bytesA))] + storeA := makeStore(t, storeFactory, string(bytesA)) + + bytesB = append([]byte(nil), chars...) + rand.Shuffle(len(bytesB), func(i, j int) { + bytesB[i], bytesB[j] = bytesB[j], bytesB[i] + }) + bytesB = bytesB[:rand.Intn(len(bytesB))] + storeB := makeStore(t, storeFactory, string(bytesB)) + + keySet := make(map[byte]struct{}) + for _, c := range append(bytesA, bytesB...) { + keySet[byte(c)] = struct{}{} + } - keySet := make(map[byte]struct{}) - for _, c := range append(bytesA, bytesB...) { - keySet[byte(c)] = struct{}{} + expectedSet := maps.Keys(keySet) + slices.Sort(expectedSet) + + maxSendRange := rand.Intn(16) + 1 + syncA := NewRangeSetReconciler(storeA, + WithMaxSendRange(maxSendRange), + WithItemChunkSize(3)) + syncB := NewRangeSetReconciler(storeB, + WithMaxSendRange(maxSendRange), + WithItemChunkSize(3)) + + runSync(t, syncA, syncB, max(len(expectedSet), 2)) // FIXME: less rounds! + // t.Logf("maxSendRange %d a %d b %d n %d", maxSendRange, len(bytesA), len(bytesB), n) + require.Equal(t, storeItemStr(storeA), storeItemStr(storeB)) + require.Equal(t, string(expectedSet), storeItemStr(storeA), + "expected set for %q<->%q", bytesA, bytesB) } - - expectedSet := maps.Keys(keySet) - slices.Sort(expectedSet) - - maxSendRange := rand.Intn(16) + 1 - syncA := NewRangeSetReconciler(storeA, - WithMaxSendRange(maxSendRange), - WithItemChunkSize(3)) - syncB := NewRangeSetReconciler(storeB, - WithMaxSendRange(maxSendRange), - WithItemChunkSize(3)) - - runSync(t, syncA, syncB, max(len(expectedSet), 2)) // FIXME: less rounds! - // t.Logf("maxSendRange %d a %d b %d n %d", maxSendRange, len(bytesA), len(bytesB), n) - require.Equal(t, storeItemStr(storeA), storeItemStr(storeB)) - require.Equal(t, string(expectedSet), storeItemStr(storeA), - "expected set for %q<->%q", bytesA, bytesB) - } -} - -func TestRandomSync(t *testing.T) { - forTestStores(t, testRandomSync) + }) } -// TBD: include initiate round!!! +// TBD: make sure that requests with MessageTypeDone are never +// answered!!! // TBD: use logger for verbose logging (messages) -// TBD: in fakeConduit -- check item count against the iterator in SendItems / SendItemsOnly!! -// TBD: record interaction using golden master in testRangeSync, together with N of rounds / msgs / items and don't check max rounds +// TBD: in fakeConduit -- check item count against the iterator in +// SendItems / SendItemsOnly!! +// TBD: record interaction using golden master in testRangeSync, for +// both probe and sync, together with N of rounds / msgs / items +// and don't check max rounds diff --git a/hashsync/wire_types.go b/hashsync/wire_types.go index 39a21c152d..839525bb48 100644 --- a/hashsync/wire_types.go +++ b/hashsync/wire_types.go @@ -101,4 +101,29 @@ type ItemBatchMessage struct { func (m *ItemBatchMessage) Type() MessageType { return MessageTypeItemBatch } +// QueryMessage requests bounded range fingerprint and count from the peer +type QueryMessage struct { + RangeX, RangeY *types.Hash32 +} + +var _ SyncMessage = &QueryMessage{} + +func (m *QueryMessage) Type() MessageType { return MessageTypeQuery } +func (m *QueryMessage) X() Ordered { + if m.RangeX == nil { + return nil + } + return *m.RangeX +} +func (m *QueryMessage) Y() Ordered { + if m.RangeY == nil { + return nil + } + return *m.RangeY +} +func (m *QueryMessage) Fingerprint() any { return nil } +func (m *QueryMessage) Count() int { return 0 } +func (m *QueryMessage) Keys() []Ordered { return nil } +func (m *QueryMessage) Values() []any { return nil } + // TODO: don't do scalegen for empty types diff --git a/hashsync/wire_types_scale.go b/hashsync/wire_types_scale.go index 91c863f951..fddb0ae1ae 100644 --- a/hashsync/wire_types_scale.go +++ b/hashsync/wire_types_scale.go @@ -271,3 +271,41 @@ func (t *ItemBatchMessage) DecodeScale(dec *scale.Decoder) (total int, err error } return total, nil } + +func (t *QueryMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeOption(enc, t.RangeX) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeOption(enc, t.RangeY) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *QueryMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + field, n, err := scale.DecodeOption[types.Hash32](dec) + if err != nil { + return total, err + } + total += n + t.RangeX = field + } + { + field, n, err := scale.DecodeOption[types.Hash32](dec) + if err != nil { + return total, err + } + total += n + t.RangeY = field + } + return total, nil +} diff --git a/hashsync/xorsync_test.go b/hashsync/xorsync_test.go index 58278c3735..9c9922925e 100644 --- a/hashsync/xorsync_test.go +++ b/hashsync/xorsync_test.go @@ -100,7 +100,7 @@ func (fv *fakeValue) EncodeScale(enc *scale.Encoder) (total int, err error) { return scale.EncodeString(enc, fv.v) } -func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB ItemStore, numSpecific int, opts []Option)) { +func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB ItemStore, numSpecific int, opts []Option) bool) { opts := []Option{ WithMaxSendRange(cfg.maxSendRange), } @@ -130,19 +130,19 @@ func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB return a.Compare(b) }) - sync(storeA, storeB, numSpecificA+numSpecificB, opts) - - itemsA := collectStoreItems[types.Hash32, *fakeValue](storeA) - itemsB := collectStoreItems[types.Hash32, *fakeValue](storeB) - require.Equal(t, itemsA, itemsB) - srcPairs := make([]pair[types.Hash32, *fakeValue], len(src)) - for n, h := range src { - srcPairs[n] = pair[types.Hash32, *fakeValue]{ - k: h, - v: mkFakeValue(h), + if sync(storeA, storeB, numSpecificA+numSpecificB, opts) { + itemsA := collectStoreItems[types.Hash32, *fakeValue](storeA) + itemsB := collectStoreItems[types.Hash32, *fakeValue](storeB) + require.Equal(t, itemsA, itemsB) + srcPairs := make([]pair[types.Hash32, *fakeValue], len(src)) + for n, h := range src { + srcPairs[n] = pair[types.Hash32, *fakeValue]{ + k: h, + v: mkFakeValue(h), + } } + require.Equal(t, srcPairs, itemsA) } - require.Equal(t, srcPairs, itemsA) } func TestBigSyncHash32(t *testing.T) { @@ -154,12 +154,13 @@ func TestBigSyncHash32(t *testing.T) { minNumSpecificB: 4, maxNumSpecificB: 100, } - verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []Option) { + verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []Option) bool { syncA := NewRangeSetReconciler(storeA, opts...) syncB := NewRangeSetReconciler(storeB, opts...) nRounds, nMsg, nItems := runSync(t, syncA, syncB, 100) itemCoef := float64(nItems) / float64(numSpecific) t.Logf("numSpecific: %d, nRounds: %d, nMsg: %d, nItems: %d, itemCoef: %.2f", numSpecific, nRounds, nMsg, nItems, itemCoef) + return true }) } From f8a021f26fc32fc7e6cae56afd963ef7d45ce81d Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sun, 24 Mar 2024 08:34:20 +0400 Subject: [PATCH 15/76] hashsync: convert from chunked streams to normal streams --- hashsync/handler.go | 271 ++++++++++++++------------------------- hashsync/handler_test.go | 165 ++++++++---------------- hashsync/interface.go | 2 +- hashsync/rangesync.go | 9 +- 4 files changed, 151 insertions(+), 296 deletions(-) diff --git a/hashsync/handler.go b/hashsync/handler.go index 779cb83ac2..5e51632bf4 100644 --- a/hashsync/handler.go +++ b/hashsync/handler.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "time" "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" @@ -38,6 +37,7 @@ func (m *decodedItemBatchMessage) Keys() []Ordered { } return r } + func (m *decodedItemBatchMessage) Values() []any { r := make([]any, len(m.ContentValues)) for n, v := range m.ContentValues { @@ -80,136 +80,86 @@ func decodeItemBatchMessage(m *ItemBatchMessage, newValue NewValueFunc) (*decode type conduitState int type wireConduit struct { - i server.Interactor - pendingMsgs []SyncMessage - initReqBuf *bytes.Buffer - newValue NewValueFunc + stream io.ReadWriter + initReqBuf *bytes.Buffer + newValue NewValueFunc // rmmePrint bool } var _ Conduit = &wireConduit{} -func (c *wireConduit) reset() { - c.pendingMsgs = nil -} - -// receive receives a single frame from the Interactor and decodes one -// or more SyncMessages from it. The frames contain just one message -// except for the initial frame which may contain multiple messages -// b/c of the way Server handles the initial request -func (c *wireConduit) receive() (msgs []SyncMessage, err error) { - data, err := c.i.Receive() - if err != nil { - return nil, err - } - if len(data) == 0 { - return nil, errors.New("zero length sync message") +// NextMessage implements Conduit. +func (c *wireConduit) NextMessage() (SyncMessage, error) { + var b [1]byte + if _, err := io.ReadFull(c.stream, b[:]); err != nil { + if !errors.Is(err, io.EOF) { + return nil, err + } + return nil, nil } - b := bytes.NewBuffer(data) - for { - code, err := b.ReadByte() + mtype := MessageType(b[0]) + // fmt.Fprintf(os.Stderr, "QQQQQ: wireConduit: receive message type %s\n", mtype) + switch mtype { + case MessageTypeDone: + return &DoneMessage{}, nil + case MessageTypeEndRound: + return &EndRoundMessage{}, nil + case MessageTypeItemBatch: + var m ItemBatchMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + dm, err := decodeItemBatchMessage(&m, c.newValue) if err != nil { - if !errors.Is(err, io.EOF) { - // this shouldn't really happen - return nil, err - } - // fmt.Fprintf(os.Stderr, "QQQQQ: wireConduit: decoded msgs: %#v\n", msgs) - return msgs, nil + return nil, err } - mtype := MessageType(code) - // fmt.Fprintf(os.Stderr, "QQQQQ: wireConduit: receive message type %s\n", mtype) - switch mtype { - case MessageTypeDone: - msgs = append(msgs, &DoneMessage{}) - case MessageTypeEndRound: - msgs = append(msgs, &EndRoundMessage{}) - case MessageTypeItemBatch: - var m ItemBatchMessage - if _, err := codec.DecodeFrom(b, &m); err != nil { - return nil, err - } - dm, err := decodeItemBatchMessage(&m, c.newValue) - if err != nil { - return nil, err - } - msgs = append(msgs, dm) - case MessageTypeEmptySet: - msgs = append(msgs, &EmptySetMessage{}) - case MessageTypeEmptyRange: - var m EmptyRangeMessage - if _, err := codec.DecodeFrom(b, &m); err != nil { - return nil, err - } - msgs = append(msgs, &m) - case MessageTypeFingerprint: - var m FingerprintMessage - if _, err := codec.DecodeFrom(b, &m); err != nil { - return nil, err - } - msgs = append(msgs, &m) - case MessageTypeRangeContents: - var m RangeContentsMessage - if _, err := codec.DecodeFrom(b, &m); err != nil { - return nil, err - } - msgs = append(msgs, &m) - case MessageTypeQuery: - var m QueryMessage - if _, err := codec.DecodeFrom(b, &m); err != nil { - return nil, err - } - msgs = append(msgs, &m) - default: - return nil, fmt.Errorf("invalid message code %02x", code) + return dm, nil + case MessageTypeEmptySet: + return &EmptySetMessage{}, nil + case MessageTypeEmptyRange: + var m EmptyRangeMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err } + return &m, nil + case MessageTypeFingerprint: + var m FingerprintMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + return &m, nil + case MessageTypeRangeContents: + var m RangeContentsMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + return &m, nil + case MessageTypeQuery: + var m QueryMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + return &m, nil + default: + return nil, fmt.Errorf("invalid message code %02x", b[0]) } } func (c *wireConduit) send(m sendable) error { - // fmt.Fprintf(os.Stderr, "QQQQQ: wireConduit: sending %s m %#v\n", m.Type(), m) - msg := []byte{byte(m.Type())} - // if c.rmmePrint { - // fmt.Fprintf(os.Stderr, "QQQQQ: send: %s\n", SyncMessageToString(m)) - // } - encoded, err := codec.Encode(m) - if err != nil { - return fmt.Errorf("error encoding %T: %w", m, err) - } - msg = append(msg, encoded...) + var stream io.Writer if c.initReqBuf != nil { - c.initReqBuf.Write(msg) + stream = c.initReqBuf + } else if c.stream == nil { + panic("BUG: wireConduit: no stream") } else { - if err := c.i.Send(msg); err != nil { - return err - } - } - return nil -} - -// NextMessage implements Conduit. -func (c *wireConduit) NextMessage() (SyncMessage, error) { - if len(c.pendingMsgs) != 0 { - m := c.pendingMsgs[0] - c.pendingMsgs = c.pendingMsgs[1:] - // if c.rmmePrint { - // fmt.Fprintf(os.Stderr, "QQQQQ: recv: %s\n", SyncMessageToString(m)) - // } - return m, nil + stream = c.stream } - - msgs, err := c.receive() - if err != nil { - return nil, err - } - if len(msgs) == 0 { - return nil, nil + b := []byte{byte(m.Type())} + if _, err := stream.Write(b); err != nil { + return err } - - c.pendingMsgs = msgs[1:] - // if c.rmmePrint { - // fmt.Fprintf(os.Stderr, "QQQQQ: recv: %s\n", SyncMessageToString(msgs[0])) - // } - return msgs[0], nil + _, err := codec.EncodeTo(stream, m) + return err } func (c *wireConduit) SendFingerprint(x, y Ordered, fingerprint any, count int) error { @@ -289,37 +239,32 @@ func (c *wireConduit) withInitialRequest(toCall func(Conduit) error) ([]byte, er return c.initReqBuf.Bytes(), nil } -func makeHandler(rsr *RangeSetReconciler, c *wireConduit, done chan struct{}) server.InteractiveHandler { - return func(ctx context.Context, i server.Interactor) (time.Duration, error) { - defer func() { - if done != nil { - close(done) - } - }() - c.i = i - for { - c.reset() - // Process() will receive all items and messages from the peer - syncDone, err := rsr.Process(c) - if err != nil { - // do not close done if we're returning an - // error, as the channel will be closed in the - // error handler func - done = nil - return 0, err - } else if syncDone { - return 0, nil - } +func (c *wireConduit) handleStream(stream io.ReadWriter, rsr *RangeSetReconciler) error { + c.stream = stream + for { + // Process() will receive all items and messages from the peer + syncDone, err := rsr.Process(c) + if err != nil { + return err + } else if syncDone { + return nil } } } -func MakeServerHandler(is ItemStore, opts ...Option) server.InteractiveHandler { - return func(ctx context.Context, i server.Interactor) (time.Duration, error) { +func MakeServerHandler(is ItemStore, opts ...Option) server.StreamHandler { + return func(ctx context.Context, req []byte, stream io.ReadWriter) error { c := wireConduit{newValue: is.New} rsr := NewRangeSetReconciler(is, opts...) - h := makeHandler(rsr, &c, nil) - return h(ctx, i) + s := struct { + io.Reader + io.Writer + }{ + // prepend the received request to data being read + Reader: io.MultiReader(bytes.NewBuffer(req), stream), + Writer: stream, + } + return c.handleStream(s, rsr) } } @@ -349,22 +294,9 @@ func syncStore(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, x, if err != nil { return err } - done := make(chan struct{}, 1) - h := makeHandler(rsr, &c, done) - var reqErr error - if err = r.InteractiveRequest(ctx, peer, initReq, h, func(err error) { - reqErr = err - close(done) - }); err != nil { - return err - } - select { - case <-ctx.Done(): - <-done - return ctx.Err() - case <-done: - return reqErr - } + return r.StreamRequest(ctx, peer, initReq, func(ctx context.Context, stream io.ReadWriter) error { + return c.handleStream(stream, rsr) + }) } func Probe(ctx context.Context, r requester, peer p2p.Peer, opts ...Option) (fp any, count int, err error) { @@ -394,33 +326,16 @@ func boundedProbe(ctx context.Context, r requester, peer p2p.Peer, x, y *types.H if err != nil { return nil, 0, err } - done := make(chan struct{}, 2) - h := func(ctx context.Context, i server.Interactor) (time.Duration, error) { - defer func() { - done <- struct{}{} - }() - c.i = i + err = r.StreamRequest(ctx, peer, initReq, func(ctx context.Context, stream io.ReadWriter) error { + c.stream = stream var err error fp, count, err = rsr.HandleProbeResponse(&c) - return 0, err - } - var reqErr error - if err = r.InteractiveRequest(ctx, peer, initReq, h, func(err error) { - reqErr = err - done <- struct{}{} - }); err != nil { + return err + }) + if err != nil { return nil, 0, err } - select { - case <-ctx.Done(): - <-done - return nil, 0, ctx.Err() - case <-done: - if reqErr != nil { - return nil, 0, reqErr - } - return fp, count, nil - } + return fp, count, nil } // TODO: request duration diff --git a/hashsync/handler_test.go b/hashsync/handler_test.go index 6cf55f37b4..2d98903425 100644 --- a/hashsync/handler_test.go +++ b/hashsync/handler_test.go @@ -1,10 +1,11 @@ package hashsync import ( + "bytes" "context" "fmt" + "io" "slices" - "sync/atomic" "testing" "time" @@ -18,67 +19,14 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/server" ) -type fakeMessage struct { - data []byte - error string -} - -type fakeInteractor struct { - fr *fakeRequester - ctx context.Context - sendCh chan fakeMessage - recvCh chan fakeMessage -} - -func (i *fakeInteractor) Send(data []byte) error { - // fmt.Fprintf(os.Stderr, "%p: send %q\n", i, data) - select { - case i.sendCh <- fakeMessage{data: data}: - atomic.AddUint32(&i.fr.bytesSent, uint32(len(data))) - return nil - case <-i.ctx.Done(): - return i.ctx.Err() - } -} - -func (i *fakeInteractor) SendError(err error) error { - // fmt.Fprintf(os.Stderr, "%p: send error %q\n", i, err) - select { - case i.sendCh <- fakeMessage{error: err.Error()}: - atomic.AddUint32(&i.fr.bytesSent, uint32(len(err.Error()))) - return nil - case <-i.ctx.Done(): - return i.ctx.Err() - } -} - -func (i *fakeInteractor) Receive() ([]byte, error) { - // fmt.Fprintf(os.Stderr, "%p: receive\n", i) - var m fakeMessage - select { - case m = <-i.recvCh: - case <-i.ctx.Done(): - return nil, i.ctx.Err() - } - // fmt.Fprintf(os.Stderr, "%p: received %#v\n", i, m) - if m.error != "" { - atomic.AddUint32(&i.fr.bytesReceived, uint32(len(m.error))) - return nil, fmt.Errorf("%w: %s", server.RemoteError, m.error) - } - atomic.AddUint32(&i.fr.bytesReceived, uint32(len(m.data))) - return m.data, nil -} - type incomingRequest struct { - sendCh chan fakeMessage - recvCh chan fakeMessage + initialRequest []byte + stream io.ReadWriter } -var _ server.Interactor = &fakeInteractor{} - type fakeRequester struct { id p2p.Peer - handler server.ServerHandler + handler server.StreamHandler peers map[p2p.Peer]*fakeRequester reqCh chan incomingRequest bytesSent uint32 @@ -87,7 +35,7 @@ type fakeRequester struct { var _ requester = &fakeRequester{} -func newFakeRequester(id p2p.Peer, handler server.ServerHandler, peers ...requester) *fakeRequester { +func newFakeRequester(id p2p.Peer, handler server.StreamHandler, peers ...requester) *fakeRequester { fr := &fakeRequester{ id: id, handler: handler, @@ -112,13 +60,9 @@ func (fr *fakeRequester) Run(ctx context.Context) error { return nil case req = <-fr.reqCh: } - i := &fakeInteractor{ - fr: fr, - ctx: ctx, - sendCh: req.sendCh, - recvCh: req.recvCh, + if err := fr.handler(ctx, req.initialRequest, req.stream); err != nil { + panic("handler error: " + err.Error()) } - fr.handler.Handle(ctx, i) } } @@ -126,45 +70,41 @@ func (fr *fakeRequester) request( ctx context.Context, pid p2p.Peer, initialRequest []byte, - handler server.InteractiveHandler, + callback server.StreamRequestCallback, ) error { p, found := fr.peers[pid] if !found { return fmt.Errorf("bad peer %q", pid) } - i := &fakeInteractor{ - fr: fr, - ctx: ctx, - sendCh: make(chan fakeMessage, 1), - recvCh: make(chan fakeMessage), + r, w := io.Pipe() + defer r.Close() + defer w.Close() + stream := struct { + io.Reader + io.Writer + }{ + Reader: r, + Writer: w, } - i.sendCh <- fakeMessage{data: initialRequest} select { case p.reqCh <- incomingRequest{ - sendCh: i.recvCh, - recvCh: i.sendCh, + initialRequest: initialRequest, + stream: stream, }: case <-ctx.Done(): return ctx.Err() } - _, err := handler(ctx, i) - return err + return callback(ctx, stream) } -func (fr *fakeRequester) InteractiveRequest( +func (fr *fakeRequester) StreamRequest( ctx context.Context, pid p2p.Peer, initialRequest []byte, - handler server.InteractiveHandler, - failure func(error), + callback server.StreamRequestCallback, + extraProtocols ...string, ) error { - go func() { - err := fr.request(ctx, pid, initialRequest, handler) - if err != nil { - failure(err) - } - }() - return nil + return fr.request(ctx, pid, initialRequest, callback) } type sliceIterator struct { @@ -235,18 +175,14 @@ type fakeRound struct { } func (r *fakeRound) handleMessages(t *testing.T, c Conduit) error { - // fmt.Fprintf(os.Stderr, "fakeRound %q: handleMessages\n", r.name) var msgs []SyncMessage for { msg, err := c.NextMessage() if err != nil { - // fmt.Fprintf(os.Stderr, "fakeRound %q: error getting message: %v\n", r.name, err) return fmt.Errorf("NextMessage(): %w", err) } else if msg == nil { - // fmt.Fprintf(os.Stderr, "fakeRound %q: consumed all messages\n", r.name) break } - // fmt.Fprintf(os.Stderr, "fakeRound %q: got message %#v\n", r.name, msg) msgs = append(msgs, msg) if msg.Type() == MessageTypeDone || msg.Type() == MessageTypeEndRound { break @@ -268,25 +204,35 @@ func (r *fakeRound) handleConversation(t *testing.T, c *wireConduit) error { return nil } -func makeTestHandler(t *testing.T, c *wireConduit, newValue NewValueFunc, done chan struct{}, rounds []fakeRound) server.InteractiveHandler { - return func(ctx context.Context, i server.Interactor) (time.Duration, error) { - defer func() { - if done != nil { - close(done) - } - }() +func makeTestStreamHandler(t *testing.T, c *wireConduit, newValue NewValueFunc, rounds []fakeRound) server.StreamHandler { + cbk := makeTestRequestCallback(t, c, newValue, rounds) + return func(ctx context.Context, initialRequest []byte, stream io.ReadWriter) error { + t.Logf("init request bytes: %d", len(initialRequest)) + s := struct { + io.Reader + io.Writer + }{ + // prepend the received request to data being read + Reader: io.MultiReader(bytes.NewBuffer(initialRequest), stream), + Writer: stream, + } + return cbk(ctx, s) + } +} + +func makeTestRequestCallback(t *testing.T, c *wireConduit, newValue NewValueFunc, rounds []fakeRound) server.StreamRequestCallback { + return func(ctx context.Context, stream io.ReadWriter) error { if c == nil { - c = &wireConduit{i: i, newValue: newValue} + c = &wireConduit{stream: stream, newValue: newValue} } else { - c.i = i + c.stream = stream } for _, round := range rounds { if err := round.handleConversation(t, c); err != nil { - done = nil - return 0, err + return err } } - return 0, nil + return nil } } @@ -296,7 +242,7 @@ func TestWireConduit(t *testing.T) { hs[n] = types.RandomHash() } fp := types.Hash12(hs[2][:12]) - srvHandler := makeTestHandler(t, nil, func() any { return new(fakeValue) }, nil, []fakeRound{ + srvHandler := makeTestStreamHandler(t, nil, func() any { return new(fakeValue) }, []fakeRound{ { name: "server got 1st request", expectMsgs: []SyncMessage{ @@ -369,8 +315,7 @@ func TestWireConduit(t *testing.T) { return c.SendEndRound() }) require.NoError(t, err) - done := make(chan struct{}) - clientHandler := makeTestHandler(t, &c, c.newValue, done, []fakeRound{ + clientCbk := makeTestRequestCallback(t, &c, c.newValue, []fakeRound{ { name: "client got 1st response", expectMsgs: []SyncMessage{ @@ -410,15 +355,11 @@ func TestWireConduit(t *testing.T) { }, }, }) - err = client.InteractiveRequest(context.Background(), "srv", initReq, clientHandler, func(err error) { - t.Errorf("fail handler called: %v", err) - close(done) - }) + err = client.StreamRequest(context.Background(), "srv", initReq, clientCbk) require.NoError(t, err) - <-done } -type getRequesterFunc func(name string, handler server.InteractiveHandler, peers ...requester) (requester, p2p.Peer) +type getRequesterFunc func(name string, handler server.StreamHandler, peers ...requester) (requester, p2p.Peer) func withClientServer( storeA, storeB ItemStore, @@ -443,7 +384,7 @@ func withClientServer( } func fakeRequesterGetter(t *testing.T) getRequesterFunc { - return func(name string, handler server.InteractiveHandler, peers ...requester) (requester, p2p.Peer) { + return func(name string, handler server.StreamHandler, peers ...requester) (requester, p2p.Peer) { pid := p2p.Peer(name) return newFakeRequester(pid, handler, peers...), pid } @@ -457,7 +398,7 @@ func p2pRequesterGetter(t *testing.T) getRequesterFunc { server.WithTimeout(10 * time.Second), server.WithLog(logtest.New(t)), } - return func(name string, handler server.InteractiveHandler, peers ...requester) (requester, p2p.Peer) { + return func(name string, handler server.StreamHandler, peers ...requester) (requester, p2p.Peer) { if len(peers) == 0 { return server.New(mesh.Hosts()[0], proto, handler, opts...), mesh.Hosts()[0].ID() } diff --git a/hashsync/interface.go b/hashsync/interface.go index 4b4c21d5f3..61de430f26 100644 --- a/hashsync/interface.go +++ b/hashsync/interface.go @@ -9,5 +9,5 @@ import ( type requester interface { Run(context.Context) error - InteractiveRequest(context.Context, p2p.Peer, []byte, server.InteractiveHandler, func(error)) error + StreamRequest(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error } diff --git a/hashsync/rangesync.go b/hashsync/rangesync.go index e9e7483ec5..9319904b40 100644 --- a/hashsync/rangesync.go +++ b/hashsync/rangesync.go @@ -80,11 +80,10 @@ type NewValueFunc func() any // Conduit handles receiving and sending peer messages type Conduit interface { - // NextMessage returns the next SyncMessage, or nil if there - // are no more SyncMessages. NextMessage is only called after - // a NextItem call indicates that there are no more items. - // NextMessage will not be called after any of Send...() - // methods is invoked + // NextMessage returns the next SyncMessage, or nil if there are no more + // SyncMessages for this session. NextMessage is only called after a NextItem call + // indicates that there are no more items. NextMessage should not be called after + // any of Send...() methods is invoked NextMessage() (SyncMessage, error) // SendFingerprint sends range fingerprint to the peer. // Count must be > 0 From 06f89dce89a9094b225a38c986d3ab05d3445e04 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 28 Mar 2024 20:53:16 +0400 Subject: [PATCH 16/76] hashsync: implement multi-peer split-sync --- hashsync/interface.go | 21 + hashsync/mocks_test.go | 532 ++++++++++++++++++ hashsync/multipeer.go | 331 +++++++++++ hashsync/split_sync.go | 208 +++++++ hashsync/split_sync_test.go | 221 ++++++++ hashsync/sync_queue.go | 105 ++++ hashsync/sync_queue_test.go | 68 +++ hashsync/sync_tree.go | 22 +- ...sync_trees_store.go => sync_tree_store.go} | 0 9 files changed, 1503 insertions(+), 5 deletions(-) create mode 100644 hashsync/mocks_test.go create mode 100644 hashsync/multipeer.go create mode 100644 hashsync/split_sync.go create mode 100644 hashsync/split_sync_test.go create mode 100644 hashsync/sync_queue.go create mode 100644 hashsync/sync_queue_test.go rename hashsync/{sync_trees_store.go => sync_tree_store.go} (100%) diff --git a/hashsync/interface.go b/hashsync/interface.go index 61de430f26..e57e2cfc58 100644 --- a/hashsync/interface.go +++ b/hashsync/interface.go @@ -3,11 +3,32 @@ package hashsync import ( "context" + "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/p2p/server" ) +//go:generate mockgen -typed -package=hashsync -destination=./mocks_test.go -source=./interface.go + type requester interface { Run(context.Context) error StreamRequest(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error } + +type peerSet interface { + addPeer(p p2p.Peer) + removePeer(p p2p.Peer) + numPeers() int + listPeers() []p2p.Peer + havePeer(p p2p.Peer) bool +} + +type syncBase interface { + derive(p p2p.Peer) syncer + probe(ctx context.Context, p p2p.Peer) (int, error) +} + +type syncer interface { + peer() p2p.Peer + sync(ctx context.Context, x, y *types.Hash32) error +} diff --git a/hashsync/mocks_test.go b/hashsync/mocks_test.go new file mode 100644 index 0000000000..b5495106db --- /dev/null +++ b/hashsync/mocks_test.go @@ -0,0 +1,532 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./interface.go +// +// Generated by this command: +// +// mockgen -typed -package=hashsync -destination=./mocks_test.go -source=./interface.go +// + +// Package hashsync is a generated GoMock package. +package hashsync + +import ( + context "context" + reflect "reflect" + + types "github.com/spacemeshos/go-spacemesh/common/types" + p2p "github.com/spacemeshos/go-spacemesh/p2p" + server "github.com/spacemeshos/go-spacemesh/p2p/server" + gomock "go.uber.org/mock/gomock" +) + +// Mockrequester is a mock of requester interface. +type Mockrequester struct { + ctrl *gomock.Controller + recorder *MockrequesterMockRecorder +} + +// MockrequesterMockRecorder is the mock recorder for Mockrequester. +type MockrequesterMockRecorder struct { + mock *Mockrequester +} + +// NewMockrequester creates a new mock instance. +func NewMockrequester(ctrl *gomock.Controller) *Mockrequester { + mock := &Mockrequester{ctrl: ctrl} + mock.recorder = &MockrequesterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *Mockrequester) EXPECT() *MockrequesterMockRecorder { + return m.recorder +} + +// Run mocks base method. +func (m *Mockrequester) Run(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Run", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Run indicates an expected call of Run. +func (mr *MockrequesterMockRecorder) Run(arg0 any) *MockrequesterRunCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*Mockrequester)(nil).Run), arg0) + return &MockrequesterRunCall{Call: call} +} + +// MockrequesterRunCall wrap *gomock.Call +type MockrequesterRunCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockrequesterRunCall) Return(arg0 error) *MockrequesterRunCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockrequesterRunCall) Do(f func(context.Context) error) *MockrequesterRunCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockrequesterRunCall) DoAndReturn(f func(context.Context) error) *MockrequesterRunCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// StreamRequest mocks base method. +func (m *Mockrequester) StreamRequest(arg0 context.Context, arg1 p2p.Peer, arg2 []byte, arg3 server.StreamRequestCallback, arg4 ...string) error { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2, arg3} + for _, a := range arg4 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "StreamRequest", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// StreamRequest indicates an expected call of StreamRequest. +func (mr *MockrequesterMockRecorder) StreamRequest(arg0, arg1, arg2, arg3 any, arg4 ...any) *MockrequesterStreamRequestCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamRequest", reflect.TypeOf((*Mockrequester)(nil).StreamRequest), varargs...) + return &MockrequesterStreamRequestCall{Call: call} +} + +// MockrequesterStreamRequestCall wrap *gomock.Call +type MockrequesterStreamRequestCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockrequesterStreamRequestCall) Return(arg0 error) *MockrequesterStreamRequestCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockrequesterStreamRequestCall) Do(f func(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error) *MockrequesterStreamRequestCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockrequesterStreamRequestCall) DoAndReturn(f func(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error) *MockrequesterStreamRequestCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockpeerSet is a mock of peerSet interface. +type MockpeerSet struct { + ctrl *gomock.Controller + recorder *MockpeerSetMockRecorder +} + +// MockpeerSetMockRecorder is the mock recorder for MockpeerSet. +type MockpeerSetMockRecorder struct { + mock *MockpeerSet +} + +// NewMockpeerSet creates a new mock instance. +func NewMockpeerSet(ctrl *gomock.Controller) *MockpeerSet { + mock := &MockpeerSet{ctrl: ctrl} + mock.recorder = &MockpeerSetMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockpeerSet) EXPECT() *MockpeerSetMockRecorder { + return m.recorder +} + +// addPeer mocks base method. +func (m *MockpeerSet) addPeer(p p2p.Peer) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "addPeer", p) +} + +// addPeer indicates an expected call of addPeer. +func (mr *MockpeerSetMockRecorder) addPeer(p any) *MockpeerSetaddPeerCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "addPeer", reflect.TypeOf((*MockpeerSet)(nil).addPeer), p) + return &MockpeerSetaddPeerCall{Call: call} +} + +// MockpeerSetaddPeerCall wrap *gomock.Call +type MockpeerSetaddPeerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockpeerSetaddPeerCall) Return() *MockpeerSetaddPeerCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockpeerSetaddPeerCall) Do(f func(p2p.Peer)) *MockpeerSetaddPeerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockpeerSetaddPeerCall) DoAndReturn(f func(p2p.Peer)) *MockpeerSetaddPeerCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// havePeer mocks base method. +func (m *MockpeerSet) havePeer(p p2p.Peer) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "havePeer", p) + ret0, _ := ret[0].(bool) + return ret0 +} + +// havePeer indicates an expected call of havePeer. +func (mr *MockpeerSetMockRecorder) havePeer(p any) *MockpeerSethavePeerCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "havePeer", reflect.TypeOf((*MockpeerSet)(nil).havePeer), p) + return &MockpeerSethavePeerCall{Call: call} +} + +// MockpeerSethavePeerCall wrap *gomock.Call +type MockpeerSethavePeerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockpeerSethavePeerCall) Return(arg0 bool) *MockpeerSethavePeerCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockpeerSethavePeerCall) Do(f func(p2p.Peer) bool) *MockpeerSethavePeerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockpeerSethavePeerCall) DoAndReturn(f func(p2p.Peer) bool) *MockpeerSethavePeerCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// listPeers mocks base method. +func (m *MockpeerSet) listPeers() []p2p.Peer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "listPeers") + ret0, _ := ret[0].([]p2p.Peer) + return ret0 +} + +// listPeers indicates an expected call of listPeers. +func (mr *MockpeerSetMockRecorder) listPeers() *MockpeerSetlistPeersCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "listPeers", reflect.TypeOf((*MockpeerSet)(nil).listPeers)) + return &MockpeerSetlistPeersCall{Call: call} +} + +// MockpeerSetlistPeersCall wrap *gomock.Call +type MockpeerSetlistPeersCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockpeerSetlistPeersCall) Return(arg0 []p2p.Peer) *MockpeerSetlistPeersCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockpeerSetlistPeersCall) Do(f func() []p2p.Peer) *MockpeerSetlistPeersCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockpeerSetlistPeersCall) DoAndReturn(f func() []p2p.Peer) *MockpeerSetlistPeersCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// numPeers mocks base method. +func (m *MockpeerSet) numPeers() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "numPeers") + ret0, _ := ret[0].(int) + return ret0 +} + +// numPeers indicates an expected call of numPeers. +func (mr *MockpeerSetMockRecorder) numPeers() *MockpeerSetnumPeersCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "numPeers", reflect.TypeOf((*MockpeerSet)(nil).numPeers)) + return &MockpeerSetnumPeersCall{Call: call} +} + +// MockpeerSetnumPeersCall wrap *gomock.Call +type MockpeerSetnumPeersCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockpeerSetnumPeersCall) Return(arg0 int) *MockpeerSetnumPeersCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockpeerSetnumPeersCall) Do(f func() int) *MockpeerSetnumPeersCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockpeerSetnumPeersCall) DoAndReturn(f func() int) *MockpeerSetnumPeersCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// removePeer mocks base method. +func (m *MockpeerSet) removePeer(p p2p.Peer) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "removePeer", p) +} + +// removePeer indicates an expected call of removePeer. +func (mr *MockpeerSetMockRecorder) removePeer(p any) *MockpeerSetremovePeerCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "removePeer", reflect.TypeOf((*MockpeerSet)(nil).removePeer), p) + return &MockpeerSetremovePeerCall{Call: call} +} + +// MockpeerSetremovePeerCall wrap *gomock.Call +type MockpeerSetremovePeerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockpeerSetremovePeerCall) Return() *MockpeerSetremovePeerCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockpeerSetremovePeerCall) Do(f func(p2p.Peer)) *MockpeerSetremovePeerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockpeerSetremovePeerCall) DoAndReturn(f func(p2p.Peer)) *MockpeerSetremovePeerCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MocksyncBase is a mock of syncBase interface. +type MocksyncBase struct { + ctrl *gomock.Controller + recorder *MocksyncBaseMockRecorder +} + +// MocksyncBaseMockRecorder is the mock recorder for MocksyncBase. +type MocksyncBaseMockRecorder struct { + mock *MocksyncBase +} + +// NewMocksyncBase creates a new mock instance. +func NewMocksyncBase(ctrl *gomock.Controller) *MocksyncBase { + mock := &MocksyncBase{ctrl: ctrl} + mock.recorder = &MocksyncBaseMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MocksyncBase) EXPECT() *MocksyncBaseMockRecorder { + return m.recorder +} + +// derive mocks base method. +func (m *MocksyncBase) derive(p p2p.Peer) syncer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "derive", p) + ret0, _ := ret[0].(syncer) + return ret0 +} + +// derive indicates an expected call of derive. +func (mr *MocksyncBaseMockRecorder) derive(p any) *MocksyncBasederiveCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "derive", reflect.TypeOf((*MocksyncBase)(nil).derive), p) + return &MocksyncBasederiveCall{Call: call} +} + +// MocksyncBasederiveCall wrap *gomock.Call +type MocksyncBasederiveCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MocksyncBasederiveCall) Return(arg0 syncer) *MocksyncBasederiveCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MocksyncBasederiveCall) Do(f func(p2p.Peer) syncer) *MocksyncBasederiveCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MocksyncBasederiveCall) DoAndReturn(f func(p2p.Peer) syncer) *MocksyncBasederiveCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// probe mocks base method. +func (m *MocksyncBase) probe(ctx context.Context, p p2p.Peer) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "probe", ctx, p) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// probe indicates an expected call of probe. +func (mr *MocksyncBaseMockRecorder) probe(ctx, p any) *MocksyncBaseprobeCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "probe", reflect.TypeOf((*MocksyncBase)(nil).probe), ctx, p) + return &MocksyncBaseprobeCall{Call: call} +} + +// MocksyncBaseprobeCall wrap *gomock.Call +type MocksyncBaseprobeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MocksyncBaseprobeCall) Return(arg0 int, arg1 error) *MocksyncBaseprobeCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MocksyncBaseprobeCall) Do(f func(context.Context, p2p.Peer) (int, error)) *MocksyncBaseprobeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MocksyncBaseprobeCall) DoAndReturn(f func(context.Context, p2p.Peer) (int, error)) *MocksyncBaseprobeCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Mocksyncer is a mock of syncer interface. +type Mocksyncer struct { + ctrl *gomock.Controller + recorder *MocksyncerMockRecorder +} + +// MocksyncerMockRecorder is the mock recorder for Mocksyncer. +type MocksyncerMockRecorder struct { + mock *Mocksyncer +} + +// NewMocksyncer creates a new mock instance. +func NewMocksyncer(ctrl *gomock.Controller) *Mocksyncer { + mock := &Mocksyncer{ctrl: ctrl} + mock.recorder = &MocksyncerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *Mocksyncer) EXPECT() *MocksyncerMockRecorder { + return m.recorder +} + +// peer mocks base method. +func (m *Mocksyncer) peer() p2p.Peer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "peer") + ret0, _ := ret[0].(p2p.Peer) + return ret0 +} + +// peer indicates an expected call of peer. +func (mr *MocksyncerMockRecorder) peer() *MocksyncerpeerCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "peer", reflect.TypeOf((*Mocksyncer)(nil).peer)) + return &MocksyncerpeerCall{Call: call} +} + +// MocksyncerpeerCall wrap *gomock.Call +type MocksyncerpeerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MocksyncerpeerCall) Return(arg0 p2p.Peer) *MocksyncerpeerCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MocksyncerpeerCall) Do(f func() p2p.Peer) *MocksyncerpeerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MocksyncerpeerCall) DoAndReturn(f func() p2p.Peer) *MocksyncerpeerCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// sync mocks base method. +func (m *Mocksyncer) sync(ctx context.Context, x, y *types.Hash32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "sync", ctx, x, y) + ret0, _ := ret[0].(error) + return ret0 +} + +// sync indicates an expected call of sync. +func (mr *MocksyncerMockRecorder) sync(ctx, x, y any) *MocksyncersyncCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "sync", reflect.TypeOf((*Mocksyncer)(nil).sync), ctx, x, y) + return &MocksyncersyncCall{Call: call} +} + +// MocksyncersyncCall wrap *gomock.Call +type MocksyncersyncCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MocksyncersyncCall) Return(arg0 error) *MocksyncersyncCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MocksyncersyncCall) Do(f func(context.Context, *types.Hash32, *types.Hash32) error) *MocksyncersyncCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MocksyncersyncCall) DoAndReturn(f func(context.Context, *types.Hash32, *types.Hash32) error) *MocksyncersyncCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/hashsync/multipeer.go b/hashsync/multipeer.go new file mode 100644 index 0000000000..ebdbd286dd --- /dev/null +++ b/hashsync/multipeer.go @@ -0,0 +1,331 @@ +package hashsync + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/jonboulle/clockwork" + "github.com/spacemeshos/go-spacemesh/log" + "github.com/spacemeshos/go-spacemesh/p2p" + "go.uber.org/zap" + "golang.org/x/exp/maps" +) + +// type DataHandler func(context.Context, types.Hash32, p2p.Peer, any) error + +// type dataItem struct { +// key types.Hash32 +// value any +// } + +// type dataItemHandler func(di dataItem) + +// type derivedStore struct { +// ItemStore +// handler dataItemHandler +// // itemCh chan dataItem +// // // TODO: don't embed context in the struct +// // ctx context.Context +// } + +// func (s *derivedStore) Add(k Ordered, v any) { +// s.ItemStore.Add(k, v) +// s.handler(dataItem{key: k.(types.Hash32), value: v}) +// // select { +// // case <-s.ctx.Done(): +// // case s.itemCh <- dataItem{key: k.(types.Hash32), value: v}: +// // } +// } + +type probeResult struct { + probed map[p2p.Peer]int + minCount int + maxCount int +} + +// type peerReconciler struct { +// st SyncTree +// } + +type MultiPeerReconcilerOpt func(mpr *MultiPeerReconciler) + +func WithMinFullSyncCount(count int) MultiPeerReconcilerOpt { + return func(mpr *MultiPeerReconciler) { + mpr.minPartSyncCount = count + } +} + +func WithMinFullFraction(frac float64) MultiPeerReconcilerOpt { + return func(mpr *MultiPeerReconciler) { + mpr.minFullFraction = frac + } +} + +// func WithMinPartSyncPeers(n int) MultiPeerReconcilerOpt { +// return func(mpr *MultiPeerReconciler) { +// mpr.minPartSyncPeers = n +// } +// } + +// func WithPeerSyncTimeout(t time.Duration) MultiPeerReconcilerOpt { +// return func(mpr *MultiPeerReconciler) { +// mpr.peerSyncTimeout = t +// } +// } + +func WithSplitSyncGracePeriod(t time.Duration) MultiPeerReconcilerOpt { + return func(mpr *MultiPeerReconciler) { + mpr.splitSyncGracePeriod = t + } +} + +func withClock(clock clockwork.Clock) MultiPeerReconcilerOpt { + return func(mpr *MultiPeerReconciler) { + mpr.clock = clock + } +} + +type MultiPeerReconciler struct { + logger zap.Logger + // minPartSyncPeers int + minPartSyncCount int + minFullFraction float64 + splitSyncGracePeriod time.Duration + // peerSyncTimeout time.Duration + syncBase syncBase + peerLock sync.Mutex + peers map[p2p.Peer]struct{} + clock clockwork.Clock +} + +func NewMultiPeerReconciler(logger zap.Logger, syncBase syncBase, opts ...MultiPeerReconcilerOpt) *MultiPeerReconciler { + return &MultiPeerReconciler{ + // minPartSyncPeers: 2, + minPartSyncCount: 1000, + minFullFraction: 0.95, + splitSyncGracePeriod: time.Minute, + syncBase: syncBase, + clock: clockwork.NewRealClock(), + } +} + +func (mpr *MultiPeerReconciler) addPeer(p p2p.Peer) { + mpr.peerLock.Lock() + defer mpr.peerLock.Unlock() + mpr.peers[p] = struct{}{} +} + +func (mpr *MultiPeerReconciler) removePeer(p p2p.Peer) { + mpr.peerLock.Lock() + defer mpr.peerLock.Unlock() + delete(mpr.peers, p) +} + +func (mpr *MultiPeerReconciler) numPeers() int { + mpr.peerLock.Lock() + defer mpr.peerLock.Unlock() + return len(mpr.peers) +} + +func (mpr *MultiPeerReconciler) listPeers() []p2p.Peer { + mpr.peerLock.Lock() + defer mpr.peerLock.Unlock() + return maps.Keys(mpr.peers) +} + +func (mpr *MultiPeerReconciler) havePeer(p p2p.Peer) bool { + mpr.peerLock.Lock() + defer mpr.peerLock.Unlock() + _, found := mpr.peers[p] + return found +} + +func (mpr *MultiPeerReconciler) probePeers(ctx context.Context) (*probeResult, error) { + var pr probeResult + for _, p := range mpr.listPeers() { + count, err := mpr.syncBase.probe(ctx, p) + if err != nil { + log.Warning("error probing the peer", zap.Any("peer", p), zap.Error(err)) + if errors.Is(err, context.Canceled) { + return nil, err + } + continue + } + if pr.probed == nil { + pr.probed = map[p2p.Peer]int{ + p: count, + } + pr.minCount = count + pr.maxCount = count + } else { + pr.probed[p] = count + if count < pr.minCount { + pr.minCount = count + } + if count > pr.maxCount { + pr.maxCount = count + } + } + } + return &pr, nil +} + +// func (mpr *MultiPeerReconciler) splitSync(ctx context.Context, peers []p2p.Peer) error { +// // Use priority queue. Higher priority = more time since started syncing +// // Highest priority = not started syncing yet +// // Mark syncRange as synced when it's done, next time it's popped from the queue, +// // it will be dropped +// // When picking up an entry which is already being synced, start with +// // SyncTree of the entry +// // TODO: when all of the ranges are synced at least once, just return. +// // The remaining syncs will be canceled +// // TODO: when no available peers remain, return failure +// if len(peers) == 0 { +// panic("BUG: no peers passed to splitSync") +// } +// syncCtx, cancel := context.WithCancel(ctx) +// defer cancel() +// delim := getDelimiters(len(peers)) +// sq := make(syncQueue, len(peers)) +// var y types.Hash32 +// for n := range sq { +// x := y +// if n == len(peers)-1 { +// y = types.Hash32{} +// } else { +// y = delim[n] +// } +// sq[n] = &syncRange{ +// x: x, +// y: y, +// } +// } +// heap.Init(&sq) +// peers = slices.Clone(peers) +// resCh := make(chan syncResult) +// syncMap := make(map[p2p.Peer]*syncRange) +// numRunning := 0 +// numRemaining := len(peers) +// numPeers := len(peers) +// needGracePeriod := true +// for numRemaining > 0 { +// p := peers[0] +// peers = peers[1:] +// var sr *syncRange +// for len(sq) != 0 { +// sr = heap.Pop(&sq).(*syncRange) +// if !sr.done { +// break +// } +// sr = nil +// } +// if sr == nil { +// panic("BUG: bad syncRange accounting in splitSync") +// } +// syncMap[p] = sr +// var s syncer +// if len(sr.syncers) != 0 { +// // derive from an existing syncer to get sync against +// // more up-to-date data +// s = sr.syncers[len(sr.syncers)-1].derive(p) +// } else { +// s = mpr.syncBase.derive(p) +// } +// sr.syncers = append(sr.syncers, s) +// numRunning++ +// // push this syncRange to the back of the queue as a fresh sync +// // is just starting +// sq.update(sr, mpr.clock.Now()) +// go func() { +// err := s.sync(syncCtx, &sr.x, &sr.y) +// select { +// case <-syncCtx.Done(): +// case resCh <- syncResult{s: s, err: err}: +// } +// }() + +// peers := slices.DeleteFunc(peers, func(p p2p.Peer) bool { +// return !mpr.havePeer(p) +// }) + +// // Grace period: after at least one syncer finishes, wait a bit +// // before assigning it another range to avoid unneeded traffic. +// // The grace period ends if any of the syncers fail +// var gpTimer <-chan time.Time +// if needGracePeriod { +// gpTimer = mpr.clock.After(mpr.splitSyncGracePeriod) +// } +// for needGracePeriod && len(peers) == 0 { +// if numRunning == 0 { +// return errors.New("all peers dropped before full sync has completed") +// } + +// var r syncResult +// select { +// case <-syncCtx.Done(): +// return syncCtx.Err() +// case r = <-resCh: +// case <-gpTimer: +// needGracePeriod = false +// } + +// sr, found := syncMap[s.peer()] +// if !found { +// panic("BUG: error in split sync syncMap handling") +// } +// numRunning-- +// delete(syncMap, s.peer()) +// n := slices.Index(sr.syncers, s) +// if n < 0 { +// panic("BUG: bad syncers in syncRange") +// } +// sr.syncers = slices.Delete(sr.syncers, n, n+1) +// if r.err != nil { +// numPeers-- +// mpr.RemovePeer(s.peer()) +// if numPeers == 0 && numRemaining != 0 { +// return errors.New("all peers dropped before full sync has completed") +// } +// if len(sr.syncers) == 0 { +// // prioritize the syncRange for resync after failed +// // sync with no active syncs remaining +// sq.update(sr, time.Time{}) +// } +// needGracePeriod = false +// } else { +// sr.done = true +// peers = append(peers, s.peer()) +// numRemaining-- +// } +// } +// } + +// return nil +// } + +func (mpr *MultiPeerReconciler) run(ctx context.Context) error { + // States: + // A. No peers -> do nothing. + // Got any peers => B + // B. Low on peers. Wait for more to appear + // Lost all peers => A + // Got enough peers => C + // Timeout => C + // C. Probe the peers. Use successfully probed ones in states D/E + // All probes failed => A + // All are low on count (minPartSyncCount) => E + // Some have substantially higher count (minFullFraction) => D + // Otherwise => E + // D. Bounded sync. Subdivide the range by peers and start syncs. + // Use peers with > minPartSyncCount + // Wait for all the syncs to complete/fail + // All syncs succeeded => A + // Any syncs failed => A + // E. Full sync. Run full syncs against each peer + // All syncs completed (success / fail) => F + // F. Wait. Pause for sync interval + // Timeout => A + panic("TBD") +} diff --git a/hashsync/split_sync.go b/hashsync/split_sync.go new file mode 100644 index 0000000000..9ce428cf1b --- /dev/null +++ b/hashsync/split_sync.go @@ -0,0 +1,208 @@ +package hashsync + +import ( + "context" + "encoding/binary" + "errors" + "slices" + "time" + + "github.com/jonboulle/clockwork" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/p2p" +) + +type syncResult struct { + s syncer + err error +} + +type splitSync struct { + logger *zap.Logger + syncBase syncBase + peerSet peerSet + peers []p2p.Peer + gracePeriod time.Duration + clock clockwork.Clock + sq syncQueue + resCh chan syncResult + slowRangeCh chan *syncRange + syncMap map[p2p.Peer]*syncRange + numRunning int + numRemaining int + numPeers int + syncers []syncer + eg *errgroup.Group +} + +func newSplitSync( + logger *zap.Logger, + syncBase syncBase, + peerSet peerSet, + peers []p2p.Peer, + gracePeriod time.Duration, + clock clockwork.Clock, +) *splitSync { + if len(peers) == 0 { + panic("BUG: no peers passed to splitSync") + } + return &splitSync{ + logger: logger, + syncBase: syncBase, + peerSet: peerSet, + peers: peers, + gracePeriod: gracePeriod, + clock: clock, + sq: newSyncQueue(len(peers)), + resCh: make(chan syncResult), + syncMap: make(map[p2p.Peer]*syncRange), + numRemaining: len(peers), + numPeers: len(peers), + } +} + +func (s *splitSync) nextPeer() p2p.Peer { + if len(s.peers) == 0 { + panic("BUG: no peers") + } + p := s.peers[0] + s.peers = s.peers[1:] + return p +} + +func (s *splitSync) startPeerSync(ctx context.Context, p p2p.Peer, sr *syncRange) { + syncer := s.syncBase.derive(p) + sr.numSyncers++ + s.numRunning++ + doneCh := make(chan struct{}) + s.eg.Go(func() error { + defer close(doneCh) + err := syncer.sync(ctx, &sr.x, &sr.y) + select { + case <-ctx.Done(): + return ctx.Err() + case s.resCh <- syncResult{s: syncer, err: err}: + return nil + } + }) + gpTimer := s.clock.After(s.gracePeriod) + s.eg.Go(func() error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-doneCh: + case <-gpTimer: + // if another peer finishes it part early, let + // it pick up this range + s.slowRangeCh <- sr + } + return nil + }) +} + +func (s *splitSync) handleSyncResult(r syncResult) error { + sr, found := s.syncMap[r.s.peer()] + if !found { + panic("BUG: error in split sync syncMap handling") + } + s.numRunning-- + delete(s.syncMap, r.s.peer()) + sr.numSyncers-- + if r.err != nil { + s.numPeers-- + s.peerSet.removePeer(r.s.peer()) + s.logger.Debug("remove failed peer", + zap.Stringer("peer", r.s.peer()), + zap.Int("numPeers", s.numPeers), + zap.Int("numRemaining", s.numRemaining), + zap.Int("numRunning", s.numRunning), + zap.Int("availPeers", len(s.peers))) + if s.numPeers == 0 && s.numRemaining != 0 { + return errors.New("all peers dropped before full sync has completed") + } + if sr.numSyncers == 0 { + // QQQQQ: it has been popped!!!! + // prioritize the syncRange for resync after failed + // sync with no active syncs remaining + s.sq.update(sr, time.Time{}) + } + } else { + sr.done = true + s.peers = append(s.peers, r.s.peer()) + s.numRemaining-- + s.logger.Debug("peer synced successfully", + zap.Stringer("peer", r.s.peer()), + zap.Int("numPeers", s.numPeers), + zap.Int("numRemaining", s.numRemaining), + zap.Int("numRunning", s.numRunning), + zap.Int("availPeers", len(s.peers))) + } + + return nil +} + +func (s *splitSync) clearDeadPeers() { + s.peers = slices.DeleteFunc(s.peers, func(p p2p.Peer) bool { + return !s.peerSet.havePeer(p) + }) +} + +func (s *splitSync) sync(ctx context.Context) error { + sctx, cancel := context.WithCancel(ctx) + defer cancel() + var syncCtx context.Context + s.eg, syncCtx = errgroup.WithContext(sctx) + for s.numRemaining > 0 { + var sr *syncRange + for { + s.logger.Debug("QQQQQ: wait sr") + sr := s.sq.popRange() + if sr != nil { + if sr.done { + continue + } + p := s.nextPeer() + s.syncMap[p] = sr + s.startPeerSync(syncCtx, p, sr) + } + break + } + s.clearDeadPeers() + for s.numRemaining > 0 && (s.sq.empty() || len(s.peers) == 0) { + s.logger.Debug("QQQQQ: loop") + if s.numRunning == 0 && len(s.peers) == 0 { + return errors.New("all peers dropped before full sync has completed") + } + select { + case sr = <-s.slowRangeCh: + // push this syncRange to the back of the queue + s.sq.update(sr, s.clock.Now()) + case <-syncCtx.Done(): + return syncCtx.Err() + case r := <-s.resCh: + if err := s.handleSyncResult(r); err != nil { + return err + } + } + } + s.logger.Debug("QQQQQ: after loop") + } + s.logger.Debug("QQQQQ: wg wait") + return s.eg.Wait() +} + +func getDelimiters(numPeers int) (h []types.Hash32) { + if numPeers < 2 { + return nil + } + inc := (uint64(0x80) << 56) / uint64(numPeers) + h = make([]types.Hash32, numPeers-1) + for i, v := 0, uint64(0); i < numPeers-1; i++ { + v += inc + binary.BigEndian.PutUint64(h[i][:], v<<1) + } + return h +} diff --git a/hashsync/split_sync_test.go b/hashsync/split_sync_test.go new file mode 100644 index 0000000000..4e7cd9d181 --- /dev/null +++ b/hashsync/split_sync_test.go @@ -0,0 +1,221 @@ +package hashsync + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/stretchr/testify/require" + gomock "go.uber.org/mock/gomock" + "go.uber.org/zap/zaptest" + "golang.org/x/sync/errgroup" +) + +func hexDelimiters(n int) (r []string) { + for _, h := range getDelimiters(n) { + r = append(r, h.Hex()) + } + return r +} + +func TestGetDelimiters(t *testing.T) { + for _, tc := range []struct { + numPeers int + values []string + }{ + { + numPeers: 0, + values: nil, + }, + { + numPeers: 1, + values: nil, + }, + { + numPeers: 2, + values: []string{ + "0x8000000000000000000000000000000000000000000000000000000000000000", + }, + }, + { + numPeers: 3, + values: []string{ + "0x5555555555555554000000000000000000000000000000000000000000000000", + "0xaaaaaaaaaaaaaaa8000000000000000000000000000000000000000000000000", + }, + }, + { + numPeers: 4, + values: []string{ + "0x4000000000000000000000000000000000000000000000000000000000000000", + "0x8000000000000000000000000000000000000000000000000000000000000000", + "0xc000000000000000000000000000000000000000000000000000000000000000", + }, + }, + } { + r := hexDelimiters(tc.numPeers) + if len(tc.values) == 0 { + require.Empty(t, r, "%d delimiters", tc.numPeers) + } else { + require.Equal(t, tc.values, r, "%d delimiters", tc.numPeers) + } + } +} + +type splitSyncTester struct { + testing.TB + + peers []p2p.Peer + clock clockwork.Clock + mtx sync.Mutex + fail map[hexRange]bool + expPeerRanges map[hexRange]int + peerRanges map[hexRange][]p2p.Peer + syncBase *MocksyncBase + peerSet *MockpeerSet + splitSync *splitSync +} + +var tstRanges = []hexRange{ + { + "0x0000000000000000000000000000000000000000000000000000000000000000", + "0x4000000000000000000000000000000000000000000000000000000000000000", + }, + { + "0x4000000000000000000000000000000000000000000000000000000000000000", + "0x8000000000000000000000000000000000000000000000000000000000000000", + }, + { + "0x8000000000000000000000000000000000000000000000000000000000000000", + "0xc000000000000000000000000000000000000000000000000000000000000000", + }, + { + "0xc000000000000000000000000000000000000000000000000000000000000000", + "0x0000000000000000000000000000000000000000000000000000000000000000", + }, +} + +func newTestSplitSync(t testing.TB) *splitSyncTester { + ctrl := gomock.NewController(t) + tst := &splitSyncTester{ + peers: make([]p2p.Peer, 4), + clock: clockwork.NewFakeClock(), + fail: make(map[hexRange]bool), + expPeerRanges: map[hexRange]int{ + tstRanges[0]: 0, + tstRanges[1]: 0, + tstRanges[2]: 0, + tstRanges[3]: 0, + }, + peerRanges: make(map[hexRange][]p2p.Peer), + syncBase: NewMocksyncBase(ctrl), + peerSet: NewMockpeerSet(ctrl), + } + for n := range tst.peers { + tst.peers[n] = p2p.Peer(types.RandomBytes(20)) + } + for index, p := range tst.peers { + index := index + p := p + tst.syncBase.EXPECT(). + derive(p). + DoAndReturn(func(peer p2p.Peer) syncer { + s := NewMocksyncer(ctrl) + s.EXPECT().peer().Return(p).AnyTimes() + s.EXPECT(). + sync(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, x, y *types.Hash32) error { + tst.mtx.Lock() + defer tst.mtx.Unlock() + require.NotNil(t, ctx) + require.NotNil(t, x) + require.NotNil(t, y) + k := hexRange{x.Hex(), y.Hex()} + tst.peerRanges[k] = append(tst.peerRanges[k], peer) + count, found := tst.expPeerRanges[k] + require.True(t, found, "peer range not found: x %s y %s", x, y) + if tst.fail[k] { + t.Logf("ERR: peer %d x %s y %s", index, x.String(), y.String()) + tst.fail[k] = false + return errors.New("injected fault") + } else { + t.Logf("OK: peer %d x %s y %s", index, x.String(), y.String()) + tst.expPeerRanges[k] = count + 1 + } + return nil + }) + return s + }). + AnyTimes() + } + tst.peerSet.EXPECT(). + havePeer(gomock.Any()). + DoAndReturn(func(p p2p.Peer) bool { + require.Contains(t, tst.peers, p) + return true + }). + AnyTimes() + tst.splitSync = newSplitSync( + zaptest.NewLogger(t), + tst.syncBase, + tst.peerSet, + tst.peers, + time.Minute, + tst.clock, + ) + return tst +} + +func TestSplitSync(t *testing.T) { + tst := newTestSplitSync(t) + var eg errgroup.Group + eg.Go(func() error { + return tst.splitSync.sync(context.Background()) + }) + require.NoError(t, eg.Wait()) + for pr, count := range tst.expPeerRanges { + require.Equal(t, 1, count, "bad sync count: x %s y %s", pr[0], pr[1]) + } +} + +func TestSplitSyncRetry(t *testing.T) { + tst := newTestSplitSync(t) + tst.fail[tstRanges[1]] = true + tst.fail[tstRanges[2]] = true + removedPeers := make(map[p2p.Peer]bool) + tst.peerSet.EXPECT().removePeer(gomock.Any()).DoAndReturn(func(peer p2p.Peer) { + require.NotContains(t, removedPeers, peer) + removedPeers[peer] = true + }).Times(2) + var eg errgroup.Group + eg.Go(func() error { + return tst.splitSync.sync(context.Background()) + }) + require.NoError(t, eg.Wait()) + for pr, count := range tst.expPeerRanges { + require.False(t, tst.fail[pr], "fail cleared for x %s y %s", pr[0], pr[1]) + require.Equal(t, 1, count, "peer range not synced: x %s y %s", pr[0], pr[1]) + } + for _, r := range []hexRange{tstRanges[1], tstRanges[2]} { + haveFailedPeers := false + for _, peer := range tst.peerRanges[r] { + if removedPeers[peer] { + haveFailedPeers = true + } + } + require.True(t, haveFailedPeers) + } +} + +// TODO: test cancel +// TODO: test sync failure +// TODO: test out of peers due to failure +// TODO: test dropping failed peers (there should be a hook so that the peer connection is terminated) +// TODO: log peer sync failures +// TODO: log sync starts +// TODO: log overlapping syncs diff --git a/hashsync/sync_queue.go b/hashsync/sync_queue.go new file mode 100644 index 0000000000..d4805d7b29 --- /dev/null +++ b/hashsync/sync_queue.go @@ -0,0 +1,105 @@ +package hashsync + +import ( + "container/heap" + "time" + + "github.com/spacemeshos/go-spacemesh/common/types" +) + +type syncRange struct { + x, y types.Hash32 + lastSyncStarted time.Time + done bool + numSyncers int + index int +} + +type syncQueue []*syncRange + +// Len implements heap.Interface. +func (sq syncQueue) Len() int { return len(sq) } + +// Less implements heap.Interface. +func (sq syncQueue) Less(i, j int) bool { + // We want Pop to give us syncRange for which which sync has started the + // earliest. Items which are not being synced are considered "most earliest" + return sq[i].lastSyncStarted.Before(sq[j].lastSyncStarted) +} + +// Swap implements heap.Interface. +func (sq syncQueue) Swap(i, j int) { + sq[i], sq[j] = sq[j], sq[i] + sq[i].index = i + sq[j].index = j +} + +// Push implements heap.Interface. +func (sq *syncQueue) Push(i any) { + n := len(*sq) + sr := i.(*syncRange) + sr.index = n + *sq = append(*sq, sr) +} + +// Pop implements heap.Interface. +func (sq *syncQueue) Pop() any { + old := *sq + n := len(old) + sr := old[n-1] + old[n-1] = nil // avoid memory leak + sr.index = -1 // not in the queue anymore + *sq = old[0 : n-1] + return sr +} + +func newSyncQueue(numPeers int) syncQueue { + delim := getDelimiters(numPeers) + var y types.Hash32 + sq := make(syncQueue, numPeers) + for n := range sq { + x := y + if n == numPeers-1 { + y = types.Hash32{} + } else { + y = delim[n] + } + sq[n] = &syncRange{ + x: x, + y: y, + } + } + heap.Init(&sq) + return sq +} + +func (sq *syncQueue) empty() bool { + return len(*sq) == 0 +} + +func (sq *syncQueue) popRange() *syncRange { + if sq.empty() { + return nil + } + sr := heap.Pop(sq).(*syncRange) + sr.index = -1 + return sr +} + +func (sq *syncQueue) pushRange(sr *syncRange) { + if sr.done { + panic("BUG: pushing a finished syncRange into the queue") + } + if sr.index == -1 { + heap.Push(sq, sr) + } +} + +func (sq *syncQueue) update(sr *syncRange, lastSyncStarted time.Time) { + sr.lastSyncStarted = lastSyncStarted + if sr.index == -1 { + sq.pushRange(sr) + } else { + heap.Fix(sq, sr.index) + } +} diff --git a/hashsync/sync_queue_test.go b/hashsync/sync_queue_test.go new file mode 100644 index 0000000000..8b0ca984ca --- /dev/null +++ b/hashsync/sync_queue_test.go @@ -0,0 +1,68 @@ +package hashsync + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type hexRange [2]string + +func TestSyncQueue(t *testing.T) { + expPeerRanges := map[hexRange]bool{ + { + "0x0000000000000000000000000000000000000000000000000000000000000000", + "0x4000000000000000000000000000000000000000000000000000000000000000", + }: false, + { + "0x4000000000000000000000000000000000000000000000000000000000000000", + "0x8000000000000000000000000000000000000000000000000000000000000000", + }: false, + { + "0x8000000000000000000000000000000000000000000000000000000000000000", + "0xc000000000000000000000000000000000000000000000000000000000000000", + }: false, + { + "0xc000000000000000000000000000000000000000000000000000000000000000", + "0x0000000000000000000000000000000000000000000000000000000000000000", + }: false, + } + sq := newSyncQueue(4) + startTime := time.Now() + pushed := make([]hexRange, 4) + for i := 0; i < 4; i++ { + sr := sq.popRange() + require.NotNil(t, sr) + require.True(t, sr.lastSyncStarted.IsZero()) + require.False(t, sr.done) + require.Zero(t, sr.numSyncers) + k := hexRange{sr.x.Hex(), sr.y.Hex()} + processed, found := expPeerRanges[k] + require.True(t, found) + require.False(t, processed) + expPeerRanges[k] = true + t.Logf("push range %v at %v", k, sr.lastSyncStarted) + if i != 1 { + sr.lastSyncStarted = startTime + sq.pushRange(sr) // pushed to the end + } else { + // use update for one of the items + // instead of pushing with proper time + sq.update(sr, startTime) + } + if i == 0 { + sq.pushRange(sr) // should do nothing + } + startTime = startTime.Add(10 * time.Second) + pushed[i] = k + } + require.Len(t, sq, 4) + for i := 0; i < 4; i++ { + sr := sq.popRange() + k := hexRange{sr.x.Hex(), sr.y.Hex()} + t.Logf("pop range %v at %v", k, sr.lastSyncStarted) + require.Equal(t, pushed[i], k) + } + require.Empty(t, sq) +} diff --git a/hashsync/sync_tree.go b/hashsync/sync_tree.go index 181fa68ba8..bd2990ad00 100644 --- a/hashsync/sync_tree.go +++ b/hashsync/sync_tree.go @@ -7,6 +7,7 @@ import ( "reflect" "slices" "strings" + "sync" ) type Ordered interface { @@ -32,6 +33,14 @@ func (fpred FingerprintPredicate) Match(y any) bool { } type SyncTree interface { + // Make a copy of the tree. The copy shares the structure with + // this tree but all its nodes are copy-on-write, so any + // changes in the copied tree do not affect this one and are + // safe to perform in another goroutine. The copy operation is + // O(n) where n is the number of nodes added to this tree + // since its creation via either NewSyncTree function or this + // Copy method, or the last call of this Copy method for this + // tree, whichever occurs last. The call to Copy is thread-safe. Copy() SyncTree Fingerprint() any Add(k Ordered) @@ -302,6 +311,7 @@ func (sn *syncTreeNode) cleanCloned() { } type syncTree struct { + rootMtx sync.Mutex m Monoid root *syncTreeNode cachedMinPtr *syncTreePointer @@ -313,9 +323,11 @@ func NewSyncTree(m Monoid) SyncTree { } func (st *syncTree) Copy() SyncTree { - // Clean flagCloned from any nodes created specifically - // for this subtree. This will mean they will have to be - // re-cloned if they need to be changed again. + st.rootMtx.Lock() + defer st.rootMtx.Unlock() + // Clean flagCloned from any nodes created specifically for + // this tree. This will mean they will have to be re-cloned if + // they need to be changed again. st.root.cleanCloned() // Don't reuse cachedMinPtr / cachedMaxPtr for the cloned // tree to be on the safe side @@ -501,6 +513,8 @@ func (st *syncTree) Set(k Ordered, v any) { } func (st *syncTree) add(k Ordered, v any, set bool) { + st.rootMtx.Lock() + st.rootMtx.Unlock() st.root = st.insert(st.root, k, v, true, set) if st.root.flags&flagBlack == 0 { st.root = st.ensureCloned(st.root) @@ -890,8 +904,6 @@ func (st *syncTree) Dump() string { return sb.String() } -// TBD: !!! values and Lookup (via findGTENode) !!! -// TODO: rename SyncTreeNode to just Node, SyncTree to SyncTree // TODO: use sync.Pool for node alloc // see also: // https://www.akshaydeo.com/blog/2017/12/23/How-did-I-improve-latency-by-700-percent-using-syncPool/ diff --git a/hashsync/sync_trees_store.go b/hashsync/sync_tree_store.go similarity index 100% rename from hashsync/sync_trees_store.go rename to hashsync/sync_tree_store.go From 34a1943354865de9a184cde10631d6ed53cd68af Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Fri, 3 May 2024 22:37:09 +0400 Subject: [PATCH 17/76] hashsync: add minhash probing --- hashsync/handler.go | 103 ++++++++++++++++----- hashsync/handler_test.go | 45 +++++----- hashsync/rangesync.go | 167 +++++++++++++++++++++++++++++------ hashsync/rangesync_test.go | 159 ++++++++++++++++++++++----------- hashsync/sync_tree_store.go | 17 +++- hashsync/wire_types.go | 100 ++++++++++++++++++--- hashsync/wire_types_scale.go | 115 +++++++++++++++++++++++- 7 files changed, 574 insertions(+), 132 deletions(-) diff --git a/hashsync/handler.go b/hashsync/handler.go index 5e51632bf4..055896c056 100644 --- a/hashsync/handler.go +++ b/hashsync/handler.go @@ -134,8 +134,14 @@ func (c *wireConduit) NextMessage() (SyncMessage, error) { return nil, err } return &m, nil - case MessageTypeQuery: - var m QueryMessage + case MessageTypeProbe: + var m ProbeMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + return &m, nil + case MessageTypeProbeResponse: + var m ProbeResponseMessage if _, err := codec.DecodeFrom(c.stream, &m); err != nil { return nil, err } @@ -146,6 +152,7 @@ func (c *wireConduit) NextMessage() (SyncMessage, error) { } func (c *wireConduit) send(m sendable) error { + // fmt.Fprintf(os.Stderr, "QQQQQ: send: %s: %#v\n", m.Type(), m) var stream io.Writer if c.initReqBuf != nil { stream = c.initReqBuf @@ -219,15 +226,46 @@ func (c *wireConduit) SendDone() error { return c.send(&DoneMessage{}) } -func (c *wireConduit) SendQuery(x, y Ordered) error { +func (c *wireConduit) SendProbe(x, y Ordered, fingerprint any, sampleSize int) error { + m := &ProbeMessage{ + RangeFingerprint: fingerprint.(types.Hash12), + SampleSize: uint32(sampleSize), + } if x == nil && y == nil { - return c.send(&QueryMessage{}) + return c.send(m) } else if x == nil || y == nil { - panic("BUG: SendQuery: bad range: just one of the bounds is nil") + panic("BUG: SendProbe: bad range: just one of the bounds is nil") } xh := x.(types.Hash32) yh := y.(types.Hash32) - return c.send(&QueryMessage{RangeX: &xh, RangeY: &yh}) + m.RangeX = &xh + m.RangeY = &yh + return c.send(m) +} + +func (c *wireConduit) SendProbeResponse(x, y Ordered, fingerprint any, count, sampleSize int, it Iterator) error { + m := &ProbeResponseMessage{ + RangeFingerprint: fingerprint.(types.Hash12), + NumItems: uint32(count), + Sample: make([]MinhashSampleItem, sampleSize), + } + // fmt.Fprintf(os.Stderr, "QQQQQ: begin sending items\n") + for n := 0; n < sampleSize; n++ { + m.Sample[n] = MinhashSampleItemFromHash32(it.Key().(types.Hash32)) + // fmt.Fprintf(os.Stderr, "QQQQQ: m.Sample[%d] = %s\n", n, m.Sample[n]) + it.Next() + } + // fmt.Fprintf(os.Stderr, "QQQQQ: end sending items\n") + if x == nil && y == nil { + return c.send(m) + } else if x == nil || y == nil { + panic("BUG: SendProbe: bad range: just one of the bounds is nil") + } + xh := x.(types.Hash32) + yh := y.(types.Hash32) + m.RangeX = &xh + m.RangeY = &yh + return c.send(m) } func (c *wireConduit) withInitialRequest(toCall func(Conduit) error) ([]byte, error) { @@ -252,6 +290,11 @@ func (c *wireConduit) handleStream(stream io.ReadWriter, rsr *RangeSetReconciler } } +// ShortenKey implements Conduit. +func (c *wireConduit) ShortenKey(k Ordered) Ordered { + return MinhashSampleItemFromHash32(k.(types.Hash32)) +} + func MakeServerHandler(is ItemStore, opts ...Option) server.StreamHandler { return func(ctx context.Context, req []byte, stream io.ReadWriter) error { c := wireConduit{newValue: is.New} @@ -299,43 +342,63 @@ func syncStore(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, x, }) } -func Probe(ctx context.Context, r requester, peer p2p.Peer, opts ...Option) (fp any, count int, err error) { - return boundedProbe(ctx, r, peer, nil, nil, opts) +func Probe(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, opts ...Option) (ProbeResult, error) { + return boundedProbe(ctx, r, peer, is, nil, nil, opts) } -func BoundedProbe(ctx context.Context, r requester, peer p2p.Peer, x, y types.Hash32, opts ...Option) (fp any, count int, err error) { - return boundedProbe(ctx, r, peer, &x, &y, opts) +func BoundedProbe( + ctx context.Context, + r requester, + peer p2p.Peer, + is ItemStore, + x, y types.Hash32, + opts ...Option, +) (ProbeResult, error) { + return boundedProbe(ctx, r, peer, is, &x, &y, opts) } -func boundedProbe(ctx context.Context, r requester, peer p2p.Peer, x, y *types.Hash32, opts []Option) (fp any, count int, err error) { +func boundedProbe( + ctx context.Context, + r requester, + peer p2p.Peer, + is ItemStore, + x, y *types.Hash32, + opts []Option, +) (ProbeResult, error) { + var ( + err error + initReq []byte + info RangeInfo + pr ProbeResult + ) c := wireConduit{ newValue: func() any { return nil }, // not used } - rsr := NewRangeSetReconciler(nil, opts...) - // c.rmmePrint = true - var initReq []byte + rsr := NewRangeSetReconciler(is, opts...) if x == nil { initReq, err = c.withInitialRequest(func(c Conduit) error { - return rsr.InitiateProbe(c) + info, err = rsr.InitiateProbe(c) + return err }) } else { initReq, err = c.withInitialRequest(func(c Conduit) error { - return rsr.InitiateBoundedProbe(c, *x, *y) + info, err = rsr.InitiateBoundedProbe(c, *x, *y) + return err }) } if err != nil { - return nil, 0, err + return ProbeResult{}, err } err = r.StreamRequest(ctx, peer, initReq, func(ctx context.Context, stream io.ReadWriter) error { c.stream = stream var err error - fp, count, err = rsr.HandleProbeResponse(&c) + pr, err = rsr.HandleProbeResponse(&c, info) return err }) if err != nil { - return nil, 0, err + return ProbeResult{}, err } - return fp, count, nil + return pr, nil } // TODO: request duration diff --git a/hashsync/handler_test.go b/hashsync/handler_test.go index 2d98903425..25557c92ee 100644 --- a/hashsync/handler_test.go +++ b/hashsync/handler_test.go @@ -362,12 +362,12 @@ func TestWireConduit(t *testing.T) { type getRequesterFunc func(name string, handler server.StreamHandler, peers ...requester) (requester, p2p.Peer) func withClientServer( - storeA, storeB ItemStore, + store ItemStore, getRequester getRequesterFunc, opts []Option, toCall func(ctx context.Context, client requester, srvPeerID p2p.Peer), ) { - srvHandler := MakeServerHandler(storeA, opts...) + srvHandler := MakeServerHandler(store, opts...) srv, srvPeerID := getRequester("srv", srvHandler) var eg errgroup.Group ctx, cancel := context.WithCancel(context.Background()) @@ -383,7 +383,7 @@ func withClientServer( toCall(ctx, client, srvPeerID) } -func fakeRequesterGetter(t *testing.T) getRequesterFunc { +func fakeRequesterGetter() getRequesterFunc { return func(name string, handler server.StreamHandler, peers ...requester) (requester, p2p.Peer) { pid := p2p.Peer(name) return newFakeRequester(pid, handler, peers...), pid @@ -428,7 +428,7 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) requester { var client requester verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []Option) bool { withClientServer( - storeA, storeB, getRequester, opts, + storeA, getRequester, opts, func(ctx context.Context, client requester, srvPeerID p2p.Peer) { err := SyncStore(ctx, client, srvPeerID, storeB, opts...) require.NoError(t, err) @@ -445,7 +445,7 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) requester { func TestWireSync(t *testing.T) { t.Run("fake requester", func(t *testing.T) { - testWireSync(t, fakeRequesterGetter(t)) + testWireSync(t, fakeRequesterGetter()) }) t.Run("p2p", func(t *testing.T) { testWireSync(t, p2pRequesterGetter(t)) @@ -455,33 +455,36 @@ func TestWireSync(t *testing.T) { func testWireProbe(t *testing.T, getRequester getRequesterFunc) requester { cfg := xorSyncTestConfig{ maxSendRange: 1, - numTestHashes: 32, - minNumSpecificA: 4, - maxNumSpecificA: 4, - minNumSpecificB: 4, - maxNumSpecificB: 4, + numTestHashes: 10000, + minNumSpecificA: 130, + maxNumSpecificA: 130, + minNumSpecificB: 130, + maxNumSpecificB: 130, } var client requester verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []Option) bool { withClientServer( - storeA, storeB, getRequester, opts, + storeA, getRequester, opts, func(ctx context.Context, client requester, srvPeerID p2p.Peer) { minA := storeA.Min().Key() infoA := storeA.GetRangeInfo(nil, minA, minA, -1) - fpA, countA, err := Probe(ctx, client, srvPeerID, opts...) + prA, err := Probe(ctx, client, srvPeerID, storeB, opts...) require.NoError(t, err) - require.Equal(t, infoA.Fingerprint, fpA) - require.Equal(t, infoA.Count, countA) + require.Equal(t, infoA.Fingerprint, prA.FP) + require.Equal(t, infoA.Count, prA.Count) + require.InDelta(t, 0.98, prA.Sim, 0.05, "sim") minA = storeA.Min().Key() partInfoA := storeA.GetRangeInfo(nil, minA, minA, infoA.Count/2) x := partInfoA.Start.Key().(types.Hash32) y := partInfoA.End.Key().(types.Hash32) // partInfoA = storeA.GetRangeInfo(nil, x, y, -1) - fpA, countA, err = BoundedProbe(ctx, client, srvPeerID, x, y, opts...) + prA, err = BoundedProbe(ctx, client, srvPeerID, storeB, x, y, opts...) require.NoError(t, err) - require.Equal(t, partInfoA.Fingerprint, fpA) - require.Equal(t, partInfoA.Count, countA) + require.Equal(t, partInfoA.Fingerprint, prA.FP) + require.Equal(t, partInfoA.Count, prA.Count) + require.InDelta(t, 0.98, prA.Sim, 0.1, "sim") + // QQQQQ: TBD: check prA.Sim and prB.Sim values }) return false }) @@ -490,11 +493,11 @@ func testWireProbe(t *testing.T, getRequester getRequesterFunc) requester { func TestWireProbe(t *testing.T) { t.Run("fake requester", func(t *testing.T) { - testWireProbe(t, fakeRequesterGetter(t)) - }) - t.Run("p2p", func(t *testing.T) { - testWireProbe(t, p2pRequesterGetter(t)) + testWireProbe(t, fakeRequesterGetter()) }) + // t.Run("p2p", func(t *testing.T) { + // testWireProbe(t, p2pRequesterGetter(t)) + // }) } // TODO: test bounded sync diff --git a/hashsync/rangesync.go b/hashsync/rangesync.go index 9319904b40..918c6fffb2 100644 --- a/hashsync/rangesync.go +++ b/hashsync/rangesync.go @@ -4,12 +4,15 @@ import ( "errors" "fmt" "reflect" + "slices" "strings" ) const ( defaultMaxSendRange = 16 defaultItemChunkSize = 16 + defaultSampleSize = 200 + maxSampleSize = 1000 ) type MessageType byte @@ -22,7 +25,8 @@ const ( MessageTypeFingerprint MessageTypeRangeContents MessageTypeItemBatch - MessageTypeQuery + MessageTypeProbe + MessageTypeProbeResponse ) var messageTypes = []string{ @@ -33,6 +37,8 @@ var messageTypes = []string{ "fingerprint", "rangeContents", "itemBatch", + "probe", + "probeResponse", } func (mtype MessageType) String() string { @@ -79,6 +85,8 @@ func SyncMessageToString(m SyncMessage) string { type NewValueFunc func() any // Conduit handles receiving and sending peer messages +// TODO: replace multiple Send* methods with a single one +// (after de-generalizing messages) type Conduit interface { // NextMessage returns the next SyncMessage, or nil if there are no more // SyncMessages for this session. NextMessage is only called after a NextItem call @@ -104,10 +112,17 @@ type Conduit interface { SendEndRound() error // SendDone sends a message that notifies the peer that sync is finished SendDone() error - // SendQuery sends a message requesting fingerprint and count of the - // whole range or part of the range. The response will never contain any - // actual data items - SendQuery(x, y Ordered) error + // SendProbe sends a message requesting fingerprint and count of the + // whole range or part of the range. If fingerprint is provided and + // it doesn't match the fingerprint on the probe handler side, + // the handler must send a sample subset of its items for MinHash + // calculation. + SendProbe(x, y Ordered, fingerprint any, sampleSize int) error + // SendProbeResponse sends probe response. If 'it' is not nil, + // the corresponding items are included in the sample + SendProbeResponse(x, y Ordered, fingerprint any, count, sampleSize int, it Iterator) error + // ShortenKey shortens the key for minhash calculation + ShortenKey(k Ordered) Ordered } type Option func(r *RangeSetReconciler) @@ -124,6 +139,12 @@ func WithItemChunkSize(n int) Option { } } +func WithSampleSize(s int) Option { + return func(r *RangeSetReconciler) { + r.sampleSize = s + } +} + // Iterator points to in item in ItemStore type Iterator interface { // Equal returns true if this iterator is equal to another Iterator @@ -149,7 +170,9 @@ type ItemStore interface { Add(k Ordered, v any) // GetRangeInfo returns RangeInfo for the item range in the tree. // If count >= 0, at most count items are returned, and RangeInfo - // is returned for the corresponding subrange of the requested range + // is returned for the corresponding subrange of the requested range. + // If both x and y is nil, the whole set of items is used. + // If only x or only y is nil, GetRangeInfo panics GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo // Min returns the iterator pointing at the minimum element // in the store. If the store is empty, it returns nil @@ -161,10 +184,17 @@ type ItemStore interface { New() any } +type ProbeResult struct { + FP any + Count int + Sim float64 +} + type RangeSetReconciler struct { is ItemStore maxSendRange int itemChunkSize int + sampleSize int } func NewRangeSetReconciler(is ItemStore, opts ...Option) *RangeSetReconciler { @@ -172,6 +202,7 @@ func NewRangeSetReconciler(is ItemStore, opts ...Option) *RangeSetReconciler { is: is, maxSendRange: defaultMaxSendRange, itemChunkSize: defaultItemChunkSize, + sampleSize: defaultSampleSize, } for _, opt := range opts { opt(rsr) @@ -237,12 +268,18 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg x := msg.X() y := msg.Y() done = true - if msg.Type() == MessageTypeEmptySet || (msg.Type() == MessageTypeQuery && x == nil && y == nil) { + if msg.Type() == MessageTypeEmptySet || (msg.Type() == MessageTypeProbe && x == nil && y == nil) { // The peer has no items at all so didn't // even send X & Y (SendEmptySet) it := rsr.is.Min() if it == nil { // We don't have any items at all, too + if msg.Type() == MessageTypeProbe { + info := rsr.is.GetRangeInfo(preceding, nil, nil, -1) + if err := c.SendProbeResponse(x, y, info.Fingerprint, info.Count, 0, it); err != nil { + return nil, false, err + } + } return nil, true, nil } x = it.Key() @@ -251,7 +288,7 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg return nil, false, errors.New("bad X or Y") } info := rsr.is.GetRangeInfo(preceding, x, y, -1) - // fmt.Fprintf(os.Stderr, "msg %s fp %v start %#v end %#v count %d\n", msg, info.Fingerprint, info.Start, info.End, info.Count) + // fmt.Fprintf(os.Stderr, "QQQQQ msg %s %#v fp %v start %#v end %#v count %d\n", msg.Type(), msg, info.Fingerprint, info.Start, info.End, info.Count) switch { case msg.Type() == MessageTypeEmptyRange || msg.Type() == MessageTypeRangeContents || @@ -267,8 +304,23 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg return nil, false, err } } - case msg.Type() == MessageTypeQuery: - if err := c.SendFingerprint(x, y, info.Fingerprint, info.Count); err != nil { + case msg.Type() == MessageTypeProbe: + sampleSize := msg.Count() + if sampleSize > maxSampleSize { + return nil, false, fmt.Errorf("bad minhash sample size %d (max %d)", + msg.Count(), maxSampleSize) + } else if sampleSize > info.Count { + sampleSize = info.Count + } + it := info.Start + if fingerprintEqual(msg.Fingerprint(), info.Fingerprint) { + // no need to send MinHash items if fingerprints match + it = nil + sampleSize = 0 + // fmt.Fprintf(os.Stderr, "QQQQQ: fingerprint eq %#v %#v\n", + // msg.Fingerprint(), info.Fingerprint) + } + if err := c.SendProbeResponse(x, y, info.Fingerprint, info.Count, sampleSize, it); err != nil { return nil, false, err } return nil, true, nil @@ -381,47 +433,108 @@ func (rsr *RangeSetReconciler) getMessages(c Conduit) (msgs []SyncMessage, done } } -func (rsr *RangeSetReconciler) InitiateProbe(c Conduit) error { +func (rsr *RangeSetReconciler) InitiateProbe(c Conduit) (RangeInfo, error) { return rsr.InitiateBoundedProbe(c, nil, nil) } -func (rsr *RangeSetReconciler) InitiateBoundedProbe(c Conduit, x, y Ordered) error { - if err := c.SendQuery(x, y); err != nil { - return err +func (rsr *RangeSetReconciler) InitiateBoundedProbe(c Conduit, x, y Ordered) (RangeInfo, error) { + info := rsr.is.GetRangeInfo(nil, x, y, -1) + // fmt.Fprintf(os.Stderr, "QQQQQ: x %#v y %#v count %d\n", x, y, info.Count) + if err := c.SendProbe(x, y, info.Fingerprint, rsr.sampleSize); err != nil { + return RangeInfo{}, err } if err := c.SendEndRound(); err != nil { - return err + return RangeInfo{}, err } - return nil + return info, nil +} + +func (rsr *RangeSetReconciler) calcSim(c Conduit, info RangeInfo, remoteSample []Ordered, fp any) float64 { + if fingerprintEqual(info.Fingerprint, fp) { + return 1 + } + if info.Start == nil { + return 0 + } + sampleSize := min(info.Count, rsr.sampleSize) + localSample := make([]Ordered, sampleSize) + it := info.Start + for n := 0; n < sampleSize; n++ { + // fmt.Fprintf(os.Stderr, "QQQQQ: n %d sampleSize %d info.Count %d rsr.sampleSize %d %#v\n", + // n, sampleSize, info.Count, rsr.sampleSize, it.Key()) + if it.Key() == nil { + panic("BUG: no key") + } + localSample[n] = c.ShortenKey(it.Key()) + it.Next() + } + slices.SortFunc(remoteSample, func(a, b Ordered) int { return a.Compare(b) }) + slices.SortFunc(localSample, func(a, b Ordered) int { return a.Compare(b) }) + + numEq := 0 + for m, n := 0, 0; m < len(localSample) && n < len(remoteSample); { + d := localSample[m].Compare(remoteSample[n]) + switch { + case d < 0: + // fmt.Fprintf(os.Stderr, "QQQQQ: less: %v < %s\n", c.ShortenKey(it.Key()), remoteSample[n]) + m++ + case d == 0: + // fmt.Fprintf(os.Stderr, "QQQQQ: eq: %v\n", remoteSample[n]) + numEq++ + m++ + n++ + default: + // fmt.Fprintf(os.Stderr, "QQQQQ: gt: %v > %s\n", c.ShortenKey(it.Key()), remoteSample[n]) + n++ + } + } + maxSampleSize := max(sampleSize, len(remoteSample)) + // fmt.Fprintf(os.Stderr, "QQQQQ: numEq %d maxSampleSize %d\n", numEq, maxSampleSize) + return float64(numEq) / float64(maxSampleSize) } -func (rsr *RangeSetReconciler) HandleProbeResponse(c Conduit) (fp any, count int, err error) { +func (rsr *RangeSetReconciler) HandleProbeResponse(c Conduit, info RangeInfo) (pr ProbeResult, err error) { + // fmt.Fprintf(os.Stderr, "QQQQQ: HandleProbeResponse\n") + // defer fmt.Fprintf(os.Stderr, "QQQQQ: HandleProbeResponse done\n") gotRange := false for { msg, err := c.NextMessage() switch { case err != nil: - return nil, 0, err + return ProbeResult{}, err case msg == nil: - return nil, 0, errors.New("no end round marker") + // fmt.Fprintf(os.Stderr, "QQQQQ: HandleProbeResponse: %s %#v\n", msg.Type(), msg) + return ProbeResult{}, errors.New("no end round marker") default: + // fmt.Fprintf(os.Stderr, "QQQQQ: HandleProbeResponse: %s %#v\n", msg.Type(), msg) switch mt := msg.Type(); mt { case MessageTypeEndRound: - return nil, 0, errors.New("non-final round in response to a probe") + return ProbeResult{}, errors.New("non-final round in response to a probe") case MessageTypeDone: // the peer is not expecting any new messages - return fp, count, nil - case MessageTypeFingerprint: - fp = msg.Fingerprint() - count = msg.Count() - fallthrough + if !gotRange { + return ProbeResult{}, errors.New("no range info received during probe") + } + return pr, nil + case MessageTypeProbeResponse: + if gotRange { + return ProbeResult{}, errors.New("single range message expected") + } + pr.FP = msg.Fingerprint() + pr.Count = msg.Count() + pr.Sim = rsr.calcSim(c, info, msg.Keys(), msg.Fingerprint()) + gotRange = true case MessageTypeEmptySet, MessageTypeEmptyRange: if gotRange { - return nil, 0, errors.New("single range message expected") + return ProbeResult{}, errors.New("single range message expected") + } + if info.Count == 0 { + pr.Sim = 1 } gotRange = true default: - return nil, 0, fmt.Errorf("unexpected message type: %v", msg.Type()) + return ProbeResult{}, fmt.Errorf( + "probe response: unexpected message type: %v", msg.Type()) } } } diff --git a/hashsync/rangesync_test.go b/hashsync/rangesync_test.go index 373759ff7a..d9c7cac3b9 100644 --- a/hashsync/rangesync_test.go +++ b/hashsync/rangesync_test.go @@ -64,14 +64,7 @@ func (fc *fakeConduit) NextMessage() (SyncMessage, error) { return nil, nil } -func (fc *fakeConduit) sendMsg(mtype MessageType, x, y Ordered, fingerprint any, count int) { - msg := rangeMessage{ - mtype: mtype, - x: x, - y: y, - fp: fingerprint, - count: count, - } +func (fc *fakeConduit) sendMsg(msg rangeMessage) { fc.resp = append(fc.resp, msg) } @@ -80,26 +73,41 @@ func (fc *fakeConduit) SendFingerprint(x, y Ordered, fingerprint any, count int) require.NotNil(fc.t, y) require.NotZero(fc.t, count) require.NotNil(fc.t, fingerprint) - fc.sendMsg(MessageTypeFingerprint, x, y, fingerprint, count) + fc.sendMsg(rangeMessage{ + mtype: MessageTypeFingerprint, + x: x, + y: y, + fp: fingerprint, + count: count, + }) return nil } func (fc *fakeConduit) SendEmptySet() error { - fc.sendMsg(MessageTypeEmptySet, nil, nil, nil, 0) + fc.sendMsg(rangeMessage{mtype: MessageTypeEmptySet}) return nil } func (fc *fakeConduit) SendEmptyRange(x, y Ordered) error { require.NotNil(fc.t, x) require.NotNil(fc.t, y) - fc.sendMsg(MessageTypeEmptyRange, x, y, nil, 0) + fc.sendMsg(rangeMessage{ + mtype: MessageTypeEmptyRange, + x: x, + y: y, + }) return nil } func (fc *fakeConduit) SendRangeContents(x, y Ordered, count int) error { require.NotNil(fc.t, x) require.NotNil(fc.t, y) - fc.sendMsg(MessageTypeRangeContents, x, y, nil, count) + fc.sendMsg(rangeMessage{ + mtype: MessageTypeRangeContents, + x: x, + y: y, + count: count, + }) return nil } @@ -119,26 +127,54 @@ func (fc *fakeConduit) SendItems(count, itemChunkSize int, it Iterator) error { it.Next() n-- } - fc.resp = append(fc.resp, msg) + fc.sendMsg(msg) } return nil } func (fc *fakeConduit) SendEndRound() error { - fc.sendMsg(MessageTypeEndRound, nil, nil, nil, 0) + fc.sendMsg(rangeMessage{mtype: MessageTypeEndRound}) return nil } func (fc *fakeConduit) SendDone() error { - fc.sendMsg(MessageTypeDone, nil, nil, nil, 0) + fc.sendMsg(rangeMessage{mtype: MessageTypeDone}) + return nil +} + +func (fc *fakeConduit) SendProbe(x, y Ordered, fingerprint any, sampleSize int) error { + fc.sendMsg(rangeMessage{ + mtype: MessageTypeProbe, + x: x, + y: y, + fp: fingerprint, + count: sampleSize, + }) return nil } -func (fc *fakeConduit) SendQuery(x, y Ordered) error { - fc.sendMsg(MessageTypeQuery, x, y, nil, 0) +func (fc *fakeConduit) SendProbeResponse(x, y Ordered, fingerprint any, count, sampleSize int, it Iterator) error { + msg := rangeMessage{ + mtype: MessageTypeProbeResponse, + x: x, + y: y, + fp: fingerprint, + count: count, + keys: make([]Ordered, sampleSize), + } + for n := 0; n < sampleSize; n++ { + require.NotNil(fc.t, it.Key()) + msg.keys[n] = it.Key() + it.Next() + } + fc.sendMsg(msg) return nil } +func (fc *fakeConduit) ShortenKey(k Ordered) Ordered { + return k +} + type dumbStoreIterator struct { ds *dumbStore n int @@ -226,6 +262,19 @@ func (ds *dumbStore) iterFor(s sampleID) Iterator { } func (ds *dumbStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo { + if x == nil && y == nil { + it := ds.Min() + if it == nil { + return RangeInfo{ + Fingerprint: "", + } + } else { + x = it.Key() + y = x + } + } else if x == nil || y == nil { + panic("BUG: bad X or Y") + } all := storeItemStr(ds) vx := x.(sampleID) vy := y.(sampleID) @@ -557,44 +606,44 @@ func doRunSync(fc *fakeConduit, syncA, syncB *RangeSetReconciler, maxRounds int) return i + 1, nMsg, nItems } -func runProbe(t *testing.T, from, to *RangeSetReconciler) (fp any, count int) { +func runProbe(t *testing.T, from, to *RangeSetReconciler) ProbeResult { fc := &fakeConduit{t: t} - require.NoError(t, from.InitiateProbe(fc)) - return doRunProbe(fc, from, to) + info, err := from.InitiateProbe(fc) + require.NoError(t, err) + return doRunProbe(fc, from, to, info) } -func runBoundedProbe(t *testing.T, from, to *RangeSetReconciler, x, y Ordered) (fp any, count int) { +func runBoundedProbe(t *testing.T, from, to *RangeSetReconciler, x, y Ordered) ProbeResult { fc := &fakeConduit{t: t} - require.NoError(t, from.InitiateBoundedProbe(fc, x, y)) - return doRunProbe(fc, from, to) + info, err := from.InitiateBoundedProbe(fc, x, y) + require.NoError(t, err) + return doRunProbe(fc, from, to, info) } -func doRunProbe(fc *fakeConduit, from, to *RangeSetReconciler) (fp any, count int) { +func doRunProbe(fc *fakeConduit, from, to *RangeSetReconciler, info RangeInfo) ProbeResult { require.NotEmpty(fc.t, fc.resp, "empty initial round") fc.gotoResponse() done, err := to.Process(fc) require.True(fc.t, done) require.NoError(fc.t, err) fc.gotoResponse() - fp, count, err = from.HandleProbeResponse(fc) + pr, err := from.HandleProbeResponse(fc, info) require.NoError(fc.t, err) require.Nil(fc.t, fc.resp, "got messages from Probe in response to done msg") - return fp, count + return pr } func TestRangeSync(t *testing.T) { forTestStores(t, func(t *testing.T, storeFactory storeFactory) { for _, tc := range []struct { - name string - a, b string - finalA string - finalB string - x, y string - countA int - countB int - fpA any - fpB any - maxRounds [4]int + name string + a, b string + finalA, finalB string + x, y string + countA, countB int + fpA, fpB string + maxRounds [4]int + sim float64 }{ { name: "empty sets", @@ -604,9 +653,10 @@ func TestRangeSync(t *testing.T) { finalB: "", countA: 0, countB: 0, - fpA: nil, - fpB: nil, + fpA: "", + fpB: "", maxRounds: [4]int{1, 1, 1, 1}, + sim: 1, }, { name: "empty to non-empty", @@ -616,9 +666,10 @@ func TestRangeSync(t *testing.T) { finalB: "abcd", countA: 0, countB: 4, - fpA: nil, + fpA: "", fpB: "abcd", maxRounds: [4]int{2, 2, 2, 2}, + sim: 0, }, { name: "non-empty to empty", @@ -629,8 +680,9 @@ func TestRangeSync(t *testing.T) { countA: 4, countB: 0, fpA: "abcd", - fpB: nil, + fpB: "", maxRounds: [4]int{2, 2, 2, 2}, + sim: 0, }, { name: "non-intersecting sets", @@ -643,6 +695,7 @@ func TestRangeSync(t *testing.T) { fpA: "ab", fpB: "cd", maxRounds: [4]int{3, 2, 2, 2}, + sim: 0, }, { name: "intersecting sets", @@ -655,6 +708,7 @@ func TestRangeSync(t *testing.T) { fpA: "acdefghijklmn", fpB: "bcdopqr", maxRounds: [4]int{4, 4, 3, 3}, + sim: 0.153, }, { name: "bounded reconciliation", @@ -669,6 +723,7 @@ func TestRangeSync(t *testing.T) { fpA: "acdefg", fpB: "bcd", maxRounds: [4]int{3, 3, 2, 2}, + sim: 0.333, }, { name: "bounded reconciliation with rollover", @@ -683,6 +738,7 @@ func TestRangeSync(t *testing.T) { fpA: "hijklmn", fpB: "opqr", maxRounds: [4]int{4, 3, 3, 2}, + sim: 0, }, { name: "sync against 1-element set", @@ -695,6 +751,7 @@ func TestRangeSync(t *testing.T) { fpA: "bcd", fpB: "a", maxRounds: [4]int{2, 2, 2, 2}, + sim: 0, }, } { t.Run(tc.name, func(t *testing.T) { @@ -712,28 +769,30 @@ func TestRangeSync(t *testing.T) { WithItemChunkSize(3)) var ( - countA, countB, nRounds int - fpA, fpB any + nRounds int + prBA, prAB ProbeResult ) if tc.x == "" { - fpA, countA = runProbe(t, syncB, syncA) - fpB, countB = runProbe(t, syncA, syncB) + prBA = runProbe(t, syncB, syncA) + prAB = runProbe(t, syncA, syncB) nRounds, _, _ = runSync(t, syncA, syncB, tc.maxRounds[n]) } else { x := sampleID(tc.x) y := sampleID(tc.y) - fpA, countA = runBoundedProbe(t, syncB, syncA, x, y) - fpB, countB = runBoundedProbe(t, syncA, syncB, x, y) + prBA = runBoundedProbe(t, syncB, syncA, x, y) + prAB = runBoundedProbe(t, syncA, syncB, x, y) nRounds, _, _ = runBoundedSync(t, syncA, syncB, x, y, tc.maxRounds[n]) } t.Logf("%s: maxSendRange %d: %d rounds", tc.name, maxSendRange, nRounds) - require.Equal(t, tc.countA, countA, "countA") - require.Equal(t, tc.countB, countB, "countB") - require.Equal(t, tc.fpA, fpA, "fpA") - require.Equal(t, tc.fpB, fpB, "fpB") + require.Equal(t, tc.countA, prBA.Count, "countA") + require.Equal(t, tc.countB, prAB.Count, "countB") + require.Equal(t, tc.fpA, prBA.FP, "fpA") + require.Equal(t, tc.fpB, prAB.FP, "fpB") require.Equal(t, tc.finalA, storeItemStr(storeA), "finalA") require.Equal(t, tc.finalB, storeItemStr(storeB), "finalB") + require.InDelta(t, tc.sim, prAB.Sim, 0.01, "prAB.Sim") + require.InDelta(t, tc.sim, prBA.Sim, 0.01, "prBA.Sim") } }) } diff --git a/hashsync/sync_tree_store.go b/hashsync/sync_tree_store.go index a6f7d2e2e1..048b1297a5 100644 --- a/hashsync/sync_tree_store.go +++ b/hashsync/sync_tree_store.go @@ -50,6 +50,7 @@ type SyncTreeStore struct { st SyncTree vh ValueHandler newValue NewValueFunc + identity any } var _ ItemStore = &SyncTreeStore{} @@ -62,6 +63,7 @@ func NewSyncTreeStore(m Monoid, vh ValueHandler, newValue NewValueFunc) ItemStor st: NewSyncTree(CombineMonoids(m, CountingMonoid{})), vh: vh, newValue: newValue, + identity: m.Identity(), } } @@ -83,7 +85,20 @@ func (sts *SyncTreeStore) iter(ptr SyncTreePointer) Iterator { } // GetRangeInfo implements ItemStore. -func (sts *SyncTreeStore) GetRangeInfo(preceding Iterator, x Ordered, y Ordered, count int) RangeInfo { +func (sts *SyncTreeStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo { + if x == nil && y == nil { + it := sts.Min() + if it == nil { + return RangeInfo{ + Fingerprint: sts.identity, + } + } else { + x = it.Key() + y = x + } + } else if x == nil || y == nil { + panic("BUG: bad X or Y") + } var stop FingerprintPredicate var node SyncTreePointer if preceding != nil { diff --git a/hashsync/wire_types.go b/hashsync/wire_types.go index 839525bb48..27cc1ba49e 100644 --- a/hashsync/wire_types.go +++ b/hashsync/wire_types.go @@ -1,6 +1,10 @@ package hashsync import ( + "cmp" + "fmt" + + "github.com/spacemeshos/go-scale" "github.com/spacemeshos/go-spacemesh/common/types" ) @@ -101,29 +105,103 @@ type ItemBatchMessage struct { func (m *ItemBatchMessage) Type() MessageType { return MessageTypeItemBatch } -// QueryMessage requests bounded range fingerprint and count from the peer -type QueryMessage struct { - RangeX, RangeY *types.Hash32 +// ProbeMessage requests bounded range fingerprint and count from the peer, +// along with a minhash sample if fingerprints differ +type ProbeMessage struct { + RangeX, RangeY *types.Hash32 + RangeFingerprint types.Hash12 + SampleSize uint32 +} + +var _ SyncMessage = &ProbeMessage{} + +func (m *ProbeMessage) Type() MessageType { return MessageTypeProbe } +func (m *ProbeMessage) X() Ordered { + if m.RangeX == nil { + return nil + } + return *m.RangeX +} + +func (m *ProbeMessage) Y() Ordered { + if m.RangeY == nil { + return nil + } + return *m.RangeY +} +func (m *ProbeMessage) Fingerprint() any { return m.RangeFingerprint } +func (m *ProbeMessage) Count() int { return int(m.SampleSize) } +func (m *ProbeMessage) Keys() []Ordered { return nil } +func (m *ProbeMessage) Values() []any { return nil } + +// MinhashSampleItem represents an item of minhash sample subset +type MinhashSampleItem uint32 + +var _ Ordered = MinhashSampleItem(0) + +func (m MinhashSampleItem) String() string { + return fmt.Sprintf("0x%08x", uint32(m)) } -var _ SyncMessage = &QueryMessage{} +// Compare implements Ordered +func (m MinhashSampleItem) Compare(other any) int { + return cmp.Compare(m, other.(MinhashSampleItem)) +} + +// EncodeScale implements scale.Encodable. +func (m MinhashSampleItem) EncodeScale(e *scale.Encoder) (int, error) { + // QQQQQ: FIXME: there's EncodeUint32 (non-compact which is better for hashes) + // but no DecodeUint32 + return scale.EncodeCompact32(e, uint32(m)) +} + +// DecodeScale implements scale.Decodable. +func (m *MinhashSampleItem) DecodeScale(d *scale.Decoder) (int, error) { + v, total, err := scale.DecodeCompact32(d) + *m = MinhashSampleItem(v) + return total, err +} + +// MinhashSampleItemFromHash32 uses lower 32 bits of a Hash32 as a MinhashSampleItem +func MinhashSampleItemFromHash32(h types.Hash32) MinhashSampleItem { + return MinhashSampleItem(uint32(h[28])<<24 + uint32(h[29])<<16 + uint32(h[30])<<8 + uint32(h[31])) +} -func (m *QueryMessage) Type() MessageType { return MessageTypeQuery } -func (m *QueryMessage) X() Ordered { +// ProbeResponseMessage is a response to ProbeMessage +type ProbeResponseMessage struct { + RangeX, RangeY *types.Hash32 + RangeFingerprint types.Hash12 + NumItems uint32 + // NOTE: max must be in sync with maxSampleSize in hashsync/rangesync.go + Sample []MinhashSampleItem `scale:"max=1000"` +} + +var _ SyncMessage = &ProbeResponseMessage{} + +func (m *ProbeResponseMessage) Type() MessageType { return MessageTypeProbeResponse } +func (m *ProbeResponseMessage) X() Ordered { if m.RangeX == nil { return nil } return *m.RangeX } -func (m *QueryMessage) Y() Ordered { + +func (m *ProbeResponseMessage) Y() Ordered { if m.RangeY == nil { return nil } return *m.RangeY } -func (m *QueryMessage) Fingerprint() any { return nil } -func (m *QueryMessage) Count() int { return 0 } -func (m *QueryMessage) Keys() []Ordered { return nil } -func (m *QueryMessage) Values() []any { return nil } +func (m *ProbeResponseMessage) Fingerprint() any { return m.RangeFingerprint } +func (m *ProbeResponseMessage) Count() int { return int(m.NumItems) } +func (m *ProbeResponseMessage) Values() []any { return nil } + +func (m *ProbeResponseMessage) Keys() []Ordered { + r := make([]Ordered, len(m.Sample)) + for n, item := range m.Sample { + r[n] = item + } + return r +} // TODO: don't do scalegen for empty types diff --git a/hashsync/wire_types_scale.go b/hashsync/wire_types_scale.go index fddb0ae1ae..164c4f1cf5 100644 --- a/hashsync/wire_types_scale.go +++ b/hashsync/wire_types_scale.go @@ -272,7 +272,7 @@ func (t *ItemBatchMessage) DecodeScale(dec *scale.Decoder) (total int, err error return total, nil } -func (t *QueryMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { +func (t *ProbeMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { { n, err := scale.EncodeOption(enc, t.RangeX) if err != nil { @@ -287,10 +287,98 @@ func (t *QueryMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { } total += n } + { + n, err := scale.EncodeByteArray(enc, t.RangeFingerprint[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeCompact32(enc, uint32(t.SampleSize)) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *ProbeMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + field, n, err := scale.DecodeOption[types.Hash32](dec) + if err != nil { + return total, err + } + total += n + t.RangeX = field + } + { + field, n, err := scale.DecodeOption[types.Hash32](dec) + if err != nil { + return total, err + } + total += n + t.RangeY = field + } + { + n, err := scale.DecodeByteArray(dec, t.RangeFingerprint[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeCompact32(dec) + if err != nil { + return total, err + } + total += n + t.SampleSize = uint32(field) + } + return total, nil +} + +func (t *ProbeResponseMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeOption(enc, t.RangeX) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeOption(enc, t.RangeY) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeByteArray(enc, t.RangeFingerprint[:]) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeCompact32(enc, uint32(t.NumItems)) + if err != nil { + return total, err + } + total += n + } + { + n, err := scale.EncodeStructSliceWithLimit(enc, t.Sample, 1000) + if err != nil { + return total, err + } + total += n + } return total, nil } -func (t *QueryMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { +func (t *ProbeResponseMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { { field, n, err := scale.DecodeOption[types.Hash32](dec) if err != nil { @@ -307,5 +395,28 @@ func (t *QueryMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { total += n t.RangeY = field } + { + n, err := scale.DecodeByteArray(dec, t.RangeFingerprint[:]) + if err != nil { + return total, err + } + total += n + } + { + field, n, err := scale.DecodeCompact32(dec) + if err != nil { + return total, err + } + total += n + t.NumItems = uint32(field) + } + { + field, n, err := scale.DecodeStructSliceWithLimit[MinhashSampleItem](dec, 1000) + if err != nil { + return total, err + } + total += n + t.Sample = field + } return total, nil } From f05fb38bae048a3fcf7e2a892403fd56dc3e9add Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 15 May 2024 16:31:25 +0400 Subject: [PATCH 18/76] hashsync: implemented part of multipeer sync --- fetch/peers/peers.go | 7 + hashsync/handler.go | 68 ++--- hashsync/handler_test.go | 28 +- hashsync/interface.go | 17 +- hashsync/mocks_test.go | 320 +++++++++++------------ hashsync/multipeer.go | 499 +++++++++++++++++------------------- hashsync/rangesync.go | 19 +- hashsync/rangesync_test.go | 56 +++- hashsync/setsyncbase.go | 121 +++++++++ hashsync/split_sync.go | 47 ++-- hashsync/split_sync_test.go | 50 ++-- hashsync/sync_tree.go | 30 +-- hashsync/sync_tree_store.go | 21 +- hashsync/sync_tree_test.go | 6 +- hashsync/xorsync_test.go | 18 +- 15 files changed, 723 insertions(+), 584 deletions(-) create mode 100644 hashsync/setsyncbase.go diff --git a/fetch/peers/peers.go b/fetch/peers/peers.go index 00b0f3ce26..ed736bf2f2 100644 --- a/fetch/peers/peers.go +++ b/fetch/peers/peers.go @@ -54,6 +54,13 @@ type Peers struct { globalLatency float64 } +func (p *Peers) Contains(id peer.ID) bool { + p.mu.Lock() + defer p.mu.Unlock() + _, exist := p.peers[id] + return exist +} + func (p *Peers) Add(id peer.ID) bool { p.mu.Lock() defer p.mu.Unlock() diff --git a/hashsync/handler.go b/hashsync/handler.go index 055896c056..63090f32ff 100644 --- a/hashsync/handler.go +++ b/hashsync/handler.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "sync" "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" @@ -77,6 +78,35 @@ func decodeItemBatchMessage(m *ItemBatchMessage, newValue NewValueFunc) (*decode return d, nil } +// QQQQQ: rmme +var ( + numRead int + numWritten int + smtx sync.Mutex +) + +type rmmeCountingStream struct { + io.ReadWriter +} + +// Read implements io.ReadWriter. +func (r *rmmeCountingStream) Read(p []byte) (n int, err error) { + smtx.Lock() + defer smtx.Unlock() + n, err = r.ReadWriter.Read(p) + numRead += n + return n, err +} + +// Write implements io.ReadWriter. +func (r *rmmeCountingStream) Write(p []byte) (n int, err error) { + smtx.Lock() + defer smtx.Unlock() + n, err = r.ReadWriter.Write(p) + numWritten += n + return n, err +} + type conduitState int type wireConduit struct { @@ -277,11 +307,11 @@ func (c *wireConduit) withInitialRequest(toCall func(Conduit) error) ([]byte, er return c.initReqBuf.Bytes(), nil } -func (c *wireConduit) handleStream(stream io.ReadWriter, rsr *RangeSetReconciler) error { +func (c *wireConduit) handleStream(ctx context.Context, stream io.ReadWriter, rsr *RangeSetReconciler) error { c.stream = stream for { // Process() will receive all items and messages from the peer - syncDone, err := rsr.Process(c) + syncDone, err := rsr.Process(ctx, c) if err != nil { return err } else if syncDone { @@ -307,19 +337,11 @@ func MakeServerHandler(is ItemStore, opts ...Option) server.StreamHandler { Reader: io.MultiReader(bytes.NewBuffer(req), stream), Writer: stream, } - return c.handleStream(s, rsr) + return c.handleStream(ctx, s, rsr) } } -func BoundedSyncStore(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, x, y types.Hash32, opts ...Option) error { - return syncStore(ctx, r, peer, is, &x, &y, opts) -} - -func SyncStore(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, opts ...Option) error { - return syncStore(ctx, r, peer, is, nil, nil, opts) -} - -func syncStore(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, x, y *types.Hash32, opts []Option) error { +func SyncStore(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, x, y *types.Hash32, opts ...Option) error { c := wireConduit{newValue: is.New} rsr := NewRangeSetReconciler(is, opts...) // c.rmmePrint = true @@ -338,32 +360,18 @@ func syncStore(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, x, return err } return r.StreamRequest(ctx, peer, initReq, func(ctx context.Context, stream io.ReadWriter) error { - return c.handleStream(stream, rsr) + s := &rmmeCountingStream{ReadWriter: stream} + return c.handleStream(ctx, s, rsr) }) } -func Probe(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, opts ...Option) (ProbeResult, error) { - return boundedProbe(ctx, r, peer, is, nil, nil, opts) -} - -func BoundedProbe( - ctx context.Context, - r requester, - peer p2p.Peer, - is ItemStore, - x, y types.Hash32, - opts ...Option, -) (ProbeResult, error) { - return boundedProbe(ctx, r, peer, is, &x, &y, opts) -} - -func boundedProbe( +func Probe( ctx context.Context, r requester, peer p2p.Peer, is ItemStore, x, y *types.Hash32, - opts []Option, + opts ...Option, ) (ProbeResult, error) { var ( err error diff --git a/hashsync/handler_test.go b/hashsync/handler_test.go index 25557c92ee..bdc84d5f02 100644 --- a/hashsync/handler_test.go +++ b/hashsync/handler_test.go @@ -418,6 +418,13 @@ func p2pRequesterGetter(t *testing.T) getRequesterFunc { func testWireSync(t *testing.T, getRequester getRequesterFunc) requester { cfg := xorSyncTestConfig{ + // large test: + // maxSendRange: 1, + // numTestHashes: 5000000, + // minNumSpecificA: 15000, + // maxNumSpecificA: 20000, + // minNumSpecificB: 15000, + // maxNumSpecificB: 20000, maxSendRange: 1, numTestHashes: 100000, minNumSpecificA: 4, @@ -430,13 +437,16 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) requester { withClientServer( storeA, getRequester, opts, func(ctx context.Context, client requester, srvPeerID p2p.Peer) { - err := SyncStore(ctx, client, srvPeerID, storeB, opts...) + err := SyncStore(ctx, client, srvPeerID, storeB, nil, nil, opts...) require.NoError(t, err) if fr, ok := client.(*fakeRequester); ok { t.Logf("numSpecific: %d, bytesSent %d, bytesReceived %d", numSpecific, fr.bytesSent, fr.bytesReceived) } + smtx.Lock() + t.Logf("bytes read: %d, bytes written: %d", numRead, numWritten) + smtx.Unlock() }) return true }) @@ -444,9 +454,9 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) requester { } func TestWireSync(t *testing.T) { - t.Run("fake requester", func(t *testing.T) { - testWireSync(t, fakeRequesterGetter()) - }) + // t.Run("fake requester", func(t *testing.T) { + // testWireSync(t, fakeRequesterGetter()) + // }) t.Run("p2p", func(t *testing.T) { testWireSync(t, p2pRequesterGetter(t)) }) @@ -468,7 +478,7 @@ func testWireProbe(t *testing.T, getRequester getRequesterFunc) requester { func(ctx context.Context, client requester, srvPeerID p2p.Peer) { minA := storeA.Min().Key() infoA := storeA.GetRangeInfo(nil, minA, minA, -1) - prA, err := Probe(ctx, client, srvPeerID, storeB, opts...) + prA, err := Probe(ctx, client, srvPeerID, storeB, nil, nil, opts...) require.NoError(t, err) require.Equal(t, infoA.Fingerprint, prA.FP) require.Equal(t, infoA.Count, prA.Count) @@ -479,7 +489,7 @@ func testWireProbe(t *testing.T, getRequester getRequesterFunc) requester { x := partInfoA.Start.Key().(types.Hash32) y := partInfoA.End.Key().(types.Hash32) // partInfoA = storeA.GetRangeInfo(nil, x, y, -1) - prA, err = BoundedProbe(ctx, client, srvPeerID, storeB, x, y, opts...) + prA, err = Probe(ctx, client, srvPeerID, storeB, &x, &y, opts...) require.NoError(t, err) require.Equal(t, partInfoA.Fingerprint, prA.FP) require.Equal(t, partInfoA.Count, prA.Count) @@ -495,9 +505,9 @@ func TestWireProbe(t *testing.T) { t.Run("fake requester", func(t *testing.T) { testWireProbe(t, fakeRequesterGetter()) }) - // t.Run("p2p", func(t *testing.T) { - // testWireProbe(t, p2pRequesterGetter(t)) - // }) + t.Run("p2p", func(t *testing.T) { + testWireProbe(t, p2pRequesterGetter(t)) + }) } // TODO: test bounded sync diff --git a/hashsync/interface.go b/hashsync/interface.go index e57e2cfc58..042d57a86d 100644 --- a/hashsync/interface.go +++ b/hashsync/interface.go @@ -15,20 +15,19 @@ type requester interface { StreamRequest(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error } -type peerSet interface { - addPeer(p p2p.Peer) - removePeer(p p2p.Peer) - numPeers() int - listPeers() []p2p.Peer - havePeer(p p2p.Peer) bool -} - type syncBase interface { + count() int derive(p p2p.Peer) syncer - probe(ctx context.Context, p p2p.Peer) (int, error) + probe(ctx context.Context, p p2p.Peer) (ProbeResult, error) + run(ctx context.Context) error } type syncer interface { peer() p2p.Peer sync(ctx context.Context, x, y *types.Hash32) error } + +type syncRunner interface { + splitSync(ctx context.Context, syncPeers []p2p.Peer) error + fullSync(ctx context.Context, syncPeers []p2p.Peer) error +} diff --git a/hashsync/mocks_test.go b/hashsync/mocks_test.go index b5495106db..7a0848ce0b 100644 --- a/hashsync/mocks_test.go +++ b/hashsync/mocks_test.go @@ -123,410 +123,376 @@ func (c *MockrequesterStreamRequestCall) DoAndReturn(f func(context.Context, p2p return c } -// MockpeerSet is a mock of peerSet interface. -type MockpeerSet struct { +// MocksyncBase is a mock of syncBase interface. +type MocksyncBase struct { ctrl *gomock.Controller - recorder *MockpeerSetMockRecorder + recorder *MocksyncBaseMockRecorder } -// MockpeerSetMockRecorder is the mock recorder for MockpeerSet. -type MockpeerSetMockRecorder struct { - mock *MockpeerSet +// MocksyncBaseMockRecorder is the mock recorder for MocksyncBase. +type MocksyncBaseMockRecorder struct { + mock *MocksyncBase } -// NewMockpeerSet creates a new mock instance. -func NewMockpeerSet(ctrl *gomock.Controller) *MockpeerSet { - mock := &MockpeerSet{ctrl: ctrl} - mock.recorder = &MockpeerSetMockRecorder{mock} +// NewMocksyncBase creates a new mock instance. +func NewMocksyncBase(ctrl *gomock.Controller) *MocksyncBase { + mock := &MocksyncBase{ctrl: ctrl} + mock.recorder = &MocksyncBaseMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockpeerSet) EXPECT() *MockpeerSetMockRecorder { +func (m *MocksyncBase) EXPECT() *MocksyncBaseMockRecorder { return m.recorder } -// addPeer mocks base method. -func (m *MockpeerSet) addPeer(p p2p.Peer) { +// count mocks base method. +func (m *MocksyncBase) count() int { m.ctrl.T.Helper() - m.ctrl.Call(m, "addPeer", p) -} - -// addPeer indicates an expected call of addPeer. -func (mr *MockpeerSetMockRecorder) addPeer(p any) *MockpeerSetaddPeerCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "addPeer", reflect.TypeOf((*MockpeerSet)(nil).addPeer), p) - return &MockpeerSetaddPeerCall{Call: call} -} - -// MockpeerSetaddPeerCall wrap *gomock.Call -type MockpeerSetaddPeerCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockpeerSetaddPeerCall) Return() *MockpeerSetaddPeerCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockpeerSetaddPeerCall) Do(f func(p2p.Peer)) *MockpeerSetaddPeerCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockpeerSetaddPeerCall) DoAndReturn(f func(p2p.Peer)) *MockpeerSetaddPeerCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// havePeer mocks base method. -func (m *MockpeerSet) havePeer(p p2p.Peer) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "havePeer", p) - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "count") + ret0, _ := ret[0].(int) return ret0 } -// havePeer indicates an expected call of havePeer. -func (mr *MockpeerSetMockRecorder) havePeer(p any) *MockpeerSethavePeerCall { +// count indicates an expected call of count. +func (mr *MocksyncBaseMockRecorder) count() *MocksyncBasecountCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "havePeer", reflect.TypeOf((*MockpeerSet)(nil).havePeer), p) - return &MockpeerSethavePeerCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "count", reflect.TypeOf((*MocksyncBase)(nil).count)) + return &MocksyncBasecountCall{Call: call} } -// MockpeerSethavePeerCall wrap *gomock.Call -type MockpeerSethavePeerCall struct { +// MocksyncBasecountCall wrap *gomock.Call +type MocksyncBasecountCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockpeerSethavePeerCall) Return(arg0 bool) *MockpeerSethavePeerCall { +func (c *MocksyncBasecountCall) Return(arg0 int) *MocksyncBasecountCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockpeerSethavePeerCall) Do(f func(p2p.Peer) bool) *MockpeerSethavePeerCall { +func (c *MocksyncBasecountCall) Do(f func() int) *MocksyncBasecountCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockpeerSethavePeerCall) DoAndReturn(f func(p2p.Peer) bool) *MockpeerSethavePeerCall { +func (c *MocksyncBasecountCall) DoAndReturn(f func() int) *MocksyncBasecountCall { c.Call = c.Call.DoAndReturn(f) return c } -// listPeers mocks base method. -func (m *MockpeerSet) listPeers() []p2p.Peer { +// derive mocks base method. +func (m *MocksyncBase) derive(p p2p.Peer) syncer { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "listPeers") - ret0, _ := ret[0].([]p2p.Peer) + ret := m.ctrl.Call(m, "derive", p) + ret0, _ := ret[0].(syncer) return ret0 } -// listPeers indicates an expected call of listPeers. -func (mr *MockpeerSetMockRecorder) listPeers() *MockpeerSetlistPeersCall { +// derive indicates an expected call of derive. +func (mr *MocksyncBaseMockRecorder) derive(p any) *MocksyncBasederiveCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "listPeers", reflect.TypeOf((*MockpeerSet)(nil).listPeers)) - return &MockpeerSetlistPeersCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "derive", reflect.TypeOf((*MocksyncBase)(nil).derive), p) + return &MocksyncBasederiveCall{Call: call} } -// MockpeerSetlistPeersCall wrap *gomock.Call -type MockpeerSetlistPeersCall struct { +// MocksyncBasederiveCall wrap *gomock.Call +type MocksyncBasederiveCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockpeerSetlistPeersCall) Return(arg0 []p2p.Peer) *MockpeerSetlistPeersCall { +func (c *MocksyncBasederiveCall) Return(arg0 syncer) *MocksyncBasederiveCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockpeerSetlistPeersCall) Do(f func() []p2p.Peer) *MockpeerSetlistPeersCall { +func (c *MocksyncBasederiveCall) Do(f func(p2p.Peer) syncer) *MocksyncBasederiveCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockpeerSetlistPeersCall) DoAndReturn(f func() []p2p.Peer) *MockpeerSetlistPeersCall { +func (c *MocksyncBasederiveCall) DoAndReturn(f func(p2p.Peer) syncer) *MocksyncBasederiveCall { c.Call = c.Call.DoAndReturn(f) return c } -// numPeers mocks base method. -func (m *MockpeerSet) numPeers() int { +// probe mocks base method. +func (m *MocksyncBase) probe(ctx context.Context, p p2p.Peer) (ProbeResult, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "numPeers") - ret0, _ := ret[0].(int) - return ret0 + ret := m.ctrl.Call(m, "probe", ctx, p) + ret0, _ := ret[0].(ProbeResult) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// numPeers indicates an expected call of numPeers. -func (mr *MockpeerSetMockRecorder) numPeers() *MockpeerSetnumPeersCall { +// probe indicates an expected call of probe. +func (mr *MocksyncBaseMockRecorder) probe(ctx, p any) *MocksyncBaseprobeCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "numPeers", reflect.TypeOf((*MockpeerSet)(nil).numPeers)) - return &MockpeerSetnumPeersCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "probe", reflect.TypeOf((*MocksyncBase)(nil).probe), ctx, p) + return &MocksyncBaseprobeCall{Call: call} } -// MockpeerSetnumPeersCall wrap *gomock.Call -type MockpeerSetnumPeersCall struct { +// MocksyncBaseprobeCall wrap *gomock.Call +type MocksyncBaseprobeCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockpeerSetnumPeersCall) Return(arg0 int) *MockpeerSetnumPeersCall { - c.Call = c.Call.Return(arg0) +func (c *MocksyncBaseprobeCall) Return(arg0 ProbeResult, arg1 error) *MocksyncBaseprobeCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockpeerSetnumPeersCall) Do(f func() int) *MockpeerSetnumPeersCall { +func (c *MocksyncBaseprobeCall) Do(f func(context.Context, p2p.Peer) (ProbeResult, error)) *MocksyncBaseprobeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockpeerSetnumPeersCall) DoAndReturn(f func() int) *MockpeerSetnumPeersCall { +func (c *MocksyncBaseprobeCall) DoAndReturn(f func(context.Context, p2p.Peer) (ProbeResult, error)) *MocksyncBaseprobeCall { c.Call = c.Call.DoAndReturn(f) return c } -// removePeer mocks base method. -func (m *MockpeerSet) removePeer(p p2p.Peer) { +// run mocks base method. +func (m *MocksyncBase) run(ctx context.Context) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "removePeer", p) + ret := m.ctrl.Call(m, "run", ctx) + ret0, _ := ret[0].(error) + return ret0 } -// removePeer indicates an expected call of removePeer. -func (mr *MockpeerSetMockRecorder) removePeer(p any) *MockpeerSetremovePeerCall { +// run indicates an expected call of run. +func (mr *MocksyncBaseMockRecorder) run(ctx any) *MocksyncBaserunCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "removePeer", reflect.TypeOf((*MockpeerSet)(nil).removePeer), p) - return &MockpeerSetremovePeerCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "run", reflect.TypeOf((*MocksyncBase)(nil).run), ctx) + return &MocksyncBaserunCall{Call: call} } -// MockpeerSetremovePeerCall wrap *gomock.Call -type MockpeerSetremovePeerCall struct { +// MocksyncBaserunCall wrap *gomock.Call +type MocksyncBaserunCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockpeerSetremovePeerCall) Return() *MockpeerSetremovePeerCall { - c.Call = c.Call.Return() +func (c *MocksyncBaserunCall) Return(arg0 error) *MocksyncBaserunCall { + c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockpeerSetremovePeerCall) Do(f func(p2p.Peer)) *MockpeerSetremovePeerCall { +func (c *MocksyncBaserunCall) Do(f func(context.Context) error) *MocksyncBaserunCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockpeerSetremovePeerCall) DoAndReturn(f func(p2p.Peer)) *MockpeerSetremovePeerCall { +func (c *MocksyncBaserunCall) DoAndReturn(f func(context.Context) error) *MocksyncBaserunCall { c.Call = c.Call.DoAndReturn(f) return c } -// MocksyncBase is a mock of syncBase interface. -type MocksyncBase struct { +// Mocksyncer is a mock of syncer interface. +type Mocksyncer struct { ctrl *gomock.Controller - recorder *MocksyncBaseMockRecorder + recorder *MocksyncerMockRecorder } -// MocksyncBaseMockRecorder is the mock recorder for MocksyncBase. -type MocksyncBaseMockRecorder struct { - mock *MocksyncBase +// MocksyncerMockRecorder is the mock recorder for Mocksyncer. +type MocksyncerMockRecorder struct { + mock *Mocksyncer } -// NewMocksyncBase creates a new mock instance. -func NewMocksyncBase(ctrl *gomock.Controller) *MocksyncBase { - mock := &MocksyncBase{ctrl: ctrl} - mock.recorder = &MocksyncBaseMockRecorder{mock} +// NewMocksyncer creates a new mock instance. +func NewMocksyncer(ctrl *gomock.Controller) *Mocksyncer { + mock := &Mocksyncer{ctrl: ctrl} + mock.recorder = &MocksyncerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MocksyncBase) EXPECT() *MocksyncBaseMockRecorder { +func (m *Mocksyncer) EXPECT() *MocksyncerMockRecorder { return m.recorder } -// derive mocks base method. -func (m *MocksyncBase) derive(p p2p.Peer) syncer { +// peer mocks base method. +func (m *Mocksyncer) peer() p2p.Peer { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "derive", p) - ret0, _ := ret[0].(syncer) + ret := m.ctrl.Call(m, "peer") + ret0, _ := ret[0].(p2p.Peer) return ret0 } -// derive indicates an expected call of derive. -func (mr *MocksyncBaseMockRecorder) derive(p any) *MocksyncBasederiveCall { +// peer indicates an expected call of peer. +func (mr *MocksyncerMockRecorder) peer() *MocksyncerpeerCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "derive", reflect.TypeOf((*MocksyncBase)(nil).derive), p) - return &MocksyncBasederiveCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "peer", reflect.TypeOf((*Mocksyncer)(nil).peer)) + return &MocksyncerpeerCall{Call: call} } -// MocksyncBasederiveCall wrap *gomock.Call -type MocksyncBasederiveCall struct { +// MocksyncerpeerCall wrap *gomock.Call +type MocksyncerpeerCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MocksyncBasederiveCall) Return(arg0 syncer) *MocksyncBasederiveCall { +func (c *MocksyncerpeerCall) Return(arg0 p2p.Peer) *MocksyncerpeerCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MocksyncBasederiveCall) Do(f func(p2p.Peer) syncer) *MocksyncBasederiveCall { +func (c *MocksyncerpeerCall) Do(f func() p2p.Peer) *MocksyncerpeerCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MocksyncBasederiveCall) DoAndReturn(f func(p2p.Peer) syncer) *MocksyncBasederiveCall { +func (c *MocksyncerpeerCall) DoAndReturn(f func() p2p.Peer) *MocksyncerpeerCall { c.Call = c.Call.DoAndReturn(f) return c } -// probe mocks base method. -func (m *MocksyncBase) probe(ctx context.Context, p p2p.Peer) (int, error) { +// sync mocks base method. +func (m *Mocksyncer) sync(ctx context.Context, x, y *types.Hash32) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "probe", ctx, p) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "sync", ctx, x, y) + ret0, _ := ret[0].(error) + return ret0 } -// probe indicates an expected call of probe. -func (mr *MocksyncBaseMockRecorder) probe(ctx, p any) *MocksyncBaseprobeCall { +// sync indicates an expected call of sync. +func (mr *MocksyncerMockRecorder) sync(ctx, x, y any) *MocksyncersyncCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "probe", reflect.TypeOf((*MocksyncBase)(nil).probe), ctx, p) - return &MocksyncBaseprobeCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "sync", reflect.TypeOf((*Mocksyncer)(nil).sync), ctx, x, y) + return &MocksyncersyncCall{Call: call} } -// MocksyncBaseprobeCall wrap *gomock.Call -type MocksyncBaseprobeCall struct { +// MocksyncersyncCall wrap *gomock.Call +type MocksyncersyncCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MocksyncBaseprobeCall) Return(arg0 int, arg1 error) *MocksyncBaseprobeCall { - c.Call = c.Call.Return(arg0, arg1) +func (c *MocksyncersyncCall) Return(arg0 error) *MocksyncersyncCall { + c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MocksyncBaseprobeCall) Do(f func(context.Context, p2p.Peer) (int, error)) *MocksyncBaseprobeCall { +func (c *MocksyncersyncCall) Do(f func(context.Context, *types.Hash32, *types.Hash32) error) *MocksyncersyncCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MocksyncBaseprobeCall) DoAndReturn(f func(context.Context, p2p.Peer) (int, error)) *MocksyncBaseprobeCall { +func (c *MocksyncersyncCall) DoAndReturn(f func(context.Context, *types.Hash32, *types.Hash32) error) *MocksyncersyncCall { c.Call = c.Call.DoAndReturn(f) return c } -// Mocksyncer is a mock of syncer interface. -type Mocksyncer struct { +// MocksyncRunner is a mock of syncRunner interface. +type MocksyncRunner struct { ctrl *gomock.Controller - recorder *MocksyncerMockRecorder + recorder *MocksyncRunnerMockRecorder } -// MocksyncerMockRecorder is the mock recorder for Mocksyncer. -type MocksyncerMockRecorder struct { - mock *Mocksyncer +// MocksyncRunnerMockRecorder is the mock recorder for MocksyncRunner. +type MocksyncRunnerMockRecorder struct { + mock *MocksyncRunner } -// NewMocksyncer creates a new mock instance. -func NewMocksyncer(ctrl *gomock.Controller) *Mocksyncer { - mock := &Mocksyncer{ctrl: ctrl} - mock.recorder = &MocksyncerMockRecorder{mock} +// NewMocksyncRunner creates a new mock instance. +func NewMocksyncRunner(ctrl *gomock.Controller) *MocksyncRunner { + mock := &MocksyncRunner{ctrl: ctrl} + mock.recorder = &MocksyncRunnerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *Mocksyncer) EXPECT() *MocksyncerMockRecorder { +func (m *MocksyncRunner) EXPECT() *MocksyncRunnerMockRecorder { return m.recorder } -// peer mocks base method. -func (m *Mocksyncer) peer() p2p.Peer { +// fullSync mocks base method. +func (m *MocksyncRunner) fullSync(ctx context.Context, syncPeers []p2p.Peer) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "peer") - ret0, _ := ret[0].(p2p.Peer) + ret := m.ctrl.Call(m, "fullSync", ctx, syncPeers) + ret0, _ := ret[0].(error) return ret0 } -// peer indicates an expected call of peer. -func (mr *MocksyncerMockRecorder) peer() *MocksyncerpeerCall { +// fullSync indicates an expected call of fullSync. +func (mr *MocksyncRunnerMockRecorder) fullSync(ctx, syncPeers any) *MocksyncRunnerfullSyncCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "peer", reflect.TypeOf((*Mocksyncer)(nil).peer)) - return &MocksyncerpeerCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "fullSync", reflect.TypeOf((*MocksyncRunner)(nil).fullSync), ctx, syncPeers) + return &MocksyncRunnerfullSyncCall{Call: call} } -// MocksyncerpeerCall wrap *gomock.Call -type MocksyncerpeerCall struct { +// MocksyncRunnerfullSyncCall wrap *gomock.Call +type MocksyncRunnerfullSyncCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MocksyncerpeerCall) Return(arg0 p2p.Peer) *MocksyncerpeerCall { +func (c *MocksyncRunnerfullSyncCall) Return(arg0 error) *MocksyncRunnerfullSyncCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MocksyncerpeerCall) Do(f func() p2p.Peer) *MocksyncerpeerCall { +func (c *MocksyncRunnerfullSyncCall) Do(f func(context.Context, []p2p.Peer) error) *MocksyncRunnerfullSyncCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MocksyncerpeerCall) DoAndReturn(f func() p2p.Peer) *MocksyncerpeerCall { +func (c *MocksyncRunnerfullSyncCall) DoAndReturn(f func(context.Context, []p2p.Peer) error) *MocksyncRunnerfullSyncCall { c.Call = c.Call.DoAndReturn(f) return c } -// sync mocks base method. -func (m *Mocksyncer) sync(ctx context.Context, x, y *types.Hash32) error { +// splitSync mocks base method. +func (m *MocksyncRunner) splitSync(ctx context.Context, syncPeers []p2p.Peer) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "sync", ctx, x, y) + ret := m.ctrl.Call(m, "splitSync", ctx, syncPeers) ret0, _ := ret[0].(error) return ret0 } -// sync indicates an expected call of sync. -func (mr *MocksyncerMockRecorder) sync(ctx, x, y any) *MocksyncersyncCall { +// splitSync indicates an expected call of splitSync. +func (mr *MocksyncRunnerMockRecorder) splitSync(ctx, syncPeers any) *MocksyncRunnersplitSyncCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "sync", reflect.TypeOf((*Mocksyncer)(nil).sync), ctx, x, y) - return &MocksyncersyncCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "splitSync", reflect.TypeOf((*MocksyncRunner)(nil).splitSync), ctx, syncPeers) + return &MocksyncRunnersplitSyncCall{Call: call} } -// MocksyncersyncCall wrap *gomock.Call -type MocksyncersyncCall struct { +// MocksyncRunnersplitSyncCall wrap *gomock.Call +type MocksyncRunnersplitSyncCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MocksyncersyncCall) Return(arg0 error) *MocksyncersyncCall { +func (c *MocksyncRunnersplitSyncCall) Return(arg0 error) *MocksyncRunnersplitSyncCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MocksyncersyncCall) Do(f func(context.Context, *types.Hash32, *types.Hash32) error) *MocksyncersyncCall { +func (c *MocksyncRunnersplitSyncCall) Do(f func(context.Context, []p2p.Peer) error) *MocksyncRunnersplitSyncCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MocksyncersyncCall) DoAndReturn(f func(context.Context, *types.Hash32, *types.Hash32) error) *MocksyncersyncCall { +func (c *MocksyncRunnersplitSyncCall) DoAndReturn(f func(context.Context, []p2p.Peer) error) *MocksyncRunnersplitSyncCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/hashsync/multipeer.go b/hashsync/multipeer.go index ebdbd286dd..e3d29cfb07 100644 --- a/hashsync/multipeer.go +++ b/hashsync/multipeer.go @@ -3,77 +3,69 @@ package hashsync import ( "context" "errors" - "sync" "time" "github.com/jonboulle/clockwork" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + + "github.com/spacemeshos/go-spacemesh/fetch/peers" "github.com/spacemeshos/go-spacemesh/log" "github.com/spacemeshos/go-spacemesh/p2p" - "go.uber.org/zap" - "golang.org/x/exp/maps" ) -// type DataHandler func(context.Context, types.Hash32, p2p.Peer, any) error - -// type dataItem struct { -// key types.Hash32 -// value any -// } - -// type dataItemHandler func(di dataItem) - -// type derivedStore struct { -// ItemStore -// handler dataItemHandler -// // itemCh chan dataItem -// // // TODO: don't embed context in the struct -// // ctx context.Context -// } - -// func (s *derivedStore) Add(k Ordered, v any) { -// s.ItemStore.Add(k, v) -// s.handler(dataItem{key: k.(types.Hash32), value: v}) -// // select { -// // case <-s.ctx.Done(): -// // case s.itemCh <- dataItem{key: k.(types.Hash32), value: v}: -// // } -// } - -type probeResult struct { - probed map[p2p.Peer]int - minCount int - maxCount int +type syncability struct { + // peers that were probed successfully + syncable []p2p.Peer + // peers that have enough items for split sync + splitSyncable []p2p.Peer + // Number of peers that are similar enough to this one for full sync + nearFullCount int } -// type peerReconciler struct { -// st SyncTree -// } - type MultiPeerReconcilerOpt func(mpr *MultiPeerReconciler) +func WithSyncPeerCount(count int) MultiPeerReconcilerOpt { + return func(mpr *MultiPeerReconciler) { + mpr.syncPeerCount = count + } +} + func WithMinFullSyncCount(count int) MultiPeerReconcilerOpt { return func(mpr *MultiPeerReconciler) { - mpr.minPartSyncCount = count + mpr.minSplitSyncCount = count } } -func WithMinFullFraction(frac float64) MultiPeerReconcilerOpt { +func WithMaxFullDiff(diff int) MultiPeerReconcilerOpt { return func(mpr *MultiPeerReconciler) { - mpr.minFullFraction = frac + mpr.maxFullDiff = diff } } -// func WithMinPartSyncPeers(n int) MultiPeerReconcilerOpt { -// return func(mpr *MultiPeerReconciler) { -// mpr.minPartSyncPeers = n -// } -// } +func WithSyncInterval(d time.Duration) MultiPeerReconcilerOpt { + return func(mpr *MultiPeerReconciler) { + mpr.syncInterval = d + } +} -// func WithPeerSyncTimeout(t time.Duration) MultiPeerReconcilerOpt { -// return func(mpr *MultiPeerReconciler) { -// mpr.peerSyncTimeout = t -// } -// } +func WithNoPeersRecheckInterval(d time.Duration) MultiPeerReconcilerOpt { + return func(mpr *MultiPeerReconciler) { + mpr.noPeersRecheckInterval = d + } +} + +func WithMinSplitSyncPeers(n int) MultiPeerReconcilerOpt { + return func(mpr *MultiPeerReconciler) { + mpr.minSplitSyncPeers = n + } +} + +func WithMinCompleteFraction(f float64) MultiPeerReconcilerOpt { + return func(mpr *MultiPeerReconciler) { + mpr.minCompleteFraction = f + } +} func WithSplitSyncGracePeriod(t time.Duration) MultiPeerReconcilerOpt { return func(mpr *MultiPeerReconciler) { @@ -87,245 +79,222 @@ func withClock(clock clockwork.Clock) MultiPeerReconcilerOpt { } } -type MultiPeerReconciler struct { - logger zap.Logger - // minPartSyncPeers int - minPartSyncCount int - minFullFraction float64 - splitSyncGracePeriod time.Duration - // peerSyncTimeout time.Duration - syncBase syncBase - peerLock sync.Mutex - peers map[p2p.Peer]struct{} - clock clockwork.Clock -} - -func NewMultiPeerReconciler(logger zap.Logger, syncBase syncBase, opts ...MultiPeerReconcilerOpt) *MultiPeerReconciler { - return &MultiPeerReconciler{ - // minPartSyncPeers: 2, - minPartSyncCount: 1000, - minFullFraction: 0.95, - splitSyncGracePeriod: time.Minute, - syncBase: syncBase, - clock: clockwork.NewRealClock(), +func withSyncRunner(runner syncRunner) MultiPeerReconcilerOpt { + return func(mpr *MultiPeerReconciler) { + mpr.runner = runner } } -func (mpr *MultiPeerReconciler) addPeer(p p2p.Peer) { - mpr.peerLock.Lock() - defer mpr.peerLock.Unlock() - mpr.peers[p] = struct{}{} +type runner struct { + mpr *MultiPeerReconciler } -func (mpr *MultiPeerReconciler) removePeer(p p2p.Peer) { - mpr.peerLock.Lock() - defer mpr.peerLock.Unlock() - delete(mpr.peers, p) +var _ syncRunner = &runner{} + +func (r *runner) splitSync(ctx context.Context, syncPeers []p2p.Peer) error { + s := newSplitSync( + r.mpr.logger, r.mpr.syncBase, r.mpr.peers, syncPeers, + r.mpr.splitSyncGracePeriod, r.mpr.clock) + return s.sync(ctx) } -func (mpr *MultiPeerReconciler) numPeers() int { - mpr.peerLock.Lock() - defer mpr.peerLock.Unlock() - return len(mpr.peers) +func (r *runner) fullSync(ctx context.Context, syncPeers []p2p.Peer) error { + return r.mpr.fullSync(ctx, syncPeers) } -func (mpr *MultiPeerReconciler) listPeers() []p2p.Peer { - mpr.peerLock.Lock() - defer mpr.peerLock.Unlock() - return maps.Keys(mpr.peers) +type MultiPeerReconciler struct { + logger *zap.Logger + syncBase syncBase + peers *peers.Peers + syncPeerCount int + minSplitSyncPeers int + minSplitSyncCount int + maxFullDiff int + minCompleteFraction float64 + splitSyncGracePeriod time.Duration + syncInterval time.Duration + noPeersRecheckInterval time.Duration + clock clockwork.Clock + runner syncRunner } -func (mpr *MultiPeerReconciler) havePeer(p p2p.Peer) bool { - mpr.peerLock.Lock() - defer mpr.peerLock.Unlock() - _, found := mpr.peers[p] - return found +func NewMultiPeerReconciler( + logger *zap.Logger, + syncBase syncBase, + peers *peers.Peers, + opts ...MultiPeerReconcilerOpt, +) *MultiPeerReconciler { + mpr := &MultiPeerReconciler{ + logger: logger, + syncBase: syncBase, + peers: peers, + syncPeerCount: 20, + minSplitSyncPeers: 2, + minSplitSyncCount: 1000, + maxFullDiff: 10000, + syncInterval: 5 * time.Minute, + minCompleteFraction: 0.5, + splitSyncGracePeriod: time.Minute, + noPeersRecheckInterval: 30 * time.Second, + clock: clockwork.NewRealClock(), + } + for _, opt := range opts { + opt(mpr) + } + if mpr.runner == nil { + mpr.runner = &runner{mpr: mpr} + } + return mpr } -func (mpr *MultiPeerReconciler) probePeers(ctx context.Context) (*probeResult, error) { - var pr probeResult - for _, p := range mpr.listPeers() { - count, err := mpr.syncBase.probe(ctx, p) +func (mpr *MultiPeerReconciler) probePeers(ctx context.Context, syncPeers []p2p.Peer) (syncability, error) { + var s syncability + for _, p := range syncPeers { + pr, err := mpr.syncBase.probe(ctx, p) if err != nil { log.Warning("error probing the peer", zap.Any("peer", p), zap.Error(err)) if errors.Is(err, context.Canceled) { - return nil, err + return s, err } continue } - if pr.probed == nil { - pr.probed = map[p2p.Peer]int{ - p: count, + s.syncable = append(s.syncable, p) + if pr.Count > mpr.minSplitSyncCount { + s.splitSyncable = append(s.splitSyncable, p) + } + if (1-pr.Sim)*float64(mpr.syncBase.count()) < float64(mpr.maxFullDiff) { + s.nearFullCount++ + } + } + return s, nil +} + +func (mpr *MultiPeerReconciler) needSplitSync(s syncability) bool { + if float64(s.nearFullCount) >= float64(mpr.syncBase.count())*mpr.minCompleteFraction { + // enough peers are close to this one according to minhash score, can do + // full sync + return false + } + + if len(s.splitSyncable) < mpr.minSplitSyncPeers { + // would be nice to do split sync, but not enough peers for that + return false + } + + return true +} + +func (mpr *MultiPeerReconciler) fullSync(ctx context.Context, syncPeers []p2p.Peer) error { + var eg errgroup.Group + for _, p := range syncPeers { + syncer := mpr.syncBase.derive(p) + eg.Go(func() error { + err := syncer.sync(ctx, nil, nil) + switch { + case err == nil: + case errors.Is(err, context.Canceled): + return err + default: + mpr.logger.Error("error syncing peer", zap.Stringer("peer", p), zap.Error(err)) } - pr.minCount = count - pr.maxCount = count - } else { - pr.probed[p] = count - if count < pr.minCount { - pr.minCount = count + return nil + }) + } + return eg.Wait() +} + +func (mpr *MultiPeerReconciler) Run(ctx context.Context) error { + // The point of using split sync, which syncs different key ranges against + // different peers, vs full sync which syncs the full key range against different + // peers, is: + // 1. Avoid getting too many range splits and thus network transfer overhead + // 2. Avoid fetching same keys from multiple peers + + // States: + // A. Wait. Pause for sync interval + // Timeout => A + // B. No peers -> do nothing. + // Got any peers => C + // C. Low on peers. Wait for more to appear + // Lost all peers => B + // Got enough peers => D + // Timeout => D + // D. Probe the peers. Use successfully probed ones in states E/F + // Drop failed peers from the peer set while polling. + // All probes failed => B + // N of peers < minSplitSyncPeers => E + // All are low on count (minSplitSyncCount) => F + // Enough peers (minCompleteFraction) with diffSize <= maxFullDiff => E + // diffSize = (1-sim)*localItemCount + // Otherwise => F + // E. Full sync. Run full syncs against each peer + // All syncs completed (success / fail) => A + // F. Bounded sync. Subdivide the range by peers and start syncs. + // Use peers with > minSplitSyncCount + // Wait for all the syncs to complete/fail + // All syncs completed (success / fail) => A + ctx, cancel := context.WithCancel(ctx) + var eg errgroup.Group + eg.Go(func() error { + err := mpr.syncBase.run(ctx) + if err != nil && !errors.Is(err, context.Canceled) { + cancel() + mpr.logger.Error("error processing synced items", zap.Error(err)) + return err + } + return nil + }) + defer func() { + cancel() + eg.Wait() + }() + + for { + select { + case <-ctx.Done(): + cancel() + // if the key handlers have caused an error, return that error + return eg.Wait() + case <-mpr.clock.After(mpr.syncInterval): + } + + var ( + s syncability + err error + ) + for { + syncPeers := mpr.peers.SelectBest(mpr.syncPeerCount) + if len(syncPeers) != 0 { + // probePeers doesn't return transient errors, sync must stop if it failed + s, err = mpr.probePeers(ctx, syncPeers) + if err != nil { + return err + } + if len(s.syncable) != 0 { + break + } } - if count > pr.maxCount { - pr.maxCount = count + + select { + case <-ctx.Done(): + return ctx.Err() + case <-mpr.clock.After(mpr.noPeersRecheckInterval): } } + + if mpr.needSplitSync(s) { + err = mpr.runner.splitSync(ctx, s.splitSyncable) + } else { + err = mpr.runner.fullSync(ctx, s.splitSyncable) + } + + if err != nil { + return err + } } - return &pr, nil } -// func (mpr *MultiPeerReconciler) splitSync(ctx context.Context, peers []p2p.Peer) error { -// // Use priority queue. Higher priority = more time since started syncing -// // Highest priority = not started syncing yet -// // Mark syncRange as synced when it's done, next time it's popped from the queue, -// // it will be dropped -// // When picking up an entry which is already being synced, start with -// // SyncTree of the entry -// // TODO: when all of the ranges are synced at least once, just return. -// // The remaining syncs will be canceled -// // TODO: when no available peers remain, return failure -// if len(peers) == 0 { -// panic("BUG: no peers passed to splitSync") -// } -// syncCtx, cancel := context.WithCancel(ctx) -// defer cancel() -// delim := getDelimiters(len(peers)) -// sq := make(syncQueue, len(peers)) -// var y types.Hash32 -// for n := range sq { -// x := y -// if n == len(peers)-1 { -// y = types.Hash32{} -// } else { -// y = delim[n] -// } -// sq[n] = &syncRange{ -// x: x, -// y: y, -// } -// } -// heap.Init(&sq) -// peers = slices.Clone(peers) -// resCh := make(chan syncResult) -// syncMap := make(map[p2p.Peer]*syncRange) -// numRunning := 0 -// numRemaining := len(peers) -// numPeers := len(peers) -// needGracePeriod := true -// for numRemaining > 0 { -// p := peers[0] -// peers = peers[1:] -// var sr *syncRange -// for len(sq) != 0 { -// sr = heap.Pop(&sq).(*syncRange) -// if !sr.done { -// break -// } -// sr = nil -// } -// if sr == nil { -// panic("BUG: bad syncRange accounting in splitSync") -// } -// syncMap[p] = sr -// var s syncer -// if len(sr.syncers) != 0 { -// // derive from an existing syncer to get sync against -// // more up-to-date data -// s = sr.syncers[len(sr.syncers)-1].derive(p) -// } else { -// s = mpr.syncBase.derive(p) -// } -// sr.syncers = append(sr.syncers, s) -// numRunning++ -// // push this syncRange to the back of the queue as a fresh sync -// // is just starting -// sq.update(sr, mpr.clock.Now()) -// go func() { -// err := s.sync(syncCtx, &sr.x, &sr.y) -// select { -// case <-syncCtx.Done(): -// case resCh <- syncResult{s: s, err: err}: -// } -// }() - -// peers := slices.DeleteFunc(peers, func(p p2p.Peer) bool { -// return !mpr.havePeer(p) -// }) - -// // Grace period: after at least one syncer finishes, wait a bit -// // before assigning it another range to avoid unneeded traffic. -// // The grace period ends if any of the syncers fail -// var gpTimer <-chan time.Time -// if needGracePeriod { -// gpTimer = mpr.clock.After(mpr.splitSyncGracePeriod) -// } -// for needGracePeriod && len(peers) == 0 { -// if numRunning == 0 { -// return errors.New("all peers dropped before full sync has completed") -// } - -// var r syncResult -// select { -// case <-syncCtx.Done(): -// return syncCtx.Err() -// case r = <-resCh: -// case <-gpTimer: -// needGracePeriod = false -// } - -// sr, found := syncMap[s.peer()] -// if !found { -// panic("BUG: error in split sync syncMap handling") -// } -// numRunning-- -// delete(syncMap, s.peer()) -// n := slices.Index(sr.syncers, s) -// if n < 0 { -// panic("BUG: bad syncers in syncRange") -// } -// sr.syncers = slices.Delete(sr.syncers, n, n+1) -// if r.err != nil { -// numPeers-- -// mpr.RemovePeer(s.peer()) -// if numPeers == 0 && numRemaining != 0 { -// return errors.New("all peers dropped before full sync has completed") -// } -// if len(sr.syncers) == 0 { -// // prioritize the syncRange for resync after failed -// // sync with no active syncs remaining -// sq.update(sr, time.Time{}) -// } -// needGracePeriod = false -// } else { -// sr.done = true -// peers = append(peers, s.peer()) -// numRemaining-- -// } -// } -// } - -// return nil -// } - -func (mpr *MultiPeerReconciler) run(ctx context.Context) error { - // States: - // A. No peers -> do nothing. - // Got any peers => B - // B. Low on peers. Wait for more to appear - // Lost all peers => A - // Got enough peers => C - // Timeout => C - // C. Probe the peers. Use successfully probed ones in states D/E - // All probes failed => A - // All are low on count (minPartSyncCount) => E - // Some have substantially higher count (minFullFraction) => D - // Otherwise => E - // D. Bounded sync. Subdivide the range by peers and start syncs. - // Use peers with > minPartSyncCount - // Wait for all the syncs to complete/fail - // All syncs succeeded => A - // Any syncs failed => A - // E. Full sync. Run full syncs against each peer - // All syncs completed (success / fail) => F - // F. Wait. Pause for sync interval - // Timeout => A - panic("TBD") +type HashSyncBase struct { + r requester + is ItemStore } diff --git a/hashsync/rangesync.go b/hashsync/rangesync.go index 918c6fffb2..28f624a2e0 100644 --- a/hashsync/rangesync.go +++ b/hashsync/rangesync.go @@ -1,6 +1,7 @@ package hashsync import ( + "context" "errors" "fmt" "reflect" @@ -167,7 +168,7 @@ type RangeInfo struct { type ItemStore interface { // Add adds a key-value pair to the store - Add(k Ordered, v any) + Add(ctx context.Context, k Ordered, v any) error // GetRangeInfo returns RangeInfo for the item range in the tree. // If count >= 0, at most count items are returned, and RangeInfo // is returned for the corresponding subrange of the requested range. @@ -182,6 +183,10 @@ type ItemStore interface { Max() Iterator // New returns an empty payload value New() any + // Copy makes a shallow copy of the ItemStore + Copy() ItemStore + // Has returns true if the specified key is present in ItemStore + Has(k Ordered) bool } type ProbeResult struct { @@ -223,7 +228,7 @@ func NewRangeSetReconciler(is ItemStore, opts ...Option) *RangeSetReconciler { // return fmt.Sprintf("%s", it.Key()) // } -func (rsr *RangeSetReconciler) processSubrange(c Conduit, preceding, start, end Iterator, x, y Ordered) (Iterator, error) { +func (rsr *RangeSetReconciler) processSubrange(c Conduit, preceding Iterator, x, y Ordered) (Iterator, error) { if preceding != nil && preceding.Key().Compare(x) > 0 { preceding = nil } @@ -358,12 +363,12 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg panic("BUG: can't split range with count > 1") } middle := part.End.Key() - next, err := rsr.processSubrange(c, info.Start, part.Start, part.End, x, middle) + next, err := rsr.processSubrange(c, info.Start, x, middle) if err != nil { return nil, false, err } // fmt.Fprintf(os.Stderr, "QQQQQ: next=%q\n", qqqqRmmeK(next)) - _, err = rsr.processSubrange(c, next, part.End, info.End, middle, y) + _, err = rsr.processSubrange(c, next, middle, y) if err != nil { return nil, false, err } @@ -540,7 +545,7 @@ func (rsr *RangeSetReconciler) HandleProbeResponse(c Conduit, info RangeInfo) (p } } -func (rsr *RangeSetReconciler) Process(c Conduit) (done bool, err error) { +func (rsr *RangeSetReconciler) Process(ctx context.Context, c Conduit) (done bool, err error) { var msgs []SyncMessage // All of the messages need to be received before processing // them, as processing the messages involves sending more @@ -558,7 +563,9 @@ func (rsr *RangeSetReconciler) Process(c Conduit) (done bool, err error) { if msg.Type() == MessageTypeItemBatch { vals := msg.Values() for n, k := range msg.Keys() { - rsr.is.Add(k, vals[n]) + if err := rsr.is.Add(ctx, k, vals[n]); err != nil { + return false, fmt.Errorf("error adding an item to the store: %w", err) + } } continue } diff --git a/hashsync/rangesync_test.go b/hashsync/rangesync_test.go index d9c7cac3b9..bf2ba15163 100644 --- a/hashsync/rangesync_test.go +++ b/hashsync/rangesync_test.go @@ -1,6 +1,8 @@ package hashsync import ( + "context" + "fmt" "math/rand" "slices" "testing" @@ -214,7 +216,7 @@ type dumbStore struct { var _ ItemStore = &dumbStore{} -func (ds *dumbStore) Add(k Ordered, v any) { +func (ds *dumbStore) Add(ctx context.Context, k Ordered, v any) error { if ds.m == nil { ds.m = make(map[sampleID]any) } @@ -222,7 +224,7 @@ func (ds *dumbStore) Add(k Ordered, v any) { if len(ds.keys) == 0 { ds.keys = []sampleID{id} ds.m[id] = v - return + return nil } p := slices.IndexFunc(ds.keys, func(other sampleID) bool { return other >= id @@ -237,6 +239,8 @@ func (ds *dumbStore) Add(k Ordered, v any) { ds.keys = slices.Insert(ds.keys, p, id) ds.m[id] = v } + + return nil } func (ds *dumbStore) iter(n int) Iterator { @@ -320,6 +324,19 @@ func (it *dumbStore) New() any { panic("not implemented") } +func (ds *dumbStore) Copy() ItemStore { + panic("not implemented") +} + +func (ds *dumbStore) Has(k Ordered) bool { + for _, cur := range ds.keys { + if k.Compare(cur) == 0 { + return true + } + } + return false +} + type verifiedStoreIterator struct { t *testing.T knownGood Iterator @@ -375,7 +392,7 @@ func disableReAdd(s ItemStore) { } } -func (vs *verifiedStore) Add(k Ordered, v any) { +func (vs *verifiedStore) Add(ctx context.Context, k Ordered, v any) error { if vs.disableReAdd { _, found := vs.added[k.(sampleID)] require.False(vs.t, found, "hash sent twice: %v", k) @@ -384,8 +401,13 @@ func (vs *verifiedStore) Add(k Ordered, v any) { } vs.added[k.(sampleID)] = struct{}{} } - vs.knownGood.Add(k, v) - vs.store.Add(k, v) + if err := vs.knownGood.Add(ctx, k, v); err != nil { + return fmt.Errorf("add to knownGood: %w", err) + } + if err := vs.store.Add(ctx, k, v); err != nil { + return fmt.Errorf("add to store: %w", err) + } + return nil } func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo { @@ -476,6 +498,22 @@ func (vs *verifiedStore) New() any { return v2 } +func (vs *verifiedStore) Copy() ItemStore { + return &verifiedStore{ + t: vs.t, + knownGood: vs.knownGood.Copy(), + store: vs.store.Copy(), + disableReAdd: vs.disableReAdd, + } +} + +func (vs *verifiedStore) Has(k Ordered) bool { + h1 := vs.knownGood.Has(k) + h2 := vs.store.Has(k) + require.Equal(vs.t, h1, h2) + return h2 +} + type storeFactory func(t *testing.T) ItemStore func makeDumbStore(t *testing.T) ItemStore { @@ -500,7 +538,7 @@ func makeVerifiedSyncTreeStore(t *testing.T) ItemStore { func makeStore(t *testing.T, f storeFactory, items string) ItemStore { s := f(t) for _, c := range items { - s.Add(sampleID(c), "") + require.NoError(t, s.Add(context.Background(), sampleID(c), "")) } return s } @@ -580,7 +618,7 @@ func doRunSync(fc *fakeConduit, syncA, syncB *RangeSetReconciler, maxRounds int) nMsg += len(fc.msgs) nItems += fc.numItems() var err error - bDone, err = syncB.Process(fc) + bDone, err = syncB.Process(context.Background(), fc) require.NoError(fc.t, err) // a party should never send anything in response to the "done" message require.False(fc.t, aDone && !bDone, "A is done but B after that is not") @@ -593,7 +631,7 @@ func doRunSync(fc *fakeConduit, syncA, syncB *RangeSetReconciler, maxRounds int) fc.gotoResponse() nMsg += len(fc.msgs) nItems += fc.numItems() - aDone, err = syncA.Process(fc) + aDone, err = syncA.Process(context.Background(), fc) require.NoError(fc.t, err) // dumpRangeMessages(fc.t, fc.msgs, "A %q --> B %q:", storeItemStr(syncB.is), storeItemStr(syncA.is)) // dumpRangeMessages(fc.t, fc.resp.msgs, "A -> B:") @@ -623,7 +661,7 @@ func runBoundedProbe(t *testing.T, from, to *RangeSetReconciler, x, y Ordered) P func doRunProbe(fc *fakeConduit, from, to *RangeSetReconciler, info RangeInfo) ProbeResult { require.NotEmpty(fc.t, fc.resp, "empty initial round") fc.gotoResponse() - done, err := to.Process(fc) + done, err := to.Process(context.Background(), fc) require.True(fc.t, done) require.NoError(fc.t, err) fc.gotoResponse() diff --git a/hashsync/setsyncbase.go b/hashsync/setsyncbase.go new file mode 100644 index 0000000000..a088dd7b83 --- /dev/null +++ b/hashsync/setsyncbase.go @@ -0,0 +1,121 @@ +package hashsync + +import ( + "context" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/spacemeshos/go-spacemesh/common/types" + "golang.org/x/sync/errgroup" +) + +type syncKeyHandler func(ctx context.Context, k Ordered) error + +type setSyncBase struct { + r requester + is ItemStore + handler syncKeyHandler + opts []Option + keyCh chan Ordered +} + +var _ syncBase = &setSyncBase{} + +func newSetSyncBase(r requester, is ItemStore, handler syncKeyHandler, opts ...Option) *setSyncBase { + return &setSyncBase{ + r: r, + is: is, + handler: handler, + opts: opts, + keyCh: make(chan Ordered), + } +} + +// count implements syncBase. +func (ssb *setSyncBase) count() int { + it := ssb.is.Min() + if it == nil { + return 0 + } + x := it.Key() + return ssb.is.GetRangeInfo(nil, x, x, -1).Count +} + +// derive implements syncBase. +func (ssb *setSyncBase) derive(p peer.ID) syncer { + return &setSyncer{ + ItemStore: ssb.is.Copy(), + r: ssb.r, + opts: ssb.opts, + p: p, + keyCh: ssb.keyCh, + } +} + +// probe implements syncBase. +func (ssb *setSyncBase) probe(ctx context.Context, p peer.ID) (ProbeResult, error) { + return Probe(ctx, ssb.r, p, ssb.is, nil, nil, ssb.opts...) +} + +// run implements syncBase. +func (ssb *setSyncBase) run(ctx context.Context) error { + eg, ctx := errgroup.WithContext(ctx) + doneCh := make(chan Ordered) + beingProcessed := make(map[Ordered]struct{}) + for { + select { + case <-ctx.Done(): + return eg.Wait() + case k := <-ssb.keyCh: + if ssb.is.Has(k) { + continue + } + if _, found := beingProcessed[k]; found { + continue + } + eg.Go(func() error { + defer func() { + select { + case <-ctx.Done(): + case doneCh <- k: + } + }() + return ssb.handler(ctx, k) + }) + case k := <-doneCh: + delete(beingProcessed, k) + } + } +} + +type setSyncer struct { + ItemStore + r requester + opts []Option + p peer.ID + keyCh chan<- Ordered +} + +var ( + _ syncer = &setSyncer{} + _ ItemStore = &setSyncer{} +) + +// peer implements syncer. +func (ss *setSyncer) peer() peer.ID { + return ss.p +} + +// sync implements syncer. +func (ss *setSyncer) sync(ctx context.Context, x, y *types.Hash32) error { + return SyncStore(ctx, ss.r, ss.p, ss, x, y, ss.opts...) +} + +// Add implements ItemStore. +func (ss *setSyncer) Add(ctx context.Context, k Ordered, v any) error { + select { + case <-ctx.Done(): + return ctx.Err() + case ss.keyCh <- k: + } + return ss.ItemStore.Add(ctx, k, v) +} diff --git a/hashsync/split_sync.go b/hashsync/split_sync.go index 9ce428cf1b..5d272a4296 100644 --- a/hashsync/split_sync.go +++ b/hashsync/split_sync.go @@ -12,6 +12,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/fetch/peers" "github.com/spacemeshos/go-spacemesh/p2p" ) @@ -23,14 +24,15 @@ type syncResult struct { type splitSync struct { logger *zap.Logger syncBase syncBase - peerSet peerSet - peers []p2p.Peer + peers *peers.Peers + syncPeers []p2p.Peer gracePeriod time.Duration clock clockwork.Clock sq syncQueue resCh chan syncResult slowRangeCh chan *syncRange syncMap map[p2p.Peer]*syncRange + failedPeers map[p2p.Peer]struct{} numRunning int numRemaining int numPeers int @@ -41,35 +43,36 @@ type splitSync struct { func newSplitSync( logger *zap.Logger, syncBase syncBase, - peerSet peerSet, - peers []p2p.Peer, + peers *peers.Peers, + syncPeers []p2p.Peer, gracePeriod time.Duration, clock clockwork.Clock, ) *splitSync { - if len(peers) == 0 { + if len(syncPeers) == 0 { panic("BUG: no peers passed to splitSync") } return &splitSync{ logger: logger, syncBase: syncBase, - peerSet: peerSet, peers: peers, + syncPeers: syncPeers, gracePeriod: gracePeriod, clock: clock, - sq: newSyncQueue(len(peers)), + sq: newSyncQueue(len(syncPeers)), resCh: make(chan syncResult), syncMap: make(map[p2p.Peer]*syncRange), - numRemaining: len(peers), - numPeers: len(peers), + failedPeers: make(map[p2p.Peer]struct{}), + numRemaining: len(syncPeers), + numPeers: len(syncPeers), } } func (s *splitSync) nextPeer() p2p.Peer { - if len(s.peers) == 0 { + if len(s.syncPeers) == 0 { panic("BUG: no peers") } - p := s.peers[0] - s.peers = s.peers[1:] + p := s.syncPeers[0] + s.syncPeers = s.syncPeers[1:] return p } @@ -113,13 +116,13 @@ func (s *splitSync) handleSyncResult(r syncResult) error { sr.numSyncers-- if r.err != nil { s.numPeers-- - s.peerSet.removePeer(r.s.peer()) + s.failedPeers[r.s.peer()] = struct{}{} s.logger.Debug("remove failed peer", zap.Stringer("peer", r.s.peer()), zap.Int("numPeers", s.numPeers), zap.Int("numRemaining", s.numRemaining), zap.Int("numRunning", s.numRunning), - zap.Int("availPeers", len(s.peers))) + zap.Int("availPeers", len(s.syncPeers))) if s.numPeers == 0 && s.numRemaining != 0 { return errors.New("all peers dropped before full sync has completed") } @@ -131,22 +134,26 @@ func (s *splitSync) handleSyncResult(r syncResult) error { } } else { sr.done = true - s.peers = append(s.peers, r.s.peer()) + s.syncPeers = append(s.syncPeers, r.s.peer()) s.numRemaining-- s.logger.Debug("peer synced successfully", zap.Stringer("peer", r.s.peer()), zap.Int("numPeers", s.numPeers), zap.Int("numRemaining", s.numRemaining), zap.Int("numRunning", s.numRunning), - zap.Int("availPeers", len(s.peers))) + zap.Int("availPeers", len(s.syncPeers))) } return nil } func (s *splitSync) clearDeadPeers() { - s.peers = slices.DeleteFunc(s.peers, func(p p2p.Peer) bool { - return !s.peerSet.havePeer(p) + s.syncPeers = slices.DeleteFunc(s.syncPeers, func(p p2p.Peer) bool { + if !s.peers.Contains(p) { + return true + } + _, failed := s.failedPeers[p] + return failed }) } @@ -171,9 +178,9 @@ func (s *splitSync) sync(ctx context.Context) error { break } s.clearDeadPeers() - for s.numRemaining > 0 && (s.sq.empty() || len(s.peers) == 0) { + for s.numRemaining > 0 && (s.sq.empty() || len(s.syncPeers) == 0) { s.logger.Debug("QQQQQ: loop") - if s.numRunning == 0 && len(s.peers) == 0 { + if s.numRunning == 0 && len(s.syncPeers) == 0 { return errors.New("all peers dropped before full sync has completed") } select { diff --git a/hashsync/split_sync_test.go b/hashsync/split_sync_test.go index 4e7cd9d181..843e5c1dbc 100644 --- a/hashsync/split_sync_test.go +++ b/hashsync/split_sync_test.go @@ -8,12 +8,14 @@ import ( "time" "github.com/jonboulle/clockwork" - "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/p2p" "github.com/stretchr/testify/require" gomock "go.uber.org/mock/gomock" "go.uber.org/zap/zaptest" "golang.org/x/sync/errgroup" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/fetch/peers" + "github.com/spacemeshos/go-spacemesh/p2p" ) func hexDelimiters(n int) (r []string) { @@ -70,14 +72,14 @@ func TestGetDelimiters(t *testing.T) { type splitSyncTester struct { testing.TB - peers []p2p.Peer + syncPeers []p2p.Peer clock clockwork.Clock mtx sync.Mutex fail map[hexRange]bool expPeerRanges map[hexRange]int peerRanges map[hexRange][]p2p.Peer syncBase *MocksyncBase - peerSet *MockpeerSet + peers *peers.Peers splitSync *splitSync } @@ -103,9 +105,9 @@ var tstRanges = []hexRange{ func newTestSplitSync(t testing.TB) *splitSyncTester { ctrl := gomock.NewController(t) tst := &splitSyncTester{ - peers: make([]p2p.Peer, 4), - clock: clockwork.NewFakeClock(), - fail: make(map[hexRange]bool), + syncPeers: make([]p2p.Peer, 4), + clock: clockwork.NewFakeClock(), + fail: make(map[hexRange]bool), expPeerRanges: map[hexRange]int{ tstRanges[0]: 0, tstRanges[1]: 0, @@ -114,12 +116,12 @@ func newTestSplitSync(t testing.TB) *splitSyncTester { }, peerRanges: make(map[hexRange][]p2p.Peer), syncBase: NewMocksyncBase(ctrl), - peerSet: NewMockpeerSet(ctrl), + peers: peers.New(), } - for n := range tst.peers { - tst.peers[n] = p2p.Peer(types.RandomBytes(20)) + for n := range tst.syncPeers { + tst.syncPeers[n] = p2p.Peer(types.RandomBytes(20)) } - for index, p := range tst.peers { + for index, p := range tst.syncPeers { index := index p := p tst.syncBase.EXPECT(). @@ -153,18 +155,14 @@ func newTestSplitSync(t testing.TB) *splitSyncTester { }). AnyTimes() } - tst.peerSet.EXPECT(). - havePeer(gomock.Any()). - DoAndReturn(func(p p2p.Peer) bool { - require.Contains(t, tst.peers, p) - return true - }). - AnyTimes() + for _, p := range tst.syncPeers { + tst.peers.Add(p) + } tst.splitSync = newSplitSync( zaptest.NewLogger(t), tst.syncBase, - tst.peerSet, tst.peers, + tst.syncPeers, time.Minute, tst.clock, ) @@ -187,11 +185,6 @@ func TestSplitSyncRetry(t *testing.T) { tst := newTestSplitSync(t) tst.fail[tstRanges[1]] = true tst.fail[tstRanges[2]] = true - removedPeers := make(map[p2p.Peer]bool) - tst.peerSet.EXPECT().removePeer(gomock.Any()).DoAndReturn(func(peer p2p.Peer) { - require.NotContains(t, removedPeers, peer) - removedPeers[peer] = true - }).Times(2) var eg errgroup.Group eg.Go(func() error { return tst.splitSync.sync(context.Background()) @@ -201,15 +194,6 @@ func TestSplitSyncRetry(t *testing.T) { require.False(t, tst.fail[pr], "fail cleared for x %s y %s", pr[0], pr[1]) require.Equal(t, 1, count, "peer range not synced: x %s y %s", pr[0], pr[1]) } - for _, r := range []hexRange{tstRanges[1], tstRanges[2]} { - haveFailedPeers := false - for _, peer := range tst.peerRanges[r] { - if removedPeers[peer] { - haveFailedPeers = true - } - } - require.True(t, haveFailedPeers) - } } // TODO: test cancel diff --git a/hashsync/sync_tree.go b/hashsync/sync_tree.go index bd2990ad00..ff995f743b 100644 --- a/hashsync/sync_tree.go +++ b/hashsync/sync_tree.go @@ -33,14 +33,12 @@ func (fpred FingerprintPredicate) Match(y any) bool { } type SyncTree interface { - // Make a copy of the tree. The copy shares the structure with - // this tree but all its nodes are copy-on-write, so any - // changes in the copied tree do not affect this one and are - // safe to perform in another goroutine. The copy operation is - // O(n) where n is the number of nodes added to this tree - // since its creation via either NewSyncTree function or this - // Copy method, or the last call of this Copy method for this - // tree, whichever occurs last. The call to Copy is thread-safe. + // Make a copy of the tree. The copy shares the structure with this tree but all + // its nodes are copy-on-write, so any changes in the copied tree do not affect + // this one and are safe to perform in another goroutine. The copy operation is + // O(n) where n is the number of nodes added to this tree since its creation via + // either NewSyncTree function or this Copy method, or the last call of this Copy + // method for this tree, whichever occurs last. The call to Copy is thread-safe. Copy() SyncTree Fingerprint() any Add(k Ordered) @@ -58,7 +56,7 @@ func SyncTreeFromSortedSlice[T Ordered](m Monoid, items []T) SyncTree { s[n] = item } st := NewSyncTree(m).(*syncTree) - st.root = st.buildFromSortedSlice(nil, s) + st.root = st.buildFromSortedSlice(s) return st } @@ -421,7 +419,7 @@ func (st *syncTree) Fingerprint() any { return st.root.fingerprint } -func (st *syncTree) newNode(parent *syncTreeNode, k Ordered, v any) *syncTreeNode { +func (st *syncTree) newNode(k Ordered, v any) *syncTreeNode { return &syncTreeNode{ key: k, value: v, @@ -430,17 +428,17 @@ func (st *syncTree) newNode(parent *syncTreeNode, k Ordered, v any) *syncTreeNod } } -func (st *syncTree) buildFromSortedSlice(parent *syncTreeNode, s []Ordered) *syncTreeNode { +func (st *syncTree) buildFromSortedSlice(s []Ordered) *syncTreeNode { switch len(s) { case 0: return nil case 1: - return st.newNode(nil, s[0], nil) + return st.newNode(s[0], nil) } middle := len(s) / 2 - node := st.newNode(parent, s[middle], nil) - node.left = st.buildFromSortedSlice(node, s[:middle]) - node.right = st.buildFromSortedSlice(node, s[middle+1:]) + node := st.newNode(s[middle], nil) + node.left = st.buildFromSortedSlice(s[:middle]) + node.right = st.buildFromSortedSlice(s[middle+1:]) if node.left != nil { node.fingerprint = st.m.Op(node.left.fingerprint, node.fingerprint) } @@ -526,7 +524,7 @@ func (st *syncTree) insert(sn *syncTreeNode, k Ordered, v any, rb, set bool) *sy // simplified insert implementation idea from // https://zarif98sjs.github.io/blog/blog/redblacktree/ if sn == nil { - sn = st.newNode(nil, k, v) + sn = st.newNode(k, v) // the new node is not really "cloned", but at this point it's // only present in this tree so we can safely modify it // without allocating new nodes diff --git a/hashsync/sync_tree_store.go b/hashsync/sync_tree_store.go index 048b1297a5..852e446841 100644 --- a/hashsync/sync_tree_store.go +++ b/hashsync/sync_tree_store.go @@ -1,5 +1,7 @@ package hashsync +import "context" + type ValueHandler interface { Load(k Ordered, treeValue any) (v any) Store(k Ordered, v any) (treeValue any) @@ -68,9 +70,10 @@ func NewSyncTreeStore(m Monoid, vh ValueHandler, newValue NewValueFunc) ItemStor } // Add implements ItemStore. -func (sts *SyncTreeStore) Add(k Ordered, v any) { +func (sts *SyncTreeStore) Add(ctx context.Context, k Ordered, v any) error { treeValue := sts.vh.Store(k, v) sts.st.Set(k, treeValue) + return nil } func (sts *SyncTreeStore) iter(ptr SyncTreePointer) Iterator { @@ -137,3 +140,19 @@ func (sts *SyncTreeStore) Max() Iterator { func (sts *SyncTreeStore) New() any { return sts.newValue() } + +// Copy implements ItemStore. +func (sts *SyncTreeStore) Copy() ItemStore { + return &SyncTreeStore{ + st: sts.st.Copy(), + vh: sts.vh, + newValue: sts.newValue, + identity: sts.identity, + } +} + +// Has implements ItemStore. +func (sts *SyncTreeStore) Has(k Ordered) bool { + _, found := sts.st.Lookup(k) + return found +} diff --git a/hashsync/sync_tree_test.go b/hashsync/sync_tree_test.go index 144c66d72c..aaa3b526cf 100644 --- a/hashsync/sync_tree_test.go +++ b/hashsync/sync_tree_test.go @@ -36,7 +36,7 @@ func makeStringConcatTree(chars string) SyncTree { for n, c := range chars { ids[n] = sampleID(c) } - return SyncTreeFromSlice[sampleID](sampleCountMonoid(), ids) + return SyncTreeFromSlice(sampleCountMonoid(), ids) } // dumbAdd inserts the node into the tree without trying to maintain the @@ -67,7 +67,7 @@ func makeRBTree(chars string) SyncTree { return st } -func gtePos(all string, item string) int { +func gtePos(all, item string) int { n := slices.IndexFunc([]byte(all), func(v byte) bool { return v >= item[0] }) @@ -411,7 +411,7 @@ func testRandomOrderAndRanges(t *testing.T, mktree makeTestTreeFunc) { rand.Shuffle(len(shuffled), func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] }) - tree := makeDumbTree(string(shuffled)) + tree := mktree(string(shuffled)) x := sampleID(shuffled[rand.Intn(len(shuffled))]) y := sampleID(shuffled[rand.Intn(len(shuffled))]) stopCount := rand.Intn(len(shuffled)+2) - 1 diff --git a/hashsync/xorsync_test.go b/hashsync/xorsync_test.go index 9c9922925e..2d312b9c50 100644 --- a/hashsync/xorsync_test.go +++ b/hashsync/xorsync_test.go @@ -1,6 +1,7 @@ package hashsync import ( + "context" "math/rand" "slices" "testing" @@ -59,15 +60,18 @@ type catchTransferTwice struct { added map[types.Hash32]bool } -func (s *catchTransferTwice) Add(k Ordered, v any) { +func (s *catchTransferTwice) Add(ctx context.Context, k Ordered, v any) error { h := k.(types.Hash32) _, found := s.added[h] assert.False(s.t, found, "hash sent twice") - s.ItemStore.Add(k, v) + if err := s.ItemStore.Add(ctx, k, v); err != nil { + return err + } if s.added == nil { s.added = make(map[types.Hash32]bool) } s.added[h] = true + return nil } type xorSyncTestConfig struct { @@ -83,8 +87,10 @@ type fakeValue struct { v string } -var _ scale.Decodable = &fakeValue{} -var _ scale.Encodable = &fakeValue{} +var ( + _ scale.Decodable = &fakeValue{} + _ scale.Encodable = &fakeValue{} +) func mkFakeValue(h types.Hash32) *fakeValue { return &fakeValue{v: h.String()} @@ -114,7 +120,7 @@ func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB sliceA := src[:cfg.numTestHashes-numSpecificB] storeA := NewSyncTreeStore(Hash32To12Xor{}, nil, func() any { return new(fakeValue) }) for _, h := range sliceA { - storeA.Add(h, mkFakeValue(h)) + require.NoError(t, storeA.Add(context.Background(), h, mkFakeValue(h))) } storeA = &catchTransferTwice{t: t, ItemStore: storeA} @@ -122,7 +128,7 @@ func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB sliceB = append(sliceB, src[cfg.numTestHashes-numSpecificB:]...) storeB := NewSyncTreeStore(Hash32To12Xor{}, nil, func() any { return new(fakeValue) }) for _, h := range sliceB { - storeB.Add(h, mkFakeValue(h)) + require.NoError(t, storeB.Add(context.Background(), h, mkFakeValue(h))) } storeB = &catchTransferTwice{t: t, ItemStore: storeB} From 4a4bc35c005818de3ee81518bfa668f0c082b36f Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 15 May 2024 18:04:22 +0400 Subject: [PATCH 19/76] hashsync: don't send values Existing fetcher will be used for blobs --- hashsync/handler.go | 84 +++--------------------------------- hashsync/handler_test.go | 51 ++++++++-------------- hashsync/rangesync.go | 26 ++++------- hashsync/rangesync_test.go | 60 +++++--------------------- hashsync/setsyncbase.go | 4 +- hashsync/sync_tree_store.go | 42 ++---------------- hashsync/wire_types.go | 23 +++++----- hashsync/wire_types_scale.go | 15 ------- hashsync/xorsync_test.go | 61 ++++++-------------------- 9 files changed, 75 insertions(+), 291 deletions(-) diff --git a/hashsync/handler.go b/hashsync/handler.go index 63090f32ff..97086502ca 100644 --- a/hashsync/handler.go +++ b/hashsync/handler.go @@ -19,65 +19,6 @@ type sendable interface { Type() MessageType } -type decodedItemBatchMessage struct { - ContentKeys []types.Hash32 - ContentValues []any -} - -var _ SyncMessage = &decodedItemBatchMessage{} - -func (m *decodedItemBatchMessage) Type() MessageType { return MessageTypeItemBatch } -func (m *decodedItemBatchMessage) X() Ordered { return nil } -func (m *decodedItemBatchMessage) Y() Ordered { return nil } -func (m *decodedItemBatchMessage) Fingerprint() any { return nil } -func (m *decodedItemBatchMessage) Count() int { return 0 } -func (m *decodedItemBatchMessage) Keys() []Ordered { - r := make([]Ordered, len(m.ContentKeys)) - for n, k := range m.ContentKeys { - r[n] = k - } - return r -} - -func (m *decodedItemBatchMessage) Values() []any { - r := make([]any, len(m.ContentValues)) - for n, v := range m.ContentValues { - r[n] = v - } - return r -} - -func (m *decodedItemBatchMessage) encode() (*ItemBatchMessage, error) { - var b bytes.Buffer - for _, v := range m.ContentValues { - _, err := codec.EncodeTo(&b, v.(codec.Encodable)) - if err != nil { - return nil, err - } - } - return &ItemBatchMessage{ - ContentKeys: m.ContentKeys, - ContentValues: b.Bytes(), - }, nil -} - -func decodeItemBatchMessage(m *ItemBatchMessage, newValue NewValueFunc) (*decodedItemBatchMessage, error) { - d := &decodedItemBatchMessage{ContentKeys: m.ContentKeys} - b := bytes.NewBuffer(m.ContentValues) - for b.Len() != 0 { - v := newValue().(codec.Decodable) - if _, err := codec.DecodeFrom(b, v); err != nil { - return nil, err - } - d.ContentValues = append(d.ContentValues, v) - } - if len(d.ContentValues) != len(d.ContentKeys) { - return nil, fmt.Errorf("mismatched key / value counts: %d / %d", - len(d.ContentKeys), len(d.ContentValues)) - } - return d, nil -} - // QQQQQ: rmme var ( numRead int @@ -112,7 +53,6 @@ type conduitState int type wireConduit struct { stream io.ReadWriter initReqBuf *bytes.Buffer - newValue NewValueFunc // rmmePrint bool } @@ -139,11 +79,7 @@ func (c *wireConduit) NextMessage() (SyncMessage, error) { if _, err := codec.DecodeFrom(c.stream, &m); err != nil { return nil, err } - dm, err := decodeItemBatchMessage(&m, c.newValue) - if err != nil { - return nil, err - } - return dm, nil + return &m, nil case MessageTypeEmptySet: return &EmptySetMessage{}, nil case MessageTypeEmptyRange: @@ -226,22 +162,18 @@ func (c *wireConduit) SendRangeContents(x, y Ordered, count int) error { func (c *wireConduit) SendItems(count, itemChunkSize int, it Iterator) error { for i := 0; i < count; i += itemChunkSize { - var msg decodedItemBatchMessage + // TBD: do not use chunks, just stream the contentkeys + var msg ItemBatchMessage n := min(itemChunkSize, count-i) for n > 0 { if it.Key() == nil { panic("fakeConduit.SendItems: went got to the end of the tree") } msg.ContentKeys = append(msg.ContentKeys, it.Key().(types.Hash32)) - msg.ContentValues = append(msg.ContentValues, it.Value()) it.Next() n-- } - encoded, err := msg.encode() - if err != nil { - return err - } - if err := c.send(encoded); err != nil { + if err := c.send(&msg); err != nil { return err } } @@ -327,7 +259,7 @@ func (c *wireConduit) ShortenKey(k Ordered) Ordered { func MakeServerHandler(is ItemStore, opts ...Option) server.StreamHandler { return func(ctx context.Context, req []byte, stream io.ReadWriter) error { - c := wireConduit{newValue: is.New} + var c wireConduit rsr := NewRangeSetReconciler(is, opts...) s := struct { io.Reader @@ -342,7 +274,7 @@ func MakeServerHandler(is ItemStore, opts ...Option) server.StreamHandler { } func SyncStore(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, x, y *types.Hash32, opts ...Option) error { - c := wireConduit{newValue: is.New} + var c wireConduit rsr := NewRangeSetReconciler(is, opts...) // c.rmmePrint = true var ( @@ -379,9 +311,7 @@ func Probe( info RangeInfo pr ProbeResult ) - c := wireConduit{ - newValue: func() any { return nil }, // not used - } + var c wireConduit rsr := NewRangeSetReconciler(is, opts...) if x == nil { initReq, err = c.withInitialRequest(func(c Conduit) error { diff --git a/hashsync/handler_test.go b/hashsync/handler_test.go index bdc84d5f02..0488b61559 100644 --- a/hashsync/handler_test.go +++ b/hashsync/handler_test.go @@ -125,14 +125,6 @@ func (it *sliceIterator) Key() Ordered { return nil } -func (it *sliceIterator) Value() any { - k := it.Key() - if k == nil { - return nil - } - return mkFakeValue(k.(types.Hash32)) -} - func (it *sliceIterator) Next() { if len(it.s) != 0 { it.s = it.s[1:] @@ -204,8 +196,8 @@ func (r *fakeRound) handleConversation(t *testing.T, c *wireConduit) error { return nil } -func makeTestStreamHandler(t *testing.T, c *wireConduit, newValue NewValueFunc, rounds []fakeRound) server.StreamHandler { - cbk := makeTestRequestCallback(t, c, newValue, rounds) +func makeTestStreamHandler(t *testing.T, c *wireConduit, rounds []fakeRound) server.StreamHandler { + cbk := makeTestRequestCallback(t, c, rounds) return func(ctx context.Context, initialRequest []byte, stream io.ReadWriter) error { t.Logf("init request bytes: %d", len(initialRequest)) s := struct { @@ -220,10 +212,10 @@ func makeTestStreamHandler(t *testing.T, c *wireConduit, newValue NewValueFunc, } } -func makeTestRequestCallback(t *testing.T, c *wireConduit, newValue NewValueFunc, rounds []fakeRound) server.StreamRequestCallback { +func makeTestRequestCallback(t *testing.T, c *wireConduit, rounds []fakeRound) server.StreamRequestCallback { return func(ctx context.Context, stream io.ReadWriter) error { if c == nil { - c = &wireConduit{stream: stream, newValue: newValue} + c = &wireConduit{stream: stream} } else { c.stream = stream } @@ -242,7 +234,7 @@ func TestWireConduit(t *testing.T) { hs[n] = types.RandomHash() } fp := types.Hash12(hs[2][:12]) - srvHandler := makeTestStreamHandler(t, nil, func() any { return new(fakeValue) }, []fakeRound{ + srvHandler := makeTestStreamHandler(t, nil, []fakeRound{ { name: "server got 1st request", expectMsgs: []SyncMessage{ @@ -276,13 +268,11 @@ func TestWireConduit(t *testing.T) { { name: "server got 2nd request", expectMsgs: []SyncMessage{ - &decodedItemBatchMessage{ - ContentKeys: []types.Hash32{hs[9], hs[10]}, - ContentValues: []any{mkFakeValue(hs[9]), mkFakeValue(hs[10])}, + &ItemBatchMessage{ + ContentKeys: []types.Hash32{hs[9], hs[10]}, }, - &decodedItemBatchMessage{ - ContentKeys: []types.Hash32{hs[11]}, - ContentValues: []any{mkFakeValue(hs[11])}, + &ItemBatchMessage{ + ContentKeys: []types.Hash32{hs[11]}, }, &EndRoundMessage{}, }, @@ -307,7 +297,6 @@ func TestWireConduit(t *testing.T) { client := newFakeRequester("client", nil, srv) var c wireConduit - c.newValue = func() any { return new(fakeValue) } initReq, err := c.withInitialRequest(func(c Conduit) error { if err := c.SendFingerprint(hs[0], hs[1], fp, 4); err != nil { return err @@ -315,7 +304,7 @@ func TestWireConduit(t *testing.T) { return c.SendEndRound() }) require.NoError(t, err) - clientCbk := makeTestRequestCallback(t, &c, c.newValue, []fakeRound{ + clientCbk := makeTestRequestCallback(t, &c, []fakeRound{ { name: "client got 1st response", expectMsgs: []SyncMessage{ @@ -329,13 +318,11 @@ func TestWireConduit(t *testing.T) { RangeY: hs[6], NumItems: 2, }, - &decodedItemBatchMessage{ - ContentKeys: []types.Hash32{hs[4], hs[5]}, - ContentValues: []any{mkFakeValue(hs[4]), mkFakeValue(hs[5])}, + &ItemBatchMessage{ + ContentKeys: []types.Hash32{hs[4], hs[5]}, }, - &decodedItemBatchMessage{ - ContentKeys: []types.Hash32{hs[7], hs[8]}, - ContentValues: []any{mkFakeValue(hs[7]), mkFakeValue(hs[8])}, + &ItemBatchMessage{ + ContentKeys: []types.Hash32{hs[7], hs[8]}, }, &EndRoundMessage{}, }, @@ -423,8 +410,8 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) requester { // numTestHashes: 5000000, // minNumSpecificA: 15000, // maxNumSpecificA: 20000, - // minNumSpecificB: 15000, - // maxNumSpecificB: 20000, + // minNumSpecificB: 15, + // maxNumSpecificB: 20, maxSendRange: 1, numTestHashes: 100000, minNumSpecificA: 4, @@ -454,9 +441,9 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) requester { } func TestWireSync(t *testing.T) { - // t.Run("fake requester", func(t *testing.T) { - // testWireSync(t, fakeRequesterGetter()) - // }) + t.Run("fake requester", func(t *testing.T) { + testWireSync(t, fakeRequesterGetter()) + }) t.Run("p2p", func(t *testing.T) { testWireSync(t, p2pRequesterGetter(t)) }) diff --git a/hashsync/rangesync.go b/hashsync/rangesync.go index 28f624a2e0..e3d3d0b292 100644 --- a/hashsync/rangesync.go +++ b/hashsync/rangesync.go @@ -56,7 +56,6 @@ type SyncMessage interface { Fingerprint() any Count() int Keys() []Ordered - Values() []any } func SyncMessageToString(m SyncMessage) string { @@ -75,16 +74,13 @@ func SyncMessageToString(m SyncMessage) string { if fp := m.Fingerprint(); fp != nil { sb.WriteString(" FP=" + fp.(fmt.Stringer).String()) } - vals := m.Values() - for n, k := range m.Keys() { - fmt.Fprintf(&sb, " item=[%s:%#v]", k.(fmt.Stringer).String(), vals[n]) + for _, k := range m.Keys() { + fmt.Fprintf(&sb, " item=%s", k.(fmt.Stringer).String()) } sb.WriteString(">") return sb.String() } -type NewValueFunc func() any - // Conduit handles receiving and sending peer messages // TODO: replace multiple Send* methods with a single one // (after de-generalizing messages) @@ -103,9 +99,9 @@ type Conduit interface { // SendEmptyRange notifies the peer that the specified range // is empty on our side. The corresponding SyncMessage has Count() == 0 SendEmptyRange(x, y Ordered) error - // SendItems notifies the peer that the corresponding range items will + // SendRangeContents notifies the peer that the corresponding range items will // be included in this sync round. The items themselves are sent via - // SendItemsOnly + // SendItems SendRangeContents(x, y Ordered, count int) error // SendItems sends just items without any message SendItems(count, chunkSize int, it Iterator) error @@ -153,9 +149,6 @@ type Iterator interface { // Key returns the key corresponding to iterator position. It returns // nil if the ItemStore is empty Key() Ordered - // Value returns the value corresponding to the iterator. It returns nil - // if the ItemStore is empty - Value() any // Next advances the iterator Next() } @@ -167,8 +160,8 @@ type RangeInfo struct { } type ItemStore interface { - // Add adds a key-value pair to the store - Add(ctx context.Context, k Ordered, v any) error + // Add adds a key to the store + Add(ctx context.Context, k Ordered) error // GetRangeInfo returns RangeInfo for the item range in the tree. // If count >= 0, at most count items are returned, and RangeInfo // is returned for the corresponding subrange of the requested range. @@ -181,8 +174,6 @@ type ItemStore interface { // Max returns the iterator pointing at the maximum element // in the store. If the store is empty, it returns nil Max() Iterator - // New returns an empty payload value - New() any // Copy makes a shallow copy of the ItemStore Copy() ItemStore // Has returns true if the specified key is present in ItemStore @@ -561,9 +552,8 @@ func (rsr *RangeSetReconciler) Process(ctx context.Context, c Conduit) (done boo done = true for _, msg := range msgs { if msg.Type() == MessageTypeItemBatch { - vals := msg.Values() - for n, k := range msg.Keys() { - if err := rsr.is.Add(ctx, k, vals[n]); err != nil { + for _, k := range msg.Keys() { + if err := rsr.is.Add(ctx, k); err != nil { return false, fmt.Errorf("error adding an item to the store: %w", err) } } diff --git a/hashsync/rangesync_test.go b/hashsync/rangesync_test.go index bf2ba15163..2fccf5aa3a 100644 --- a/hashsync/rangesync_test.go +++ b/hashsync/rangesync_test.go @@ -13,12 +13,11 @@ import ( ) type rangeMessage struct { - mtype MessageType - x, y Ordered - fp any - count int - keys []Ordered - values []any + mtype MessageType + x, y Ordered + fp any + count int + keys []Ordered } var _ SyncMessage = rangeMessage{} @@ -29,7 +28,6 @@ func (m rangeMessage) Y() Ordered { return m.y } func (m rangeMessage) Fingerprint() any { return m.fp } func (m rangeMessage) Count() int { return m.count } func (m rangeMessage) Keys() []Ordered { return m.keys } -func (m rangeMessage) Values() []any { return m.values } func (m rangeMessage) String() string { return SyncMessageToString(m) @@ -125,7 +123,6 @@ func (fc *fakeConduit) SendItems(count, itemChunkSize int, it Iterator) error { panic("fakeConduit.SendItems: went got to the end of the tree") } msg.keys = append(msg.keys, it.Key()) - msg.values = append(msg.values, it.Value()) it.Next() n-- } @@ -196,13 +193,6 @@ func (it *dumbStoreIterator) Key() Ordered { return it.ds.keys[it.n] } -func (it *dumbStoreIterator) Value() any { - if len(it.ds.keys) == 0 { - return nil - } - return it.ds.m[it.Key().(sampleID)] -} - func (it *dumbStoreIterator) Next() { if len(it.ds.keys) != 0 { it.n = (it.n + 1) % len(it.ds.keys) @@ -211,19 +201,14 @@ func (it *dumbStoreIterator) Next() { type dumbStore struct { keys []sampleID - m map[sampleID]any } var _ ItemStore = &dumbStore{} -func (ds *dumbStore) Add(ctx context.Context, k Ordered, v any) error { - if ds.m == nil { - ds.m = make(map[sampleID]any) - } +func (ds *dumbStore) Add(ctx context.Context, k Ordered) error { id := k.(sampleID) if len(ds.keys) == 0 { ds.keys = []sampleID{id} - ds.m[id] = v return nil } p := slices.IndexFunc(ds.keys, func(other sampleID) bool { @@ -232,12 +217,10 @@ func (ds *dumbStore) Add(ctx context.Context, k Ordered, v any) error { switch { case p < 0: ds.keys = append(ds.keys, id) - ds.m[id] = v case id == ds.keys[p]: // already present default: ds.keys = slices.Insert(ds.keys, p, id) - ds.m[id] = v } return nil @@ -320,10 +303,6 @@ func (ds *dumbStore) Max() Iterator { } } -func (it *dumbStore) New() any { - panic("not implemented") -} - func (ds *dumbStore) Copy() ItemStore { panic("not implemented") } @@ -363,13 +342,6 @@ func (it verifiedStoreIterator) Key() Ordered { return k2 } -func (it verifiedStoreIterator) Value() any { - v1 := it.knownGood.Value() - v2 := it.it.Value() - assert.Equal(it.t, v1, v2, "values") - return v2 -} - func (it verifiedStoreIterator) Next() { it.knownGood.Next() it.it.Next() @@ -392,7 +364,7 @@ func disableReAdd(s ItemStore) { } } -func (vs *verifiedStore) Add(ctx context.Context, k Ordered, v any) error { +func (vs *verifiedStore) Add(ctx context.Context, k Ordered) error { if vs.disableReAdd { _, found := vs.added[k.(sampleID)] require.False(vs.t, found, "hash sent twice: %v", k) @@ -401,10 +373,10 @@ func (vs *verifiedStore) Add(ctx context.Context, k Ordered, v any) error { } vs.added[k.(sampleID)] = struct{}{} } - if err := vs.knownGood.Add(ctx, k, v); err != nil { + if err := vs.knownGood.Add(ctx, k); err != nil { return fmt.Errorf("add to knownGood: %w", err) } - if err := vs.store.Add(ctx, k, v); err != nil { + if err := vs.store.Add(ctx, k); err != nil { return fmt.Errorf("add to store: %w", err) } return nil @@ -491,13 +463,6 @@ func (vs *verifiedStore) Max() Iterator { } } -func (vs *verifiedStore) New() any { - v1 := vs.knownGood.New() - v2 := vs.store.New() - require.Equal(vs.t, v1, v2, "New") - return v2 -} - func (vs *verifiedStore) Copy() ItemStore { return &verifiedStore{ t: vs.t, @@ -521,10 +486,7 @@ func makeDumbStore(t *testing.T) ItemStore { } func makeSyncTreeStore(t *testing.T) ItemStore { - return NewSyncTreeStore(sampleMonoid{}, nil, func() any { - // newValue func is only called by wireConduit - panic("not implemented") - }) + return NewSyncTreeStore(sampleMonoid{}) } func makeVerifiedSyncTreeStore(t *testing.T) ItemStore { @@ -538,7 +500,7 @@ func makeVerifiedSyncTreeStore(t *testing.T) ItemStore { func makeStore(t *testing.T, f storeFactory, items string) ItemStore { s := f(t) for _, c := range items { - require.NoError(t, s.Add(context.Background(), sampleID(c), "")) + require.NoError(t, s.Add(context.Background(), sampleID(c))) } return s } diff --git a/hashsync/setsyncbase.go b/hashsync/setsyncbase.go index a088dd7b83..467ccadfc4 100644 --- a/hashsync/setsyncbase.go +++ b/hashsync/setsyncbase.go @@ -111,11 +111,11 @@ func (ss *setSyncer) sync(ctx context.Context, x, y *types.Hash32) error { } // Add implements ItemStore. -func (ss *setSyncer) Add(ctx context.Context, k Ordered, v any) error { +func (ss *setSyncer) Add(ctx context.Context, k Ordered) error { select { case <-ctx.Done(): return ctx.Err() case ss.keyCh <- k: } - return ss.ItemStore.Add(ctx, k, v) + return ss.ItemStore.Add(ctx, k) } diff --git a/hashsync/sync_tree_store.go b/hashsync/sync_tree_store.go index 852e446841..bfc949f362 100644 --- a/hashsync/sync_tree_store.go +++ b/hashsync/sync_tree_store.go @@ -2,25 +2,9 @@ package hashsync import "context" -type ValueHandler interface { - Load(k Ordered, treeValue any) (v any) - Store(k Ordered, v any) (treeValue any) -} - -type defaultValueHandler struct{} - -func (vh defaultValueHandler) Load(k Ordered, treeValue any) (v any) { - return treeValue -} - -func (vh defaultValueHandler) Store(k Ordered, v any) (treeValue any) { - return v -} - type syncTreeIterator struct { st SyncTree ptr SyncTreePointer - vh ValueHandler } var _ Iterator = &syncTreeIterator{} @@ -37,10 +21,6 @@ func (it *syncTreeIterator) Key() Ordered { return it.ptr.Key() } -func (it *syncTreeIterator) Value() any { - return it.vh.Load(it.ptr.Key(), it.ptr.Value()) -} - func (it *syncTreeIterator) Next() { it.ptr.Next() if it.ptr.Key() == nil { @@ -50,29 +30,21 @@ func (it *syncTreeIterator) Next() { type SyncTreeStore struct { st SyncTree - vh ValueHandler - newValue NewValueFunc identity any } var _ ItemStore = &SyncTreeStore{} -func NewSyncTreeStore(m Monoid, vh ValueHandler, newValue NewValueFunc) ItemStore { - if vh == nil { - vh = defaultValueHandler{} - } +func NewSyncTreeStore(m Monoid) ItemStore { return &SyncTreeStore{ st: NewSyncTree(CombineMonoids(m, CountingMonoid{})), - vh: vh, - newValue: newValue, identity: m.Identity(), } } // Add implements ItemStore. -func (sts *SyncTreeStore) Add(ctx context.Context, k Ordered, v any) error { - treeValue := sts.vh.Store(k, v) - sts.st.Set(k, treeValue) +func (sts *SyncTreeStore) Add(ctx context.Context, k Ordered) error { + sts.st.Set(k, nil) return nil } @@ -83,7 +55,6 @@ func (sts *SyncTreeStore) iter(ptr SyncTreePointer) Iterator { return &syncTreeIterator{ st: sts.st, ptr: ptr, - vh: sts.vh, } } @@ -136,17 +107,10 @@ func (sts *SyncTreeStore) Max() Iterator { return sts.iter(sts.st.Max()) } -// New implements ItemStore. -func (sts *SyncTreeStore) New() any { - return sts.newValue() -} - // Copy implements ItemStore. func (sts *SyncTreeStore) Copy() ItemStore { return &SyncTreeStore{ st: sts.st.Copy(), - vh: sts.vh, - newValue: sts.newValue, identity: sts.identity, } } diff --git a/hashsync/wire_types.go b/hashsync/wire_types.go index 27cc1ba49e..bc94a6fa7d 100644 --- a/hashsync/wire_types.go +++ b/hashsync/wire_types.go @@ -17,7 +17,6 @@ func (*Marker) Y() Ordered { return nil } func (*Marker) Fingerprint() any { return nil } func (*Marker) Count() int { return 0 } func (*Marker) Keys() []Ordered { return nil } -func (*Marker) Values() []any { return nil } // DoneMessage is a SyncMessage that denotes the end of the synchronization. // The peer should stop any further processing after receiving this message. @@ -56,7 +55,6 @@ func (m *EmptyRangeMessage) Y() Ordered { return m.RangeY } func (m *EmptyRangeMessage) Fingerprint() any { return nil } func (m *EmptyRangeMessage) Count() int { return 0 } func (m *EmptyRangeMessage) Keys() []Ordered { return nil } -func (m *EmptyRangeMessage) Values() []any { return nil } // FingerprintMessage contains range fingerprint for comparison against the // peer's fingerprint of the range with the same bounds [RangeX, RangeY) @@ -74,7 +72,6 @@ func (m *FingerprintMessage) Y() Ordered { return m.RangeY } func (m *FingerprintMessage) Fingerprint() any { return m.RangeFingerprint } func (m *FingerprintMessage) Count() int { return int(m.NumItems) } func (m *FingerprintMessage) Keys() []Ordered { return nil } -func (m *FingerprintMessage) Values() []any { return nil } // RangeContentsMessage denotes a range for which the set of items has been sent. // The peer needs to send back any items it has in the same range bounded @@ -92,18 +89,24 @@ func (m *RangeContentsMessage) Y() Ordered { return m.RangeY } func (m *RangeContentsMessage) Fingerprint() any { return nil } func (m *RangeContentsMessage) Count() int { return int(m.NumItems) } func (m *RangeContentsMessage) Keys() []Ordered { return nil } -func (m *RangeContentsMessage) Values() []any { return nil } // ItemBatchMessage denotes a batch of items to be added to the peer's set. -// ItemBatchMessage doesn't implement SyncMessage interface by itself -// and needs to be wrapped in TypedItemBatchMessage[T] that implements -// SyncMessage by providing the proper Values() method type ItemBatchMessage struct { - ContentKeys []types.Hash32 `scale:"max=1024"` - ContentValues []byte `scale:"max=1024"` + ContentKeys []types.Hash32 `scale:"max=1024"` } func (m *ItemBatchMessage) Type() MessageType { return MessageTypeItemBatch } +func (m *ItemBatchMessage) X() Ordered { return nil } +func (m *ItemBatchMessage) Y() Ordered { return nil } +func (m *ItemBatchMessage) Fingerprint() any { return nil } +func (m *ItemBatchMessage) Count() int { return 0 } +func (m *ItemBatchMessage) Keys() []Ordered { + var r []Ordered + for _, k := range m.ContentKeys { + r = append(r, k) + } + return r +} // ProbeMessage requests bounded range fingerprint and count from the peer, // along with a minhash sample if fingerprints differ @@ -132,7 +135,6 @@ func (m *ProbeMessage) Y() Ordered { func (m *ProbeMessage) Fingerprint() any { return m.RangeFingerprint } func (m *ProbeMessage) Count() int { return int(m.SampleSize) } func (m *ProbeMessage) Keys() []Ordered { return nil } -func (m *ProbeMessage) Values() []any { return nil } // MinhashSampleItem represents an item of minhash sample subset type MinhashSampleItem uint32 @@ -194,7 +196,6 @@ func (m *ProbeResponseMessage) Y() Ordered { } func (m *ProbeResponseMessage) Fingerprint() any { return m.RangeFingerprint } func (m *ProbeResponseMessage) Count() int { return int(m.NumItems) } -func (m *ProbeResponseMessage) Values() []any { return nil } func (m *ProbeResponseMessage) Keys() []Ordered { r := make([]Ordered, len(m.Sample)) diff --git a/hashsync/wire_types_scale.go b/hashsync/wire_types_scale.go index 164c4f1cf5..4a0343e2fa 100644 --- a/hashsync/wire_types_scale.go +++ b/hashsync/wire_types_scale.go @@ -242,13 +242,6 @@ func (t *ItemBatchMessage) EncodeScale(enc *scale.Encoder) (total int, err error } total += n } - { - n, err := scale.EncodeByteSliceWithLimit(enc, t.ContentValues, 1024) - if err != nil { - return total, err - } - total += n - } return total, nil } @@ -261,14 +254,6 @@ func (t *ItemBatchMessage) DecodeScale(dec *scale.Decoder) (total int, err error total += n t.ContentKeys = field } - { - field, n, err := scale.DecodeByteSliceWithLimit(dec, 1024) - if err != nil { - return total, err - } - total += n - t.ContentValues = field - } return total, nil } diff --git a/hashsync/xorsync_test.go b/hashsync/xorsync_test.go index 2d312b9c50..c8008f8f6d 100644 --- a/hashsync/xorsync_test.go +++ b/hashsync/xorsync_test.go @@ -6,7 +6,6 @@ import ( "slices" "testing" - "github.com/spacemeshos/go-scale" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -31,22 +30,14 @@ func TestHash32To12Xor(t *testing.T) { require.Equal(t, m.Op(m.Op(fp1, fp2), fp3), m.Op(fp1, m.Op(fp2, fp3))) } -type pair[K any, V any] struct { - k K - v V -} - -func collectStoreItems[K Ordered, V any](is ItemStore) (r []pair[K, V]) { +func collectStoreItems[K Ordered](is ItemStore) (r []K) { it := is.Min() if it == nil { return nil } endAt := is.Min() for { - r = append(r, pair[K, V]{ - k: it.Key().(K), - v: it.Value().(V), - }) + r = append(r, it.Key().(K)) it.Next() if it.Equal(endAt) { return r @@ -60,11 +51,11 @@ type catchTransferTwice struct { added map[types.Hash32]bool } -func (s *catchTransferTwice) Add(ctx context.Context, k Ordered, v any) error { +func (s *catchTransferTwice) Add(ctx context.Context, k Ordered) error { h := k.(types.Hash32) _, found := s.added[h] assert.False(s.t, found, "hash sent twice") - if err := s.ItemStore.Add(ctx, k, v); err != nil { + if err := s.ItemStore.Add(ctx, k); err != nil { return err } if s.added == nil { @@ -83,29 +74,6 @@ type xorSyncTestConfig struct { maxNumSpecificB int } -type fakeValue struct { - v string -} - -var ( - _ scale.Decodable = &fakeValue{} - _ scale.Encodable = &fakeValue{} -) - -func mkFakeValue(h types.Hash32) *fakeValue { - return &fakeValue{v: h.String()} -} - -func (fv *fakeValue) DecodeScale(dec *scale.Decoder) (total int, err error) { - s, total, err := scale.DecodeString(dec) - fv.v = s - return total, err -} - -func (fv *fakeValue) EncodeScale(enc *scale.Encoder) (total int, err error) { - return scale.EncodeString(enc, fv.v) -} - func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB ItemStore, numSpecific int, opts []Option) bool) { opts := []Option{ WithMaxSendRange(cfg.maxSendRange), @@ -118,17 +86,17 @@ func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB } sliceA := src[:cfg.numTestHashes-numSpecificB] - storeA := NewSyncTreeStore(Hash32To12Xor{}, nil, func() any { return new(fakeValue) }) + storeA := NewSyncTreeStore(Hash32To12Xor{}) for _, h := range sliceA { - require.NoError(t, storeA.Add(context.Background(), h, mkFakeValue(h))) + require.NoError(t, storeA.Add(context.Background(), h)) } storeA = &catchTransferTwice{t: t, ItemStore: storeA} sliceB := append([]types.Hash32(nil), src[:cfg.numTestHashes-numSpecificB-numSpecificA]...) sliceB = append(sliceB, src[cfg.numTestHashes-numSpecificB:]...) - storeB := NewSyncTreeStore(Hash32To12Xor{}, nil, func() any { return new(fakeValue) }) + storeB := NewSyncTreeStore(Hash32To12Xor{}) for _, h := range sliceB { - require.NoError(t, storeB.Add(context.Background(), h, mkFakeValue(h))) + require.NoError(t, storeB.Add(context.Background(), h)) } storeB = &catchTransferTwice{t: t, ItemStore: storeB} @@ -137,17 +105,14 @@ func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB }) if sync(storeA, storeB, numSpecificA+numSpecificB, opts) { - itemsA := collectStoreItems[types.Hash32, *fakeValue](storeA) - itemsB := collectStoreItems[types.Hash32, *fakeValue](storeB) + itemsA := collectStoreItems[types.Hash32](storeA) + itemsB := collectStoreItems[types.Hash32](storeB) require.Equal(t, itemsA, itemsB) - srcPairs := make([]pair[types.Hash32, *fakeValue], len(src)) + srcKeys := make([]types.Hash32, len(src)) for n, h := range src { - srcPairs[n] = pair[types.Hash32, *fakeValue]{ - k: h, - v: mkFakeValue(h), - } + srcKeys[n] = h } - require.Equal(t, srcPairs, itemsA) + require.Equal(t, srcKeys, itemsA) } } From fbe9ceb348807da49e56fe7349e648c43ebd2694 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Fri, 17 May 2024 04:52:10 +0400 Subject: [PATCH 20/76] hashsync: implement working setSyncBase / setSyncer --- hashsync/interface.go | 46 +++- hashsync/mocks_test.go | 510 ++++++++++++++++++++++++++++++++++- hashsync/multipeer.go | 97 +++---- hashsync/rangesync.go | 38 --- hashsync/rangesync_test.go | 2 +- hashsync/setsyncbase.go | 110 ++++---- hashsync/setsyncbase_test.go | 216 +++++++++++++++ 7 files changed, 856 insertions(+), 163 deletions(-) create mode 100644 hashsync/setsyncbase_test.go diff --git a/hashsync/interface.go b/hashsync/interface.go index 042d57a86d..6285010d47 100644 --- a/hashsync/interface.go +++ b/hashsync/interface.go @@ -10,6 +10,45 @@ import ( //go:generate mockgen -typed -package=hashsync -destination=./mocks_test.go -source=./interface.go +// Iterator points to in item in ItemStore +type Iterator interface { + // Equal returns true if this iterator is equal to another Iterator + Equal(other Iterator) bool + // Key returns the key corresponding to iterator position. It returns + // nil if the ItemStore is empty + Key() Ordered + // Next advances the iterator + Next() +} + +type RangeInfo struct { + Fingerprint any + Count int + Start, End Iterator +} + +// ItemStore represents the data store that can be synced against a remote peer +type ItemStore interface { + // Add adds a key to the store + Add(ctx context.Context, k Ordered) error + // GetRangeInfo returns RangeInfo for the item range in the tree. + // If count >= 0, at most count items are returned, and RangeInfo + // is returned for the corresponding subrange of the requested range. + // If both x and y is nil, the whole set of items is used. + // If only x or only y is nil, GetRangeInfo panics + GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo + // Min returns the iterator pointing at the minimum element + // in the store. If the store is empty, it returns nil + Min() Iterator + // Max returns the iterator pointing at the maximum element + // in the store. If the store is empty, it returns nil + Max() Iterator + // Copy makes a shallow copy of the ItemStore + Copy() ItemStore + // Has returns true if the specified key is present in ItemStore + Has(k Ordered) bool +} + type requester interface { Run(context.Context) error StreamRequest(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error @@ -19,7 +58,7 @@ type syncBase interface { count() int derive(p p2p.Peer) syncer probe(ctx context.Context, p p2p.Peer) (ProbeResult, error) - run(ctx context.Context) error + wait() error } type syncer interface { @@ -31,3 +70,8 @@ type syncRunner interface { splitSync(ctx context.Context, syncPeers []p2p.Peer) error fullSync(ctx context.Context, syncPeers []p2p.Peer) error } + +type pairwiseSyncer interface { + probe(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) (ProbeResult, error) + syncStore(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) error +} diff --git a/hashsync/mocks_test.go b/hashsync/mocks_test.go index 7a0848ce0b..ec175b0b1d 100644 --- a/hashsync/mocks_test.go +++ b/hashsync/mocks_test.go @@ -19,6 +19,392 @@ import ( gomock "go.uber.org/mock/gomock" ) +// MockIterator is a mock of Iterator interface. +type MockIterator struct { + ctrl *gomock.Controller + recorder *MockIteratorMockRecorder +} + +// MockIteratorMockRecorder is the mock recorder for MockIterator. +type MockIteratorMockRecorder struct { + mock *MockIterator +} + +// NewMockIterator creates a new mock instance. +func NewMockIterator(ctrl *gomock.Controller) *MockIterator { + mock := &MockIterator{ctrl: ctrl} + mock.recorder = &MockIteratorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockIterator) EXPECT() *MockIteratorMockRecorder { + return m.recorder +} + +// Equal mocks base method. +func (m *MockIterator) Equal(other Iterator) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Equal", other) + ret0, _ := ret[0].(bool) + return ret0 +} + +// Equal indicates an expected call of Equal. +func (mr *MockIteratorMockRecorder) Equal(other any) *MockIteratorEqualCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Equal", reflect.TypeOf((*MockIterator)(nil).Equal), other) + return &MockIteratorEqualCall{Call: call} +} + +// MockIteratorEqualCall wrap *gomock.Call +type MockIteratorEqualCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockIteratorEqualCall) Return(arg0 bool) *MockIteratorEqualCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockIteratorEqualCall) Do(f func(Iterator) bool) *MockIteratorEqualCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockIteratorEqualCall) DoAndReturn(f func(Iterator) bool) *MockIteratorEqualCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Key mocks base method. +func (m *MockIterator) Key() Ordered { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Key") + ret0, _ := ret[0].(Ordered) + return ret0 +} + +// Key indicates an expected call of Key. +func (mr *MockIteratorMockRecorder) Key() *MockIteratorKeyCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Key", reflect.TypeOf((*MockIterator)(nil).Key)) + return &MockIteratorKeyCall{Call: call} +} + +// MockIteratorKeyCall wrap *gomock.Call +type MockIteratorKeyCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockIteratorKeyCall) Return(arg0 Ordered) *MockIteratorKeyCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockIteratorKeyCall) Do(f func() Ordered) *MockIteratorKeyCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockIteratorKeyCall) DoAndReturn(f func() Ordered) *MockIteratorKeyCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Next mocks base method. +func (m *MockIterator) Next() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Next") +} + +// Next indicates an expected call of Next. +func (mr *MockIteratorMockRecorder) Next() *MockIteratorNextCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockIterator)(nil).Next)) + return &MockIteratorNextCall{Call: call} +} + +// MockIteratorNextCall wrap *gomock.Call +type MockIteratorNextCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockIteratorNextCall) Return() *MockIteratorNextCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockIteratorNextCall) Do(f func()) *MockIteratorNextCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockIteratorNextCall) DoAndReturn(f func()) *MockIteratorNextCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockItemStore is a mock of ItemStore interface. +type MockItemStore struct { + ctrl *gomock.Controller + recorder *MockItemStoreMockRecorder +} + +// MockItemStoreMockRecorder is the mock recorder for MockItemStore. +type MockItemStoreMockRecorder struct { + mock *MockItemStore +} + +// NewMockItemStore creates a new mock instance. +func NewMockItemStore(ctrl *gomock.Controller) *MockItemStore { + mock := &MockItemStore{ctrl: ctrl} + mock.recorder = &MockItemStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockItemStore) EXPECT() *MockItemStoreMockRecorder { + return m.recorder +} + +// Add mocks base method. +func (m *MockItemStore) Add(ctx context.Context, k Ordered) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Add", ctx, k) + ret0, _ := ret[0].(error) + return ret0 +} + +// Add indicates an expected call of Add. +func (mr *MockItemStoreMockRecorder) Add(ctx, k any) *MockItemStoreAddCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockItemStore)(nil).Add), ctx, k) + return &MockItemStoreAddCall{Call: call} +} + +// MockItemStoreAddCall wrap *gomock.Call +type MockItemStoreAddCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockItemStoreAddCall) Return(arg0 error) *MockItemStoreAddCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockItemStoreAddCall) Do(f func(context.Context, Ordered) error) *MockItemStoreAddCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockItemStoreAddCall) DoAndReturn(f func(context.Context, Ordered) error) *MockItemStoreAddCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Copy mocks base method. +func (m *MockItemStore) Copy() ItemStore { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Copy") + ret0, _ := ret[0].(ItemStore) + return ret0 +} + +// Copy indicates an expected call of Copy. +func (mr *MockItemStoreMockRecorder) Copy() *MockItemStoreCopyCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Copy", reflect.TypeOf((*MockItemStore)(nil).Copy)) + return &MockItemStoreCopyCall{Call: call} +} + +// MockItemStoreCopyCall wrap *gomock.Call +type MockItemStoreCopyCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockItemStoreCopyCall) Return(arg0 ItemStore) *MockItemStoreCopyCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockItemStoreCopyCall) Do(f func() ItemStore) *MockItemStoreCopyCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockItemStoreCopyCall) DoAndReturn(f func() ItemStore) *MockItemStoreCopyCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// GetRangeInfo mocks base method. +func (m *MockItemStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRangeInfo", preceding, x, y, count) + ret0, _ := ret[0].(RangeInfo) + return ret0 +} + +// GetRangeInfo indicates an expected call of GetRangeInfo. +func (mr *MockItemStoreMockRecorder) GetRangeInfo(preceding, x, y, count any) *MockItemStoreGetRangeInfoCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRangeInfo", reflect.TypeOf((*MockItemStore)(nil).GetRangeInfo), preceding, x, y, count) + return &MockItemStoreGetRangeInfoCall{Call: call} +} + +// MockItemStoreGetRangeInfoCall wrap *gomock.Call +type MockItemStoreGetRangeInfoCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockItemStoreGetRangeInfoCall) Return(arg0 RangeInfo) *MockItemStoreGetRangeInfoCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockItemStoreGetRangeInfoCall) Do(f func(Iterator, Ordered, Ordered, int) RangeInfo) *MockItemStoreGetRangeInfoCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockItemStoreGetRangeInfoCall) DoAndReturn(f func(Iterator, Ordered, Ordered, int) RangeInfo) *MockItemStoreGetRangeInfoCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Has mocks base method. +func (m *MockItemStore) Has(k Ordered) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Has", k) + ret0, _ := ret[0].(bool) + return ret0 +} + +// Has indicates an expected call of Has. +func (mr *MockItemStoreMockRecorder) Has(k any) *MockItemStoreHasCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Has", reflect.TypeOf((*MockItemStore)(nil).Has), k) + return &MockItemStoreHasCall{Call: call} +} + +// MockItemStoreHasCall wrap *gomock.Call +type MockItemStoreHasCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockItemStoreHasCall) Return(arg0 bool) *MockItemStoreHasCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockItemStoreHasCall) Do(f func(Ordered) bool) *MockItemStoreHasCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockItemStoreHasCall) DoAndReturn(f func(Ordered) bool) *MockItemStoreHasCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Max mocks base method. +func (m *MockItemStore) Max() Iterator { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Max") + ret0, _ := ret[0].(Iterator) + return ret0 +} + +// Max indicates an expected call of Max. +func (mr *MockItemStoreMockRecorder) Max() *MockItemStoreMaxCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Max", reflect.TypeOf((*MockItemStore)(nil).Max)) + return &MockItemStoreMaxCall{Call: call} +} + +// MockItemStoreMaxCall wrap *gomock.Call +type MockItemStoreMaxCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockItemStoreMaxCall) Return(arg0 Iterator) *MockItemStoreMaxCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockItemStoreMaxCall) Do(f func() Iterator) *MockItemStoreMaxCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockItemStoreMaxCall) DoAndReturn(f func() Iterator) *MockItemStoreMaxCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Min mocks base method. +func (m *MockItemStore) Min() Iterator { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Min") + ret0, _ := ret[0].(Iterator) + return ret0 +} + +// Min indicates an expected call of Min. +func (mr *MockItemStoreMockRecorder) Min() *MockItemStoreMinCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Min", reflect.TypeOf((*MockItemStore)(nil).Min)) + return &MockItemStoreMinCall{Call: call} +} + +// MockItemStoreMinCall wrap *gomock.Call +type MockItemStoreMinCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockItemStoreMinCall) Return(arg0 Iterator) *MockItemStoreMinCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockItemStoreMinCall) Do(f func() Iterator) *MockItemStoreMinCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockItemStoreMinCall) DoAndReturn(f func() Iterator) *MockItemStoreMinCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // Mockrequester is a mock of requester interface. type Mockrequester struct { ctrl *gomock.Controller @@ -261,40 +647,40 @@ func (c *MocksyncBaseprobeCall) DoAndReturn(f func(context.Context, p2p.Peer) (P return c } -// run mocks base method. -func (m *MocksyncBase) run(ctx context.Context) error { +// wait mocks base method. +func (m *MocksyncBase) wait() error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "run", ctx) + ret := m.ctrl.Call(m, "wait") ret0, _ := ret[0].(error) return ret0 } -// run indicates an expected call of run. -func (mr *MocksyncBaseMockRecorder) run(ctx any) *MocksyncBaserunCall { +// wait indicates an expected call of wait. +func (mr *MocksyncBaseMockRecorder) wait() *MocksyncBasewaitCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "run", reflect.TypeOf((*MocksyncBase)(nil).run), ctx) - return &MocksyncBaserunCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "wait", reflect.TypeOf((*MocksyncBase)(nil).wait)) + return &MocksyncBasewaitCall{Call: call} } -// MocksyncBaserunCall wrap *gomock.Call -type MocksyncBaserunCall struct { +// MocksyncBasewaitCall wrap *gomock.Call +type MocksyncBasewaitCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MocksyncBaserunCall) Return(arg0 error) *MocksyncBaserunCall { +func (c *MocksyncBasewaitCall) Return(arg0 error) *MocksyncBasewaitCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MocksyncBaserunCall) Do(f func(context.Context) error) *MocksyncBaserunCall { +func (c *MocksyncBasewaitCall) Do(f func() error) *MocksyncBasewaitCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MocksyncBaserunCall) DoAndReturn(f func(context.Context) error) *MocksyncBaserunCall { +func (c *MocksyncBasewaitCall) DoAndReturn(f func() error) *MocksyncBasewaitCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -496,3 +882,103 @@ func (c *MocksyncRunnersplitSyncCall) DoAndReturn(f func(context.Context, []p2p. c.Call = c.Call.DoAndReturn(f) return c } + +// MockpairwiseSyncer is a mock of pairwiseSyncer interface. +type MockpairwiseSyncer struct { + ctrl *gomock.Controller + recorder *MockpairwiseSyncerMockRecorder +} + +// MockpairwiseSyncerMockRecorder is the mock recorder for MockpairwiseSyncer. +type MockpairwiseSyncerMockRecorder struct { + mock *MockpairwiseSyncer +} + +// NewMockpairwiseSyncer creates a new mock instance. +func NewMockpairwiseSyncer(ctrl *gomock.Controller) *MockpairwiseSyncer { + mock := &MockpairwiseSyncer{ctrl: ctrl} + mock.recorder = &MockpairwiseSyncerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockpairwiseSyncer) EXPECT() *MockpairwiseSyncerMockRecorder { + return m.recorder +} + +// probe mocks base method. +func (m *MockpairwiseSyncer) probe(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) (ProbeResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "probe", ctx, peer, is, x, y) + ret0, _ := ret[0].(ProbeResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// probe indicates an expected call of probe. +func (mr *MockpairwiseSyncerMockRecorder) probe(ctx, peer, is, x, y any) *MockpairwiseSyncerprobeCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "probe", reflect.TypeOf((*MockpairwiseSyncer)(nil).probe), ctx, peer, is, x, y) + return &MockpairwiseSyncerprobeCall{Call: call} +} + +// MockpairwiseSyncerprobeCall wrap *gomock.Call +type MockpairwiseSyncerprobeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockpairwiseSyncerprobeCall) Return(arg0 ProbeResult, arg1 error) *MockpairwiseSyncerprobeCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockpairwiseSyncerprobeCall) Do(f func(context.Context, p2p.Peer, ItemStore, *types.Hash32, *types.Hash32) (ProbeResult, error)) *MockpairwiseSyncerprobeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockpairwiseSyncerprobeCall) DoAndReturn(f func(context.Context, p2p.Peer, ItemStore, *types.Hash32, *types.Hash32) (ProbeResult, error)) *MockpairwiseSyncerprobeCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// syncStore mocks base method. +func (m *MockpairwiseSyncer) syncStore(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "syncStore", ctx, peer, is, x, y) + ret0, _ := ret[0].(error) + return ret0 +} + +// syncStore indicates an expected call of syncStore. +func (mr *MockpairwiseSyncerMockRecorder) syncStore(ctx, peer, is, x, y any) *MockpairwiseSyncersyncStoreCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "syncStore", reflect.TypeOf((*MockpairwiseSyncer)(nil).syncStore), ctx, peer, is, x, y) + return &MockpairwiseSyncersyncStoreCall{Call: call} +} + +// MockpairwiseSyncersyncStoreCall wrap *gomock.Call +type MockpairwiseSyncersyncStoreCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockpairwiseSyncersyncStoreCall) Return(arg0 error) *MockpairwiseSyncersyncStoreCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockpairwiseSyncersyncStoreCall) Do(f func(context.Context, p2p.Peer, ItemStore, *types.Hash32, *types.Hash32) error) *MockpairwiseSyncersyncStoreCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockpairwiseSyncersyncStoreCall) DoAndReturn(f func(context.Context, p2p.Peer, ItemStore, *types.Hash32, *types.Hash32) error) *MockpairwiseSyncersyncStoreCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/hashsync/multipeer.go b/hashsync/multipeer.go index e3d29cfb07..a5dbe72aeb 100644 --- a/hashsync/multipeer.go +++ b/hashsync/multipeer.go @@ -195,6 +195,8 @@ func (mpr *MultiPeerReconciler) fullSync(ctx context.Context, syncPeers []p2p.Pe case errors.Is(err, context.Canceled): return err default: + // failing to sync against a particular peer is not considered + // a fatal sync failure, so we just log the error mpr.logger.Error("error syncing peer", zap.Stringer("peer", p), zap.Error(err)) } return nil @@ -203,6 +205,38 @@ func (mpr *MultiPeerReconciler) fullSync(ctx context.Context, syncPeers []p2p.Pe return eg.Wait() } +func (mpr *MultiPeerReconciler) syncOnce(ctx context.Context) error { + var ( + s syncability + err error + ) + for { + syncPeers := mpr.peers.SelectBest(mpr.syncPeerCount) + if len(syncPeers) != 0 { + // probePeers doesn't return transient errors, sync must stop if it failed + s, err = mpr.probePeers(ctx, syncPeers) + if err != nil { + return err + } + if len(s.syncable) != 0 { + break + } + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-mpr.clock.After(mpr.noPeersRecheckInterval): + } + } + + if mpr.needSplitSync(s) { + return mpr.runner.splitSync(ctx, s.splitSyncable) + } else { + return mpr.runner.fullSync(ctx, s.splitSyncable) + } +} + func (mpr *MultiPeerReconciler) Run(ctx context.Context) error { // The point of using split sync, which syncs different key ranges against // different peers, vs full sync which syncs the full key range against different @@ -234,67 +268,20 @@ func (mpr *MultiPeerReconciler) Run(ctx context.Context) error { // Wait for all the syncs to complete/fail // All syncs completed (success / fail) => A ctx, cancel := context.WithCancel(ctx) - var eg errgroup.Group - eg.Go(func() error { - err := mpr.syncBase.run(ctx) - if err != nil && !errors.Is(err, context.Canceled) { - cancel() - mpr.logger.Error("error processing synced items", zap.Error(err)) - return err - } - return nil - }) - defer func() { - cancel() - eg.Wait() - }() - + var err error +LOOP: for { select { case <-ctx.Done(): - cancel() - // if the key handlers have caused an error, return that error - return eg.Wait() + err = ctx.Err() + break LOOP case <-mpr.clock.After(mpr.syncInterval): } - var ( - s syncability - err error - ) - for { - syncPeers := mpr.peers.SelectBest(mpr.syncPeerCount) - if len(syncPeers) != 0 { - // probePeers doesn't return transient errors, sync must stop if it failed - s, err = mpr.probePeers(ctx, syncPeers) - if err != nil { - return err - } - if len(s.syncable) != 0 { - break - } - } - - select { - case <-ctx.Done(): - return ctx.Err() - case <-mpr.clock.After(mpr.noPeersRecheckInterval): - } - } - - if mpr.needSplitSync(s) { - err = mpr.runner.splitSync(ctx, s.splitSyncable) - } else { - err = mpr.runner.fullSync(ctx, s.splitSyncable) - } - - if err != nil { - return err + if err = mpr.syncOnce(ctx); err != nil { + break } } -} - -type HashSyncBase struct { - r requester - is ItemStore + cancel() + return errors.Join(err, mpr.syncBase.wait()) } diff --git a/hashsync/rangesync.go b/hashsync/rangesync.go index e3d3d0b292..7dcf2d197b 100644 --- a/hashsync/rangesync.go +++ b/hashsync/rangesync.go @@ -142,44 +142,6 @@ func WithSampleSize(s int) Option { } } -// Iterator points to in item in ItemStore -type Iterator interface { - // Equal returns true if this iterator is equal to another Iterator - Equal(other Iterator) bool - // Key returns the key corresponding to iterator position. It returns - // nil if the ItemStore is empty - Key() Ordered - // Next advances the iterator - Next() -} - -type RangeInfo struct { - Fingerprint any - Count int - Start, End Iterator -} - -type ItemStore interface { - // Add adds a key to the store - Add(ctx context.Context, k Ordered) error - // GetRangeInfo returns RangeInfo for the item range in the tree. - // If count >= 0, at most count items are returned, and RangeInfo - // is returned for the corresponding subrange of the requested range. - // If both x and y is nil, the whole set of items is used. - // If only x or only y is nil, GetRangeInfo panics - GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo - // Min returns the iterator pointing at the minimum element - // in the store. If the store is empty, it returns nil - Min() Iterator - // Max returns the iterator pointing at the maximum element - // in the store. If the store is empty, it returns nil - Max() Iterator - // Copy makes a shallow copy of the ItemStore - Copy() ItemStore - // Has returns true if the specified key is present in ItemStore - Has(k Ordered) bool -} - type ProbeResult struct { FP any Count int diff --git a/hashsync/rangesync_test.go b/hashsync/rangesync_test.go index 2fccf5aa3a..deeac28c79 100644 --- a/hashsync/rangesync_test.go +++ b/hashsync/rangesync_test.go @@ -304,7 +304,7 @@ func (ds *dumbStore) Max() Iterator { } func (ds *dumbStore) Copy() ItemStore { - panic("not implemented") + return &dumbStore{keys: slices.Clone(ds.keys)} } func (ds *dumbStore) Has(k Ordered) bool { diff --git a/hashsync/setsyncbase.go b/hashsync/setsyncbase.go index 467ccadfc4..b322f1002c 100644 --- a/hashsync/setsyncbase.go +++ b/hashsync/setsyncbase.go @@ -2,31 +2,32 @@ package hashsync import ( "context" + "errors" + "fmt" + + "github.com/spacemeshos/go-spacemesh/p2p" + "golang.org/x/sync/singleflight" - "github.com/libp2p/go-libp2p/core/peer" "github.com/spacemeshos/go-spacemesh/common/types" - "golang.org/x/sync/errgroup" ) type syncKeyHandler func(ctx context.Context, k Ordered) error type setSyncBase struct { - r requester + ps pairwiseSyncer is ItemStore handler syncKeyHandler - opts []Option - keyCh chan Ordered + waiting []<-chan singleflight.Result + g singleflight.Group } var _ syncBase = &setSyncBase{} -func newSetSyncBase(r requester, is ItemStore, handler syncKeyHandler, opts ...Option) *setSyncBase { +func newSetSyncBase(ps pairwiseSyncer, is ItemStore, handler syncKeyHandler) *setSyncBase { return &setSyncBase{ - r: r, + ps: ps, is: is, handler: handler, - opts: opts, - keyCh: make(chan Ordered), } } @@ -41,58 +42,44 @@ func (ssb *setSyncBase) count() int { } // derive implements syncBase. -func (ssb *setSyncBase) derive(p peer.ID) syncer { +func (ssb *setSyncBase) derive(p p2p.Peer) syncer { return &setSyncer{ - ItemStore: ssb.is.Copy(), - r: ssb.r, - opts: ssb.opts, - p: p, - keyCh: ssb.keyCh, + setSyncBase: ssb, + ItemStore: ssb.is.Copy(), + p: p, } } // probe implements syncBase. -func (ssb *setSyncBase) probe(ctx context.Context, p peer.ID) (ProbeResult, error) { - return Probe(ctx, ssb.r, p, ssb.is, nil, nil, ssb.opts...) +func (ssb *setSyncBase) probe(ctx context.Context, p p2p.Peer) (ProbeResult, error) { + return ssb.ps.probe(ctx, p, ssb.is, nil, nil) +} + +func (ssb *setSyncBase) acceptKey(ctx context.Context, k Ordered) { + key := k.(fmt.Stringer).String() + if !ssb.is.Has(k) { + ssb.waiting = append(ssb.waiting, + ssb.g.DoChan(key, func() (any, error) { + return key, ssb.handler(ctx, k) + })) + } } -// run implements syncBase. -func (ssb *setSyncBase) run(ctx context.Context) error { - eg, ctx := errgroup.WithContext(ctx) - doneCh := make(chan Ordered) - beingProcessed := make(map[Ordered]struct{}) - for { - select { - case <-ctx.Done(): - return eg.Wait() - case k := <-ssb.keyCh: - if ssb.is.Has(k) { - continue - } - if _, found := beingProcessed[k]; found { - continue - } - eg.Go(func() error { - defer func() { - select { - case <-ctx.Done(): - case doneCh <- k: - } - }() - return ssb.handler(ctx, k) - }) - case k := <-doneCh: - delete(beingProcessed, k) - } +func (ssb *setSyncBase) wait() error { + var errs []error + for _, w := range ssb.waiting { + r := <-w + ssb.g.Forget(r.Val.(string)) + errs = append(errs, r.Err) } + ssb.waiting = nil + return errors.Join(errs...) } type setSyncer struct { + *setSyncBase ItemStore - r requester - opts []Option - p peer.ID - keyCh chan<- Ordered + p p2p.Peer } var ( @@ -101,21 +88,32 @@ var ( ) // peer implements syncer. -func (ss *setSyncer) peer() peer.ID { +func (ss *setSyncer) peer() p2p.Peer { return ss.p } // sync implements syncer. func (ss *setSyncer) sync(ctx context.Context, x, y *types.Hash32) error { - return SyncStore(ctx, ss.r, ss.p, ss, x, y, ss.opts...) + return ss.ps.syncStore(ctx, ss.p, ss, x, y) } // Add implements ItemStore. func (ss *setSyncer) Add(ctx context.Context, k Ordered) error { - select { - case <-ctx.Done(): - return ctx.Err() - case ss.keyCh <- k: - } + ss.acceptKey(ctx, k) return ss.ItemStore.Add(ctx, k) } + +type realPairwiseSyncer struct { + r requester + opts []Option +} + +var _ pairwiseSyncer = &realPairwiseSyncer{} + +func (ps *realPairwiseSyncer) probe(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) (ProbeResult, error) { + return Probe(ctx, ps.r, peer, is, x, y, ps.opts...) +} + +func (ps *realPairwiseSyncer) syncStore(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) error { + return SyncStore(ctx, ps.r, peer, is, x, y, ps.opts...) +} diff --git a/hashsync/setsyncbase_test.go b/hashsync/setsyncbase_test.go new file mode 100644 index 0000000000..e7e60501c8 --- /dev/null +++ b/hashsync/setsyncbase_test.go @@ -0,0 +1,216 @@ +package hashsync + +import ( + "context" + "errors" + "sync" + "testing" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/stretchr/testify/require" + gomock "go.uber.org/mock/gomock" + "golang.org/x/sync/errgroup" +) + +type setSyncBaseTester struct { + t *testing.T + ctrl *gomock.Controller + ps *MockpairwiseSyncer + is *MockItemStore + ssb *setSyncBase + waitMtx sync.Mutex + waitChs map[Ordered]chan error + doneCh chan Ordered +} + +func newSetSyncBaseTester(t *testing.T) *setSyncBaseTester { + ctrl := gomock.NewController(t) + st := &setSyncBaseTester{ + t: t, + ctrl: ctrl, + ps: NewMockpairwiseSyncer(ctrl), + is: NewMockItemStore(ctrl), + waitChs: make(map[Ordered]chan error), + doneCh: make(chan Ordered), + } + st.ssb = newSetSyncBase(st.ps, st.is, func(ctx context.Context, k Ordered) error { + err := <-st.getWaitCh(k) + st.doneCh <- k + return err + }) + return st +} + +func (st *setSyncBaseTester) getWaitCh(k Ordered) chan error { + st.waitMtx.Lock() + defer st.waitMtx.Unlock() + ch, found := st.waitChs[k] + if !found { + ch = make(chan error) + st.waitChs[k] = ch + } + return ch +} + +func (st *setSyncBaseTester) expectCopy(ctx context.Context, addedKeys ...types.Hash32) { + st.is.EXPECT().Copy().DoAndReturn(func() ItemStore { + copy := NewMockItemStore(st.ctrl) + for _, k := range addedKeys { + copy.EXPECT().Add(ctx, k) + } + return copy + }) +} + +func (st *setSyncBaseTester) expectSyncStore( + ctx context.Context, + p p2p.Peer, + ss syncer, + addedKeys ...types.Hash32, +) { + st.ps.EXPECT().syncStore(ctx, p, ss, nil, nil). + DoAndReturn(func(ctx context.Context, p p2p.Peer, is ItemStore, x, y *types.Hash32) error { + for _, k := range addedKeys { + require.NoError(st.t, is.Add(ctx, k)) + } + return nil + }) +} + +func (st *setSyncBaseTester) failToSyncStore( + ctx context.Context, + p p2p.Peer, + ss syncer, + err error, +) { + st.ps.EXPECT().syncStore(ctx, p, ss, nil, nil). + DoAndReturn(func(ctx context.Context, p p2p.Peer, is ItemStore, x, y *types.Hash32) error { + return err + }) +} + +func (st *setSyncBaseTester) wait(count int) ([]types.Hash32, error) { + var eg errgroup.Group + eg.Go(st.ssb.wait) + var handledKeys []types.Hash32 + for k := range st.doneCh { + handledKeys = append(handledKeys, k.(types.Hash32)) + count-- + if count == 0 { + break + } + } + return handledKeys, eg.Wait() +} + +func TestSetSyncBase(t *testing.T) { + t.Run("probe", func(t *testing.T) { + t.Parallel() + st := newSetSyncBaseTester(t) + ctx := context.Background() + expPr := ProbeResult{ + FP: types.RandomHash(), + Count: 42, + Sim: 0.99, + } + st.ps.EXPECT().probe(ctx, p2p.Peer("p1"), st.is, nil, nil).Return(expPr, nil) + pr, err := st.ssb.probe(ctx, p2p.Peer("p1")) + require.NoError(t, err) + require.Equal(t, expPr, pr) + }) + + t.Run("single key one-time sync", func(t *testing.T) { + t.Parallel() + st := newSetSyncBaseTester(t) + ctx := context.Background() + + addedKey := types.RandomHash() + st.expectCopy(ctx, addedKey) + ss := st.ssb.derive(p2p.Peer("p1")) + require.Equal(t, p2p.Peer("p1"), ss.peer()) + + x := types.RandomHash() + y := types.RandomHash() + st.ps.EXPECT().syncStore(ctx, p2p.Peer("p1"), ss, &x, &y) + require.NoError(t, ss.sync(ctx, &x, &y)) + + st.is.EXPECT().Has(addedKey) + st.expectSyncStore(ctx, p2p.Peer("p1"), ss, addedKey) + require.NoError(t, ss.sync(ctx, nil, nil)) + close(st.getWaitCh(addedKey)) + + handledKeys, err := st.wait(1) + require.NoError(t, err) + require.ElementsMatch(t, []types.Hash32{addedKey}, handledKeys) + }) + + t.Run("single key synced multiple times", func(t *testing.T) { + t.Parallel() + st := newSetSyncBaseTester(t) + ctx := context.Background() + + addedKey := types.RandomHash() + st.expectCopy(ctx, addedKey, addedKey, addedKey) + ss := st.ssb.derive(p2p.Peer("p1")) + require.Equal(t, p2p.Peer("p1"), ss.peer()) + + for i := 0; i < 3; i++ { + st.is.EXPECT().Has(addedKey) + st.expectSyncStore(ctx, p2p.Peer("p1"), ss, addedKey) + require.NoError(t, ss.sync(ctx, nil, nil)) + } + close(st.getWaitCh(addedKey)) + + handledKeys, err := st.wait(1) + require.NoError(t, err) + require.ElementsMatch(t, []types.Hash32{addedKey}, handledKeys) + }) + + t.Run("multiple keys", func(t *testing.T) { + t.Parallel() + st := newSetSyncBaseTester(t) + ctx := context.Background() + + k1 := types.RandomHash() + k2 := types.RandomHash() + st.expectCopy(ctx, k1, k2) + ss := st.ssb.derive(p2p.Peer("p1")) + require.Equal(t, p2p.Peer("p1"), ss.peer()) + + st.is.EXPECT().Has(k1) + st.is.EXPECT().Has(k2) + st.expectSyncStore(ctx, p2p.Peer("p1"), ss, k1, k2) + require.NoError(t, ss.sync(ctx, nil, nil)) + close(st.getWaitCh(k1)) + close(st.getWaitCh(k2)) + + handledKeys, err := st.wait(2) + require.NoError(t, err) + require.ElementsMatch(t, []types.Hash32{k1, k2}, handledKeys) + }) + + t.Run("handler failure", func(t *testing.T) { + t.Parallel() + st := newSetSyncBaseTester(t) + ctx := context.Background() + + k1 := types.RandomHash() + k2 := types.RandomHash() + st.expectCopy(ctx, k1, k2) + ss := st.ssb.derive(p2p.Peer("p1")) + require.Equal(t, p2p.Peer("p1"), ss.peer()) + + st.is.EXPECT().Has(k1) + st.is.EXPECT().Has(k2) + st.expectSyncStore(ctx, p2p.Peer("p1"), ss, k1, k2) + require.NoError(t, ss.sync(ctx, nil, nil)) + handlerErr := errors.New("fail") + st.getWaitCh(k1) <- handlerErr + close(st.getWaitCh(k2)) + + handledKeys, err := st.wait(2) + require.ErrorIs(t, err, handlerErr) + require.ElementsMatch(t, []types.Hash32{k1, k2}, handledKeys) + }) +} From ed7ca2dbbaa575c451abe304bbd2bd397b0c5070 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 18 May 2024 09:54:17 +0400 Subject: [PATCH 21/76] hashsync: fix propagating keys to setSyncBase --- hashsync/setsyncbase.go | 6 +++- hashsync/setsyncbase_test.go | 59 +++++++++++++++++++++++++++++------- 2 files changed, 53 insertions(+), 12 deletions(-) diff --git a/hashsync/setsyncbase.go b/hashsync/setsyncbase.go index b322f1002c..025be7a7ce 100644 --- a/hashsync/setsyncbase.go +++ b/hashsync/setsyncbase.go @@ -60,7 +60,11 @@ func (ssb *setSyncBase) acceptKey(ctx context.Context, k Ordered) { if !ssb.is.Has(k) { ssb.waiting = append(ssb.waiting, ssb.g.DoChan(key, func() (any, error) { - return key, ssb.handler(ctx, k) + err := ssb.handler(ctx, k) + if err == nil { + err = ssb.is.Add(ctx, k) + } + return key, err })) } } diff --git a/hashsync/setsyncbase_test.go b/hashsync/setsyncbase_test.go index e7e60501c8..618b222e95 100644 --- a/hashsync/setsyncbase_test.go +++ b/hashsync/setsyncbase_test.go @@ -14,7 +14,7 @@ import ( ) type setSyncBaseTester struct { - t *testing.T + *testing.T ctrl *gomock.Controller ps *MockpairwiseSyncer is *MockItemStore @@ -24,17 +24,20 @@ type setSyncBaseTester struct { doneCh chan Ordered } -func newSetSyncBaseTester(t *testing.T) *setSyncBaseTester { +func newSetSyncBaseTester(t *testing.T, is ItemStore) *setSyncBaseTester { ctrl := gomock.NewController(t) st := &setSyncBaseTester{ - t: t, + T: t, ctrl: ctrl, ps: NewMockpairwiseSyncer(ctrl), - is: NewMockItemStore(ctrl), waitChs: make(map[Ordered]chan error), doneCh: make(chan Ordered), } - st.ssb = newSetSyncBase(st.ps, st.is, func(ctx context.Context, k Ordered) error { + if is == nil { + st.is = NewMockItemStore(ctrl) + is = st.is + } + st.ssb = newSetSyncBase(st.ps, is, func(ctx context.Context, k Ordered) error { err := <-st.getWaitCh(k) st.doneCh <- k return err @@ -72,7 +75,7 @@ func (st *setSyncBaseTester) expectSyncStore( st.ps.EXPECT().syncStore(ctx, p, ss, nil, nil). DoAndReturn(func(ctx context.Context, p p2p.Peer, is ItemStore, x, y *types.Hash32) error { for _, k := range addedKeys { - require.NoError(st.t, is.Add(ctx, k)) + require.NoError(st, is.Add(ctx, k)) } return nil }) @@ -107,7 +110,7 @@ func (st *setSyncBaseTester) wait(count int) ([]types.Hash32, error) { func TestSetSyncBase(t *testing.T) { t.Run("probe", func(t *testing.T) { t.Parallel() - st := newSetSyncBaseTester(t) + st := newSetSyncBaseTester(t, nil) ctx := context.Background() expPr := ProbeResult{ FP: types.RandomHash(), @@ -122,7 +125,7 @@ func TestSetSyncBase(t *testing.T) { t.Run("single key one-time sync", func(t *testing.T) { t.Parallel() - st := newSetSyncBaseTester(t) + st := newSetSyncBaseTester(t, nil) ctx := context.Background() addedKey := types.RandomHash() @@ -136,6 +139,7 @@ func TestSetSyncBase(t *testing.T) { require.NoError(t, ss.sync(ctx, &x, &y)) st.is.EXPECT().Has(addedKey) + st.is.EXPECT().Add(ctx, addedKey) st.expectSyncStore(ctx, p2p.Peer("p1"), ss, addedKey) require.NoError(t, ss.sync(ctx, nil, nil)) close(st.getWaitCh(addedKey)) @@ -147,7 +151,7 @@ func TestSetSyncBase(t *testing.T) { t.Run("single key synced multiple times", func(t *testing.T) { t.Parallel() - st := newSetSyncBaseTester(t) + st := newSetSyncBaseTester(t, nil) ctx := context.Background() addedKey := types.RandomHash() @@ -155,6 +159,8 @@ func TestSetSyncBase(t *testing.T) { ss := st.ssb.derive(p2p.Peer("p1")) require.Equal(t, p2p.Peer("p1"), ss.peer()) + // added just once + st.is.EXPECT().Add(ctx, addedKey) for i := 0; i < 3; i++ { st.is.EXPECT().Has(addedKey) st.expectSyncStore(ctx, p2p.Peer("p1"), ss, addedKey) @@ -169,7 +175,7 @@ func TestSetSyncBase(t *testing.T) { t.Run("multiple keys", func(t *testing.T) { t.Parallel() - st := newSetSyncBaseTester(t) + st := newSetSyncBaseTester(t, nil) ctx := context.Background() k1 := types.RandomHash() @@ -180,6 +186,8 @@ func TestSetSyncBase(t *testing.T) { st.is.EXPECT().Has(k1) st.is.EXPECT().Has(k2) + st.is.EXPECT().Add(ctx, k1) + st.is.EXPECT().Add(ctx, k2) st.expectSyncStore(ctx, p2p.Peer("p1"), ss, k1, k2) require.NoError(t, ss.sync(ctx, nil, nil)) close(st.getWaitCh(k1)) @@ -192,7 +200,7 @@ func TestSetSyncBase(t *testing.T) { t.Run("handler failure", func(t *testing.T) { t.Parallel() - st := newSetSyncBaseTester(t) + st := newSetSyncBaseTester(t, nil) ctx := context.Background() k1 := types.RandomHash() @@ -203,6 +211,8 @@ func TestSetSyncBase(t *testing.T) { st.is.EXPECT().Has(k1) st.is.EXPECT().Has(k2) + // k1 is not propagated to syncBase due to the handler failure + st.is.EXPECT().Add(ctx, k2) st.expectSyncStore(ctx, p2p.Peer("p1"), ss, k1, k2) require.NoError(t, ss.sync(ctx, nil, nil)) handlerErr := errors.New("fail") @@ -213,4 +223,31 @@ func TestSetSyncBase(t *testing.T) { require.ErrorIs(t, err, handlerErr) require.ElementsMatch(t, []types.Hash32{k1, k2}, handledKeys) }) + + t.Run("synctree based item store", func(t *testing.T) { + t.Parallel() + hs := make([]types.Hash32, 4) + for n := range hs { + hs[n] = types.RandomHash() + } + is := NewSyncTreeStore(Hash32To12Xor{}) + is.Add(context.Background(), hs[0]) + is.Add(context.Background(), hs[1]) + st := newSetSyncBaseTester(t, is) + ss := st.ssb.derive(p2p.Peer("p1")) + ss.(ItemStore).Add(context.Background(), hs[2]) + ss.(ItemStore).Add(context.Background(), hs[3]) + // syncer's cloned ItemStore has new key immediately + require.True(t, ss.(ItemStore).Has(hs[2])) + require.True(t, ss.(ItemStore).Has(hs[3])) + handlerErr := errors.New("fail") + st.getWaitCh(hs[2]) <- handlerErr + close(st.getWaitCh(hs[3])) + handledKeys, err := st.wait(2) + require.ErrorIs(t, err, handlerErr) + require.ElementsMatch(t, hs[2:], handledKeys) + // only successfully handled key propagate the syncBase + require.False(t, is.Has(hs[2])) + require.True(t, is.Has(hs[3])) + }) } From 031620b156d07a0cc7d72098846db3921941a2ef Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 18 May 2024 09:54:47 +0400 Subject: [PATCH 22/76] hashsync: fix multipeer and add test --- hashsync/multipeer.go | 48 ++++++- hashsync/multipeer_test.go | 261 +++++++++++++++++++++++++++++++++++++ 2 files changed, 304 insertions(+), 5 deletions(-) create mode 100644 hashsync/multipeer_test.go diff --git a/hashsync/multipeer.go b/hashsync/multipeer.go index a5dbe72aeb..bfd9b1cf00 100644 --- a/hashsync/multipeer.go +++ b/hashsync/multipeer.go @@ -73,6 +73,12 @@ func WithSplitSyncGracePeriod(t time.Duration) MultiPeerReconcilerOpt { } } +func WithLogger(logger *zap.Logger) MultiPeerReconcilerOpt { + return func(mpr *MultiPeerReconciler) { + mpr.logger = logger + } +} + func withClock(clock clockwork.Clock) MultiPeerReconcilerOpt { return func(mpr *MultiPeerReconciler) { mpr.clock = clock @@ -119,13 +125,12 @@ type MultiPeerReconciler struct { } func NewMultiPeerReconciler( - logger *zap.Logger, syncBase syncBase, peers *peers.Peers, opts ...MultiPeerReconcilerOpt, ) *MultiPeerReconciler { mpr := &MultiPeerReconciler{ - logger: logger, + logger: zap.NewNop(), syncBase: syncBase, peers: peers, syncPeerCount: 20, @@ -149,7 +154,11 @@ func NewMultiPeerReconciler( func (mpr *MultiPeerReconciler) probePeers(ctx context.Context, syncPeers []p2p.Peer) (syncability, error) { var s syncability + s.syncable = nil + s.splitSyncable = nil + s.nearFullCount = 0 for _, p := range syncPeers { + mpr.logger.Debug("probe peer", zap.Stringer("peer", p)) pr, err := mpr.syncBase.probe(ctx, p) if err != nil { log.Warning("error probing the peer", zap.Any("peer", p), zap.Error(err)) @@ -160,17 +169,34 @@ func (mpr *MultiPeerReconciler) probePeers(ctx context.Context, syncPeers []p2p. } s.syncable = append(s.syncable, p) if pr.Count > mpr.minSplitSyncCount { + mpr.logger.Debug("splitSyncable peer", + zap.Stringer("peer", p), + zap.Int("count", pr.Count)) s.splitSyncable = append(s.splitSyncable, p) + } else { + mpr.logger.Debug("NOT splitSyncable peer", + zap.Stringer("peer", p), + zap.Int("count", pr.Count)) } + if (1-pr.Sim)*float64(mpr.syncBase.count()) < float64(mpr.maxFullDiff) { + mpr.logger.Debug("nearFull peer", + zap.Stringer("peer", p), + zap.Float64("sim", pr.Sim), + zap.Int("localCount", mpr.syncBase.count())) s.nearFullCount++ + } else { + mpr.logger.Debug("nearFull peer", + zap.Stringer("peer", p), + zap.Float64("sim", pr.Sim), + zap.Int("localCount", mpr.syncBase.count())) } } return s, nil } func (mpr *MultiPeerReconciler) needSplitSync(s syncability) bool { - if float64(s.nearFullCount) >= float64(mpr.syncBase.count())*mpr.minCompleteFraction { + if float64(s.nearFullCount) >= float64(len(s.syncable))*mpr.minCompleteFraction { // enough peers are close to this one according to minhash score, can do // full sync return false @@ -212,8 +238,10 @@ func (mpr *MultiPeerReconciler) syncOnce(ctx context.Context) error { ) for { syncPeers := mpr.peers.SelectBest(mpr.syncPeerCount) + mpr.logger.Debug("selected best peers for sync", zap.Int("numPeers", len(syncPeers))) if len(syncPeers) != 0 { // probePeers doesn't return transient errors, sync must stop if it failed + mpr.logger.Debug("probing peers", zap.Int("count", len(syncPeers))) s, err = mpr.probePeers(ctx, syncPeers) if err != nil { return err @@ -223,6 +251,7 @@ func (mpr *MultiPeerReconciler) syncOnce(ctx context.Context) error { } } + mpr.logger.Debug("no peers found, waiting", zap.Duration("duration", mpr.noPeersRecheckInterval)) select { case <-ctx.Done(): return ctx.Err() @@ -231,10 +260,19 @@ func (mpr *MultiPeerReconciler) syncOnce(ctx context.Context) error { } if mpr.needSplitSync(s) { - return mpr.runner.splitSync(ctx, s.splitSyncable) + mpr.logger.Debug("doing split sync", zap.Int("peerCount", len(s.splitSyncable))) + err = mpr.runner.splitSync(ctx, s.splitSyncable) } else { - return mpr.runner.fullSync(ctx, s.splitSyncable) + mpr.logger.Debug("doing full sync", zap.Int("peerCount", len(s.syncable))) + err = mpr.runner.fullSync(ctx, s.syncable) + } + + // handler errors are not fatal + if handlerErr := mpr.syncBase.wait(); handlerErr != nil { + mpr.logger.Error("error handling synced keys", zap.Error(handlerErr)) } + + return errors.Join(err) } func (mpr *MultiPeerReconciler) Run(ctx context.Context) error { diff --git a/hashsync/multipeer_test.go b/hashsync/multipeer_test.go new file mode 100644 index 0000000000..4c8073b642 --- /dev/null +++ b/hashsync/multipeer_test.go @@ -0,0 +1,261 @@ +package hashsync + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "go.uber.org/mock/gomock" + "go.uber.org/zap/zaptest" + "golang.org/x/sync/errgroup" + + "github.com/jonboulle/clockwork" + "github.com/spacemeshos/go-spacemesh/fetch/peers" + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/stretchr/testify/require" +) + +// FIXME: BlockUntilContext is not included in FakeClock interface. +// This will be fixed in a post-0.4.0 clockwork release, but with a breaking change that +// makes FakeClock a struct instead of an interface. +// See: https://github.com/jonboulle/clockwork/pull/71 +type fakeClock interface { + clockwork.FakeClock + BlockUntilContext(ctx context.Context, n int) error +} + +type multiPeerSyncTester struct { + *testing.T + ctrl *gomock.Controller + syncBase *MocksyncBase + syncRunner *MocksyncRunner + peers *peers.Peers + clock fakeClock + reconciler *MultiPeerReconciler + selectedPeers []p2p.Peer + cancel context.CancelFunc + eg errgroup.Group +} + +func newMultiPeerSyncTester(t *testing.T) *multiPeerSyncTester { + ctrl := gomock.NewController(t) + mt := &multiPeerSyncTester{ + T: t, + ctrl: ctrl, + syncBase: NewMocksyncBase(ctrl), + syncRunner: NewMocksyncRunner(ctrl), + peers: peers.New(), + clock: clockwork.NewFakeClock().(fakeClock), + } + mt.reconciler = NewMultiPeerReconciler(mt.syncBase, mt.peers, + WithLogger(zaptest.NewLogger(t)), + WithSyncInterval(time.Minute), + WithSyncPeerCount(6), + WithMinSplitSyncPeers(2), + WithMinFullSyncCount(90), + WithMaxFullDiff(20), + WithMinCompleteFraction(0.9), + WithNoPeersRecheckInterval(10*time.Second), + withSyncRunner(mt.syncRunner), + withClock(mt.clock)) + return mt +} + +func (mt *multiPeerSyncTester) addPeers(n int) { + for i := 1; i <= n; i++ { + mt.peers.Add(p2p.Peer(fmt.Sprintf("peer%d", i))) + } +} + +func (mt *multiPeerSyncTester) start() context.Context { + var ctx context.Context + ctx, mt.cancel = context.WithTimeout(context.Background(), 10*time.Second) + mt.eg.Go(func() error { return mt.reconciler.Run(ctx) }) + mt.Cleanup(func() { + mt.cancel() + if err := mt.eg.Wait(); err != nil { + require.ErrorIs(mt, err, context.Canceled) + } + }) + return ctx +} + +func (mt *multiPeerSyncTester) expectProbe(times int, pr ProbeResult) { + mt.selectedPeers = nil + mt.syncBase.EXPECT().probe(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, p p2p.Peer) (ProbeResult, error) { + require.NotContains(mt, mt.selectedPeers, p, "peer probed twice") + require.True(mt, mt.peers.Contains(p)) + mt.selectedPeers = append(mt.selectedPeers, p) + return pr, nil + }).Times(times) +} + +func (mt *multiPeerSyncTester) expectFullSync(times, numFails int) { + mt.syncRunner.EXPECT().fullSync(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, peers []p2p.Peer) error { + require.ElementsMatch(mt, mt.selectedPeers, peers) + // delegate to the real fullsync + return mt.reconciler.fullSync(ctx, peers) + }) + mt.syncBase.EXPECT().derive(gomock.Any()).DoAndReturn(func(p p2p.Peer) syncer { + require.Contains(mt, mt.selectedPeers, p) + s := NewMocksyncer(mt.ctrl) + s.EXPECT().peer().Return(p).AnyTimes() + expSync := s.EXPECT().sync(gomock.Any(), gomock.Nil(), gomock.Nil()) + if numFails != 0 { + expSync.Return(errors.New("sync failed")) + numFails-- + } + return s + }).Times(times) +} + +// satisfy waits until all the expected mocked calls are made +func (mt *multiPeerSyncTester) satisfy() { + require.Eventually(mt, mt.ctrl.Satisfied, time.Second, time.Millisecond) +} + +func TestMultiPeerSync(t *testing.T) { + const numSyncs = 3 + + t.Run("split sync", func(t *testing.T) { + mt := newMultiPeerSyncTester(t) + ctx := mt.start() + mt.clock.BlockUntilContext(ctx, 1) + // Advance by sync interval. No peers yet + mt.clock.Advance(time.Minute) + mt.clock.BlockUntilContext(ctx, 1) + mt.addPeers(10) + // Advance by peer wait time. After that, 6 peers will be selected + // randomly and probed + mt.syncBase.EXPECT().count().Return(50).AnyTimes() + for i := 0; i < numSyncs; i++ { + mt.expectProbe(6, ProbeResult{ + FP: "foo", + Count: 100, + Sim: 0.5, // too low for full sync + }) + mt.syncRunner.EXPECT().splitSync(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, peers []p2p.Peer) error { + require.ElementsMatch(t, mt.selectedPeers, peers) + return nil + }) + mt.syncBase.EXPECT().wait() + mt.clock.BlockUntilContext(ctx, 1) + if i > 0 { + mt.clock.Advance(time.Minute) + } else if i < numSyncs-1 { + mt.clock.Advance(10 * time.Second) + } + mt.satisfy() + } + mt.syncBase.EXPECT().wait() + }) + + t.Run("full sync", func(t *testing.T) { + mt := newMultiPeerSyncTester(t) + ctx := mt.start() + mt.addPeers(10) + mt.syncBase.EXPECT().count().Return(100).AnyTimes() + for i := 0; i < numSyncs; i++ { + mt.expectProbe(6, ProbeResult{ + FP: "foo", + Count: 100, + Sim: 0.99, // high enough for full sync + }) + mt.expectFullSync(6, 0) + mt.syncBase.EXPECT().wait() + mt.clock.BlockUntilContext(ctx, 1) + mt.clock.Advance(time.Minute) + mt.satisfy() + } + mt.syncBase.EXPECT().wait() + }) + + t.Run("full sync due to low peer count", func(t *testing.T) { + mt := newMultiPeerSyncTester(t) + ctx := mt.start() + mt.addPeers(1) + mt.syncBase.EXPECT().count().Return(50).AnyTimes() + for i := 0; i < numSyncs; i++ { + mt.expectProbe(1, ProbeResult{ + FP: "foo", + Count: 100, + Sim: 0.5, // too low for full sync, but will have it anyway + }) + mt.expectFullSync(1, 0) + mt.syncBase.EXPECT().wait() + mt.clock.BlockUntilContext(ctx, 1) + mt.clock.Advance(time.Minute) + mt.satisfy() + } + mt.syncBase.EXPECT().wait() + }) + + t.Run("probe failure", func(t *testing.T) { + mt := newMultiPeerSyncTester(t) + ctx := mt.start() + mt.addPeers(10) + mt.syncBase.EXPECT().count().Return(100).AnyTimes() + mt.syncBase.EXPECT().probe(gomock.Any(), gomock.Any()). + Return(ProbeResult{}, errors.New("probe failed")) + mt.expectProbe(5, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) + // just 5 peers for which the probe worked will be checked + mt.expectFullSync(5, 0) + mt.syncBase.EXPECT().wait().Times(2) + mt.clock.BlockUntilContext(ctx, 1) + mt.clock.Advance(time.Minute) + }) + + t.Run("failed peers during full sync", func(t *testing.T) { + mt := newMultiPeerSyncTester(t) + ctx := mt.start() + mt.addPeers(10) + mt.syncBase.EXPECT().count().Return(100).AnyTimes() + for i := 0; i < numSyncs; i++ { + mt.expectProbe(6, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) + mt.expectFullSync(6, 3) + mt.syncBase.EXPECT().wait() + mt.clock.BlockUntilContext(ctx, 1) + mt.clock.Advance(time.Minute) + mt.satisfy() + } + mt.syncBase.EXPECT().wait() + }) + + t.Run("failed synced key handling during full sync", func(t *testing.T) { + mt := newMultiPeerSyncTester(t) + ctx := mt.start() + mt.addPeers(10) + mt.syncBase.EXPECT().count().Return(100).AnyTimes() + for i := 0; i < numSyncs; i++ { + mt.expectProbe(6, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) + mt.expectFullSync(6, 0) + mt.syncBase.EXPECT().wait().Return(errors.New("some handlers failed")) + mt.clock.BlockUntilContext(ctx, 1) + mt.clock.Advance(time.Minute) + mt.satisfy() + } + mt.syncBase.EXPECT().wait() + }) + + t.Run("cancellation during sync", func(t *testing.T) { + mt := newMultiPeerSyncTester(t) + ctx := mt.start() + mt.addPeers(10) + mt.syncBase.EXPECT().count().Return(100).AnyTimes() + mt.expectProbe(6, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) + mt.syncRunner.EXPECT().fullSync(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, peers []p2p.Peer) error { + mt.cancel() + return ctx.Err() + }) + mt.syncBase.EXPECT().wait().Times(2) + mt.clock.BlockUntilContext(ctx, 1) + mt.clock.Advance(time.Minute) + require.ErrorIs(t, mt.eg.Wait(), context.Canceled) + }) +} From 1b731f334060e5d8f0db25c4950252268d6c5981 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 23 May 2024 16:32:26 +0400 Subject: [PATCH 23/76] p2p: server: use zap for logging --- fetch/fetch.go | 2 +- p2p/server/server.go | 70 ++++++++++++++++++++------------------- p2p/server/server_test.go | 5 ++- 3 files changed, 39 insertions(+), 38 deletions(-) diff --git a/fetch/fetch.go b/fetch/fetch.go index 95980920cf..8acd86b405 100644 --- a/fetch/fetch.go +++ b/fetch/fetch.go @@ -348,7 +348,7 @@ func (f *Fetch) registerServer( opts := []server.Opt{ server.WithTimeout(f.cfg.RequestTimeout), server.WithHardTimeout(f.cfg.RequestHardTimeout), - server.WithLog(f.logger), + server.WithLog(f.logger.Zap()), server.WithDecayingTag(f.cfg.DecayingTag), } if f.cfg.EnableServerMetrics { diff --git a/p2p/server/server.go b/p2p/server/server.go index 8d5d0002d3..079cecc78e 100644 --- a/p2p/server/server.go +++ b/p2p/server/server.go @@ -15,6 +15,7 @@ import ( "github.com/libp2p/go-libp2p/core/protocol" "github.com/multiformats/go-varint" dto "github.com/prometheus/client_model/go" + "go.uber.org/zap" "golang.org/x/sync/errgroup" "golang.org/x/time/rate" @@ -58,7 +59,7 @@ func WithHardTimeout(timeout time.Duration) Opt { } // WithLog configures logger for the server. -func WithLog(log log.Log) Opt { +func WithLog(log *zap.Logger) Opt { return func(s *Server) { s.logger = log } @@ -146,7 +147,7 @@ type Response struct { // Server for the Handler. type Server struct { - logger log.Log + logger *zap.Logger protocol string handler StreamHandler timeout time.Duration @@ -166,7 +167,7 @@ type Server struct { // New server for the handler. func New(h Host, proto string, handler StreamHandler, opts ...Opt) *Server { srv := &Server{ - logger: log.NewNop(), + logger: zap.NewNop(), protocol: proto, handler: handler, h: h, @@ -190,7 +191,7 @@ func New(h Host, proto string, handler StreamHandler, opts ...Opt) *Server { connmgr.DecayFixed(srv.decayingTagSpec.Dec), connmgr.BumpSumBounded(0, srv.decayingTagSpec.Cap)) if err != nil { - srv.logger.Error("error registering decaying tag", log.Err(err)) + srv.logger.Error("error registering decaying tag", zap.Error(err)) } else { srv.decayingTag = tag } @@ -267,21 +268,21 @@ func (s *Server) queueHandler(ctx context.Context, stream network.Stream) bool { rd := bufio.NewReader(dadj) size, err := varint.ReadUvarint(rd) if err != nil { - s.logger.With().Debug("initial read failed", - log.String("protocol", s.protocol), - log.Stringer("remotePeer", stream.Conn().RemotePeer()), - log.Stringer("remoteMultiaddr", stream.Conn().RemoteMultiaddr()), - log.Err(err), + s.logger.Debug("initial read failed", + zap.String("protocol", s.protocol), + zap.Stringer("remotePeer", stream.Conn().RemotePeer()), + zap.Stringer("remoteMultiaddr", stream.Conn().RemoteMultiaddr()), + zap.Error(err), ) return false } if size > uint64(s.requestLimit) { - s.logger.With().Warning("request limit overflow", - log.String("protocol", s.protocol), - log.Stringer("remotePeer", stream.Conn().RemotePeer()), - log.Stringer("remoteMultiaddr", stream.Conn().RemoteMultiaddr()), - log.Int("limit", s.requestLimit), - log.Uint64("request", size), + s.logger.Warn("request limit overflow", + zap.String("protocol", s.protocol), + zap.Stringer("remotePeer", stream.Conn().RemotePeer()), + zap.Stringer("remoteMultiaddr", stream.Conn().RemoteMultiaddr()), + zap.Int("limit", s.requestLimit), + zap.Uint64("request", size), ) stream.Conn().Close() return false @@ -289,29 +290,29 @@ func (s *Server) queueHandler(ctx context.Context, stream network.Stream) bool { buf := make([]byte, size) _, err = io.ReadFull(rd, buf) if err != nil { - s.logger.With().Debug("error reading request", - log.String("protocol", s.protocol), - log.Stringer("remotePeer", stream.Conn().RemotePeer()), - log.Stringer("remoteMultiaddr", stream.Conn().RemoteMultiaddr()), - log.Err(err), + s.logger.Debug("error reading request", + zap.String("protocol", s.protocol), + zap.Stringer("remotePeer", stream.Conn().RemotePeer()), + zap.Stringer("remoteMultiaddr", stream.Conn().RemoteMultiaddr()), + zap.Error(err), ) return false } start := time.Now() if err = s.handler(log.WithNewRequestID(ctx), buf, dadj); err != nil { - s.logger.With().Debug("handler reported error", - log.String("protocol", s.protocol), - log.Stringer("remotePeer", stream.Conn().RemotePeer()), - log.Stringer("remoteMultiaddr", stream.Conn().RemoteMultiaddr()), - log.Err(err), + s.logger.Debug("handler reported error", + zap.String("protocol", s.protocol), + zap.Stringer("remotePeer", stream.Conn().RemotePeer()), + zap.Stringer("remoteMultiaddr", stream.Conn().RemoteMultiaddr()), + zap.Error(err), ) return false } - s.logger.With().Debug("protocol handler execution time", - log.String("protocol", s.protocol), - log.Stringer("remotePeer", stream.Conn().RemotePeer()), - log.Stringer("remoteMultiaddr", stream.Conn().RemoteMultiaddr()), - log.Duration("duration", time.Since(start)), + s.logger.Debug("protocol handler execution time", + zap.String("protocol", s.protocol), + zap.Stringer("remotePeer", stream.Conn().RemotePeer()), + zap.Stringer("remoteMultiaddr", stream.Conn().RemoteMultiaddr()), + zap.Duration("duration", time.Since(start)), ) return true } @@ -356,10 +357,11 @@ func (s *Server) StreamRequest( stream, err := s.streamRequest(ctx, pid, req, extraProtocols...) if err == nil { err = callback(ctx, stream) - s.logger.WithContext(ctx).With().Debug("request execution time", - log.String("protocol", s.protocol), - log.Duration("duration", time.Since(start)), - log.Err(err), + s.logger.Debug("request execution time", + zap.String("protocol", s.protocol), + zap.Duration("duration", time.Since(start)), + zap.Error(err), + log.ZContext(ctx), ) } diff --git a/p2p/server/server_test.go b/p2p/server/server_test.go index 0290833d8d..f8c2c1cdd8 100644 --- a/p2p/server/server_test.go +++ b/p2p/server/server_test.go @@ -11,9 +11,8 @@ import ( "github.com/spacemeshos/go-scale/tester" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" "golang.org/x/sync/errgroup" - - "github.com/spacemeshos/go-spacemesh/log/logtest" ) func TestServer(t *testing.T) { @@ -33,7 +32,7 @@ func TestServer(t *testing.T) { } opts := []Opt{ WithTimeout(100 * time.Millisecond), - WithLog(logtest.New(t)), + WithLog(zaptest.NewLogger(t)), WithMetrics(), } client := New(mesh.Hosts()[0], proto, WrapHandler(handler), append(opts, WithRequestSizeLimit(2*limit))...) From fb22b9dcd7c6b32457c97e67716f01ebfbeebd1f Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 23 May 2024 16:34:04 +0400 Subject: [PATCH 24/76] p2p: server: store peer ID in the context --- p2p/server/server.go | 22 +++++++++++++++++++++- p2p/server/server_test.go | 13 ++++++++++--- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/p2p/server/server.go b/p2p/server/server.go index 079cecc78e..29a6219493 100644 --- a/p2p/server/server.go +++ b/p2p/server/server.go @@ -102,12 +102,30 @@ func WithRequestsPerInterval(n int, interval time.Duration) Opt { } } +// WithDecayingTag specifies P2P decaying tag that is applied to the peer when a request +// is being served func WithDecayingTag(tag DecayingTagSpec) Opt { return func(s *Server) { s.decayingTagSpec = &tag } } +type peerIDKey = struct{} + +func withPeerID(ctx context.Context, peerID peer.ID) context.Context { + return context.WithValue(ctx, peerIDKey{}, peerID) +} + +// ContextPeerID retrieves the ID of the peer being served from the context and a boolean +// value indicating that the context contains peer ID. If there's no peer ID associated +// with the context, the function returns an empty peer ID and false. +func ContextPeerID(ctx context.Context) (peer.ID, bool) { + if v := ctx.Value(peerIDKey{}); v != nil { + return v.(peer.ID), true + } + return peer.ID(""), false +} + // Handler is a handler to be defined by the application. type Handler func(context.Context, []byte) ([]byte, error) @@ -243,9 +261,11 @@ func (s *Server) Run(ctx context.Context) error { eg.Wait() return nil } + peer := req.stream.Conn().RemotePeer() + ctx = withPeerID(ctx, peer) eg.Go(func() error { if s.decayingTag != nil { - s.decayingTag.Bump(req.stream.Conn().RemotePeer(), s.decayingTagSpec.Inc) + s.decayingTag.Bump(peer, s.decayingTagSpec.Inc) } ok := s.queueHandler(ctx, req.stream) if s.metrics != nil { diff --git a/p2p/server/server_test.go b/p2p/server/server_test.go index f8c2c1cdd8..96a9dc4156 100644 --- a/p2p/server/server_test.go +++ b/p2p/server/server_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/libp2p/go-libp2p/core/peer" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/spacemeshos/go-scale/tester" "github.com/stretchr/testify/assert" @@ -24,8 +25,10 @@ func TestServer(t *testing.T) { request := []byte("test request") testErr := errors.New("test error") - handler := func(_ context.Context, msg []byte) ([]byte, error) { - return msg, nil + handler := func(ctx context.Context, msg []byte) ([]byte, error) { + peerID, found := ContextPeerID(ctx) + require.True(t, found) + return append(msg, []byte(peerID)...), nil } errhandler := func(_ context.Context, _ []byte) ([]byte, error) { return nil, testErr @@ -40,6 +43,9 @@ func TestServer(t *testing.T) { srv2 := New(mesh.Hosts()[2], proto, WrapHandler(errhandler), append(opts, WithRequestSizeLimit(limit))...) srv3 := New(mesh.Hosts()[3], "otherproto", WrapHandler(errhandler), append(opts, WithRequestSizeLimit(limit))...) ctx, cancel := context.WithCancel(context.Background()) + noPeerID, found := ContextPeerID(ctx) + require.Equal(t, peer.ID(""), noPeerID) + require.False(t, found) var eg errgroup.Group eg.Go(func() error { return srv1.Run(ctx) @@ -67,7 +73,8 @@ func TestServer(t *testing.T) { n := srv1.NumAcceptedRequests() response, err := client.Request(ctx, mesh.Hosts()[1].ID(), request) require.NoError(t, err) - require.Equal(t, request, response) + expResponse := append(request, []byte(mesh.Hosts()[0].ID())...) + require.Equal(t, expResponse, response) require.NotEmpty(t, mesh.Hosts()[2].Network().ConnsToPeer(mesh.Hosts()[0].ID())) require.Equal(t, n+1, srv1.NumAcceptedRequests()) }) From 0ec78d7204ce96d3b4f79a244a243e8c9e12bbb5 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 23 May 2024 16:35:22 +0400 Subject: [PATCH 25/76] p2p: server: include 'read/write' in deadlineAdjuster error messages --- p2p/server/deadline_adjuster.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/p2p/server/deadline_adjuster.go b/p2p/server/deadline_adjuster.go index 1cdc3d4aa6..e839de0032 100644 --- a/p2p/server/deadline_adjuster.go +++ b/p2p/server/deadline_adjuster.go @@ -16,6 +16,7 @@ const ( ) type deadlineAdjusterError struct { + what string innerErr error elapsed time.Duration totalRead int @@ -29,7 +30,8 @@ func (err *deadlineAdjusterError) Unwrap() error { } func (err *deadlineAdjusterError) Error() string { - return fmt.Sprintf("%v elapsed, %d bytes read, %d bytes written, timeout %v, hard timeout %v: %v", + return fmt.Sprintf("%s: %v elapsed, %d bytes read, %d bytes written, timeout %v, hard timeout %v: %v", + err.what, err.elapsed, err.totalRead, err.totalWritten, @@ -73,6 +75,7 @@ func (dadj *deadlineAdjuster) augmentError(what string, err error) error { } return &deadlineAdjusterError{ + what: what, innerErr: err, elapsed: dadj.clock.Now().Sub(dadj.start), totalRead: dadj.totalRead, From 3de284bbe191c5e7c21091ea7fc6482927a45a7d Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 23 May 2024 16:36:26 +0400 Subject: [PATCH 26/76] p2p: server: close streams upon context cancellation --- p2p/server/deadline_adjuster.go | 13 ++++++++++--- p2p/server/server.go | 17 ++++++++++++++++- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/p2p/server/deadline_adjuster.go b/p2p/server/deadline_adjuster.go index e839de0032..37d9906421 100644 --- a/p2p/server/deadline_adjuster.go +++ b/p2p/server/deadline_adjuster.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "sync" "time" "github.com/jonboulle/clockwork" @@ -52,6 +53,8 @@ type deadlineAdjuster struct { nextAdjustRead int nextAdjustWrite int hardDeadline time.Time + closeErr error + close sync.Once } var _ io.ReadWriteCloser = &deadlineAdjuster{} @@ -85,10 +88,14 @@ func (dadj *deadlineAdjuster) augmentError(what string, err error) error { } } +// Close closes the stream. This method is safe to call multiple times func (dadj *deadlineAdjuster) Close() error { - // FIXME: unsure if this is really needed (inherited from the older Server code) - _ = dadj.peerStream.SetDeadline(time.Time{}) - return dadj.peerStream.Close() + dadj.close.Do(func() { + // FIXME: unsure if this is really needed (inherited from the older Server code) + _ = dadj.peerStream.SetDeadline(time.Time{}) + dadj.closeErr = dadj.peerStream.Close() + }) + return dadj.closeErr } func (dadj *deadlineAdjuster) adjust() error { diff --git a/p2p/server/server.go b/p2p/server/server.go index 29a6219493..64e230f852 100644 --- a/p2p/server/server.go +++ b/p2p/server/server.go @@ -262,7 +262,15 @@ func (s *Server) Run(ctx context.Context) error { return nil } peer := req.stream.Conn().RemotePeer() - ctx = withPeerID(ctx, peer) + ctx, cancel := context.WithCancel(withPeerID(ctx, peer)) + defer cancel() + var eg errgroup.Group + eg.Go(func() error { + <-ctx.Done() + // deadlineAdjuster.Close() is safe to call multiple times + req.stream.Close() + return nil + }) eg.Go(func() error { if s.decayingTag != nil { s.decayingTag.Bump(peer, s.decayingTagSpec.Inc) @@ -376,6 +384,13 @@ func (s *Server) StreamRequest( defer cancel() stream, err := s.streamRequest(ctx, pid, req, extraProtocols...) if err == nil { + var eg errgroup.Group + eg.Go(func() error { + <-ctx.Done() + // deadlineAdjuster.Close() is safe to call multiple times + stream.Close() + return nil + }) err = callback(ctx, stream) s.logger.Debug("request execution time", zap.String("protocol", s.protocol), From 88a67e640d2b9a04b3a9603a2df35f5c61cdc1a2 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 23 May 2024 16:37:00 +0400 Subject: [PATCH 27/76] sync2: implement P2PHashSync Also, fix a race in syncTree and go version for proper loop var capture handling. --- go.mod | 2 +- hashsync/setsyncbase.go | 123 ----- {hashsync => sync2/hashsync}/handler.go | 115 +++-- {hashsync => sync2/hashsync}/handler_test.go | 51 +- {hashsync => sync2/hashsync}/interface.go | 31 +- {hashsync => sync2/hashsync}/mocks_test.go | 473 ++++++++++-------- {hashsync => sync2/hashsync}/monoid.go | 0 {hashsync => sync2/hashsync}/multipeer.go | 22 +- .../hashsync}/multipeer_test.go | 56 +-- {hashsync => sync2/hashsync}/rangesync.go | 22 +- .../hashsync}/rangesync_test.go | 0 sync2/hashsync/setsyncbase.go | 136 +++++ .../hashsync}/setsyncbase_test.go | 52 +- {hashsync => sync2/hashsync}/split_sync.go | 24 +- .../hashsync}/split_sync_test.go | 14 +- {hashsync => sync2/hashsync}/sync_queue.go | 0 .../hashsync}/sync_queue_test.go | 0 {hashsync => sync2/hashsync}/sync_tree.go | 4 +- .../hashsync}/sync_tree_store.go | 0 .../hashsync}/sync_tree_test.go | 56 +++ {hashsync => sync2/hashsync}/wire_types.go | 0 .../hashsync}/wire_types_scale.go | 0 {hashsync => sync2/hashsync}/xorsync.go | 0 {hashsync => sync2/hashsync}/xorsync_test.go | 6 +- sync2/p2p.go | 134 +++++ sync2/p2p_test.go | 102 ++++ 26 files changed, 912 insertions(+), 511 deletions(-) delete mode 100644 hashsync/setsyncbase.go rename {hashsync => sync2/hashsync}/handler.go (85%) rename {hashsync => sync2/hashsync}/handler_test.go (89%) rename {hashsync => sync2/hashsync}/interface.go (77%) rename {hashsync => sync2/hashsync}/mocks_test.go (63%) rename {hashsync => sync2/hashsync}/monoid.go (100%) rename {hashsync => sync2/hashsync}/multipeer.go (94%) rename {hashsync => sync2/hashsync}/multipeer_test.go (84%) rename {hashsync => sync2/hashsync}/rangesync.go (97%) rename {hashsync => sync2/hashsync}/rangesync_test.go (100%) create mode 100644 sync2/hashsync/setsyncbase.go rename {hashsync => sync2/hashsync}/setsyncbase_test.go (83%) rename {hashsync => sync2/hashsync}/split_sync.go (92%) rename {hashsync => sync2/hashsync}/split_sync_test.go (95%) rename {hashsync => sync2/hashsync}/sync_queue.go (100%) rename {hashsync => sync2/hashsync}/sync_queue_test.go (100%) rename {hashsync => sync2/hashsync}/sync_tree.go (99%) rename {hashsync => sync2/hashsync}/sync_tree_store.go (100%) rename {hashsync => sync2/hashsync}/sync_tree_test.go (92%) rename {hashsync => sync2/hashsync}/wire_types.go (100%) rename {hashsync => sync2/hashsync}/wire_types_scale.go (100%) rename {hashsync => sync2/hashsync}/xorsync.go (100%) rename {hashsync => sync2/hashsync}/xorsync_test.go (95%) create mode 100644 sync2/p2p.go create mode 100644 sync2/p2p_test.go diff --git a/go.mod b/go.mod index bd772e63d1..359827776d 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/spacemeshos/go-spacemesh -go 1.21.8 +go 1.22.2 require ( cloud.google.com/go/storage v1.39.1 diff --git a/hashsync/setsyncbase.go b/hashsync/setsyncbase.go deleted file mode 100644 index 025be7a7ce..0000000000 --- a/hashsync/setsyncbase.go +++ /dev/null @@ -1,123 +0,0 @@ -package hashsync - -import ( - "context" - "errors" - "fmt" - - "github.com/spacemeshos/go-spacemesh/p2p" - "golang.org/x/sync/singleflight" - - "github.com/spacemeshos/go-spacemesh/common/types" -) - -type syncKeyHandler func(ctx context.Context, k Ordered) error - -type setSyncBase struct { - ps pairwiseSyncer - is ItemStore - handler syncKeyHandler - waiting []<-chan singleflight.Result - g singleflight.Group -} - -var _ syncBase = &setSyncBase{} - -func newSetSyncBase(ps pairwiseSyncer, is ItemStore, handler syncKeyHandler) *setSyncBase { - return &setSyncBase{ - ps: ps, - is: is, - handler: handler, - } -} - -// count implements syncBase. -func (ssb *setSyncBase) count() int { - it := ssb.is.Min() - if it == nil { - return 0 - } - x := it.Key() - return ssb.is.GetRangeInfo(nil, x, x, -1).Count -} - -// derive implements syncBase. -func (ssb *setSyncBase) derive(p p2p.Peer) syncer { - return &setSyncer{ - setSyncBase: ssb, - ItemStore: ssb.is.Copy(), - p: p, - } -} - -// probe implements syncBase. -func (ssb *setSyncBase) probe(ctx context.Context, p p2p.Peer) (ProbeResult, error) { - return ssb.ps.probe(ctx, p, ssb.is, nil, nil) -} - -func (ssb *setSyncBase) acceptKey(ctx context.Context, k Ordered) { - key := k.(fmt.Stringer).String() - if !ssb.is.Has(k) { - ssb.waiting = append(ssb.waiting, - ssb.g.DoChan(key, func() (any, error) { - err := ssb.handler(ctx, k) - if err == nil { - err = ssb.is.Add(ctx, k) - } - return key, err - })) - } -} - -func (ssb *setSyncBase) wait() error { - var errs []error - for _, w := range ssb.waiting { - r := <-w - ssb.g.Forget(r.Val.(string)) - errs = append(errs, r.Err) - } - ssb.waiting = nil - return errors.Join(errs...) -} - -type setSyncer struct { - *setSyncBase - ItemStore - p p2p.Peer -} - -var ( - _ syncer = &setSyncer{} - _ ItemStore = &setSyncer{} -) - -// peer implements syncer. -func (ss *setSyncer) peer() p2p.Peer { - return ss.p -} - -// sync implements syncer. -func (ss *setSyncer) sync(ctx context.Context, x, y *types.Hash32) error { - return ss.ps.syncStore(ctx, ss.p, ss, x, y) -} - -// Add implements ItemStore. -func (ss *setSyncer) Add(ctx context.Context, k Ordered) error { - ss.acceptKey(ctx, k) - return ss.ItemStore.Add(ctx, k) -} - -type realPairwiseSyncer struct { - r requester - opts []Option -} - -var _ pairwiseSyncer = &realPairwiseSyncer{} - -func (ps *realPairwiseSyncer) probe(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) (ProbeResult, error) { - return Probe(ctx, ps.r, peer, is, x, y, ps.opts...) -} - -func (ps *realPairwiseSyncer) syncStore(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) error { - return SyncStore(ctx, ps.r, peer, is, x, y, ps.opts...) -} diff --git a/hashsync/handler.go b/sync2/hashsync/handler.go similarity index 85% rename from hashsync/handler.go rename to sync2/hashsync/handler.go index 97086502ca..a7555f0685 100644 --- a/hashsync/handler.go +++ b/sync2/hashsync/handler.go @@ -6,12 +6,11 @@ import ( "errors" "fmt" "io" - "sync" + "sync/atomic" "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/p2p" - "github.com/spacemeshos/go-spacemesh/p2p/server" ) type sendable interface { @@ -21,9 +20,8 @@ type sendable interface { // QQQQQ: rmme var ( - numRead int - numWritten int - smtx sync.Mutex + numRead atomic.Int64 + numWritten atomic.Int64 ) type rmmeCountingStream struct { @@ -32,19 +30,15 @@ type rmmeCountingStream struct { // Read implements io.ReadWriter. func (r *rmmeCountingStream) Read(p []byte) (n int, err error) { - smtx.Lock() - defer smtx.Unlock() n, err = r.ReadWriter.Read(p) - numRead += n + numRead.Add(int64(n)) return n, err } // Write implements io.ReadWriter. func (r *rmmeCountingStream) Write(p []byte) (n int, err error) { - smtx.Lock() - defer smtx.Unlock() n, err = r.ReadWriter.Write(p) - numWritten += n + numWritten.Add(int64(n)) return n, err } @@ -257,53 +251,22 @@ func (c *wireConduit) ShortenKey(k Ordered) Ordered { return MinhashSampleItemFromHash32(k.(types.Hash32)) } -func MakeServerHandler(is ItemStore, opts ...Option) server.StreamHandler { - return func(ctx context.Context, req []byte, stream io.ReadWriter) error { - var c wireConduit - rsr := NewRangeSetReconciler(is, opts...) - s := struct { - io.Reader - io.Writer - }{ - // prepend the received request to data being read - Reader: io.MultiReader(bytes.NewBuffer(req), stream), - Writer: stream, - } - return c.handleStream(ctx, s, rsr) - } +type PairwiseStoreSyncer struct { + r Requester + opts []RangeSetReconcilerOption } -func SyncStore(ctx context.Context, r requester, peer p2p.Peer, is ItemStore, x, y *types.Hash32, opts ...Option) error { - var c wireConduit - rsr := NewRangeSetReconciler(is, opts...) - // c.rmmePrint = true - var ( - initReq []byte - err error - ) - if x == nil { - initReq, err = c.withInitialRequest(rsr.Initiate) - } else { - initReq, err = c.withInitialRequest(func(c Conduit) error { - return rsr.InitiateBounded(c, *x, *y) - }) - } - if err != nil { - return err - } - return r.StreamRequest(ctx, peer, initReq, func(ctx context.Context, stream io.ReadWriter) error { - s := &rmmeCountingStream{ReadWriter: stream} - return c.handleStream(ctx, s, rsr) - }) +var _ PairwiseSyncer = &PairwiseStoreSyncer{} + +func NewPairwiseStoreSyncer(r Requester, opts []RangeSetReconcilerOption) *PairwiseStoreSyncer { + return &PairwiseStoreSyncer{r: r, opts: opts} } -func Probe( +func (pss *PairwiseStoreSyncer) Probe( ctx context.Context, - r requester, peer p2p.Peer, is ItemStore, x, y *types.Hash32, - opts ...Option, ) (ProbeResult, error) { var ( err error @@ -312,7 +275,7 @@ func Probe( pr ProbeResult ) var c wireConduit - rsr := NewRangeSetReconciler(is, opts...) + rsr := NewRangeSetReconciler(is, pss.opts...) if x == nil { initReq, err = c.withInitialRequest(func(c Conduit) error { info, err = rsr.InitiateProbe(c) @@ -327,7 +290,7 @@ func Probe( if err != nil { return ProbeResult{}, err } - err = r.StreamRequest(ctx, peer, initReq, func(ctx context.Context, stream io.ReadWriter) error { + err = pss.r.StreamRequest(ctx, peer, initReq, func(ctx context.Context, stream io.ReadWriter) error { c.stream = stream var err error pr, err = rsr.HandleProbeResponse(&c, info) @@ -339,6 +302,54 @@ func Probe( return pr, nil } +func (pss *PairwiseStoreSyncer) SyncStore( + ctx context.Context, + peer p2p.Peer, + is ItemStore, + x, y *types.Hash32, +) error { + var c wireConduit + rsr := NewRangeSetReconciler(is, pss.opts...) + // c.rmmePrint = true + var ( + initReq []byte + err error + ) + if x == nil { + initReq, err = c.withInitialRequest(rsr.Initiate) + } else { + initReq, err = c.withInitialRequest(func(c Conduit) error { + return rsr.InitiateBounded(c, *x, *y) + }) + } + if err != nil { + return err + } + return pss.r.StreamRequest(ctx, peer, initReq, func(ctx context.Context, stream io.ReadWriter) error { + s := &rmmeCountingStream{ReadWriter: stream} + return c.handleStream(ctx, s, rsr) + }) +} + +func (pss *PairwiseStoreSyncer) Serve( + ctx context.Context, + req []byte, + stream io.ReadWriter, + is ItemStore, +) error { + var c wireConduit + rsr := NewRangeSetReconciler(is, pss.opts...) + s := struct { + io.Reader + io.Writer + }{ + // prepend the received request to data being read + Reader: io.MultiReader(bytes.NewBuffer(req), stream), + Writer: stream, + } + return c.handleStream(ctx, s, rsr) +} + // TODO: request duration // TODO: validate counts // TODO: don't forget about Initiate!!! diff --git a/hashsync/handler_test.go b/sync2/hashsync/handler_test.go similarity index 89% rename from hashsync/handler_test.go rename to sync2/hashsync/handler_test.go index 0488b61559..ba96fbacbd 100644 --- a/hashsync/handler_test.go +++ b/sync2/hashsync/handler_test.go @@ -11,10 +11,10 @@ import ( mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" "golang.org/x/sync/errgroup" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/log/logtest" "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/p2p/server" ) @@ -33,9 +33,9 @@ type fakeRequester struct { bytesReceived uint32 } -var _ requester = &fakeRequester{} +var _ Requester = &fakeRequester{} -func newFakeRequester(id p2p.Peer, handler server.StreamHandler, peers ...requester) *fakeRequester { +func newFakeRequester(id p2p.Peer, handler server.StreamHandler, peers ...Requester) *fakeRequester { fr := &fakeRequester{ id: id, handler: handler, @@ -346,15 +346,18 @@ func TestWireConduit(t *testing.T) { require.NoError(t, err) } -type getRequesterFunc func(name string, handler server.StreamHandler, peers ...requester) (requester, p2p.Peer) +type getRequesterFunc func(name string, handler server.StreamHandler, peers ...Requester) (Requester, p2p.Peer) func withClientServer( store ItemStore, getRequester getRequesterFunc, - opts []Option, - toCall func(ctx context.Context, client requester, srvPeerID p2p.Peer), + opts []RangeSetReconcilerOption, + toCall func(ctx context.Context, client Requester, srvPeerID p2p.Peer), ) { - srvHandler := MakeServerHandler(store, opts...) + srvHandler := func(ctx context.Context, req []byte, stream io.ReadWriter) error { + pss := NewPairwiseStoreSyncer(nil, opts) + return pss.Serve(ctx, req, stream, store) + } srv, srvPeerID := getRequester("srv", srvHandler) var eg errgroup.Group ctx, cancel := context.WithCancel(context.Background()) @@ -371,7 +374,7 @@ func withClientServer( } func fakeRequesterGetter() getRequesterFunc { - return func(name string, handler server.StreamHandler, peers ...requester) (requester, p2p.Peer) { + return func(name string, handler server.StreamHandler, peers ...Requester) (Requester, p2p.Peer) { pid := p2p.Peer(name) return newFakeRequester(pid, handler, peers...), pid } @@ -383,9 +386,9 @@ func p2pRequesterGetter(t *testing.T) getRequesterFunc { proto := "itest" opts := []server.Opt{ server.WithTimeout(10 * time.Second), - server.WithLog(logtest.New(t)), + server.WithLog(zaptest.NewLogger(t)), } - return func(name string, handler server.StreamHandler, peers ...requester) (requester, p2p.Peer) { + return func(name string, handler server.StreamHandler, peers ...Requester) (Requester, p2p.Peer) { if len(peers) == 0 { return server.New(mesh.Hosts()[0], proto, handler, opts...), mesh.Hosts()[0].ID() } @@ -403,7 +406,7 @@ func p2pRequesterGetter(t *testing.T) getRequesterFunc { } } -func testWireSync(t *testing.T, getRequester getRequesterFunc) requester { +func testWireSync(t *testing.T, getRequester getRequesterFunc) Requester { cfg := xorSyncTestConfig{ // large test: // maxSendRange: 1, @@ -419,21 +422,20 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) requester { minNumSpecificB: 4, maxNumSpecificB: 100, } - var client requester - verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []Option) bool { + var client Requester + verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []RangeSetReconcilerOption) bool { withClientServer( storeA, getRequester, opts, - func(ctx context.Context, client requester, srvPeerID p2p.Peer) { - err := SyncStore(ctx, client, srvPeerID, storeB, nil, nil, opts...) + func(ctx context.Context, client Requester, srvPeerID p2p.Peer) { + pss := NewPairwiseStoreSyncer(client, opts) + err := pss.SyncStore(ctx, srvPeerID, storeB, nil, nil) require.NoError(t, err) if fr, ok := client.(*fakeRequester); ok { t.Logf("numSpecific: %d, bytesSent %d, bytesReceived %d", numSpecific, fr.bytesSent, fr.bytesReceived) } - smtx.Lock() - t.Logf("bytes read: %d, bytes written: %d", numRead, numWritten) - smtx.Unlock() + t.Logf("bytes read: %d, bytes written: %d", numRead.Load(), numWritten.Load()) }) return true }) @@ -449,7 +451,7 @@ func TestWireSync(t *testing.T) { }) } -func testWireProbe(t *testing.T, getRequester getRequesterFunc) requester { +func testWireProbe(t *testing.T, getRequester getRequesterFunc) Requester { cfg := xorSyncTestConfig{ maxSendRange: 1, numTestHashes: 10000, @@ -458,14 +460,15 @@ func testWireProbe(t *testing.T, getRequester getRequesterFunc) requester { minNumSpecificB: 130, maxNumSpecificB: 130, } - var client requester - verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []Option) bool { + var client Requester + verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []RangeSetReconcilerOption) bool { withClientServer( storeA, getRequester, opts, - func(ctx context.Context, client requester, srvPeerID p2p.Peer) { + func(ctx context.Context, client Requester, srvPeerID p2p.Peer) { + pss := NewPairwiseStoreSyncer(client, opts) minA := storeA.Min().Key() infoA := storeA.GetRangeInfo(nil, minA, minA, -1) - prA, err := Probe(ctx, client, srvPeerID, storeB, nil, nil, opts...) + prA, err := pss.Probe(ctx, srvPeerID, storeB, nil, nil) require.NoError(t, err) require.Equal(t, infoA.Fingerprint, prA.FP) require.Equal(t, infoA.Count, prA.Count) @@ -476,7 +479,7 @@ func testWireProbe(t *testing.T, getRequester getRequesterFunc) requester { x := partInfoA.Start.Key().(types.Hash32) y := partInfoA.End.Key().(types.Hash32) // partInfoA = storeA.GetRangeInfo(nil, x, y, -1) - prA, err = Probe(ctx, client, srvPeerID, storeB, &x, &y, opts...) + prA, err = pss.Probe(ctx, srvPeerID, storeB, &x, &y) require.NoError(t, err) require.Equal(t, partInfoA.Fingerprint, prA.FP) require.Equal(t, partInfoA.Count, prA.Count) diff --git a/hashsync/interface.go b/sync2/hashsync/interface.go similarity index 77% rename from hashsync/interface.go rename to sync2/hashsync/interface.go index 6285010d47..6038750122 100644 --- a/hashsync/interface.go +++ b/sync2/hashsync/interface.go @@ -2,6 +2,7 @@ package hashsync import ( "context" + "io" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/p2p" @@ -49,29 +50,31 @@ type ItemStore interface { Has(k Ordered) bool } -type requester interface { +type Requester interface { Run(context.Context) error StreamRequest(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error } -type syncBase interface { - count() int - derive(p p2p.Peer) syncer - probe(ctx context.Context, p p2p.Peer) (ProbeResult, error) - wait() error +type SyncBase interface { + Count() int + Derive(p p2p.Peer) Syncer + Probe(ctx context.Context, p p2p.Peer) (ProbeResult, error) + Wait() error } -type syncer interface { - peer() p2p.Peer - sync(ctx context.Context, x, y *types.Hash32) error +type Syncer interface { + Peer() p2p.Peer + Sync(ctx context.Context, x, y *types.Hash32) error + Serve(ctx context.Context, req []byte, stream io.ReadWriter) error +} + +type PairwiseSyncer interface { + Probe(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) (ProbeResult, error) + SyncStore(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) error + Serve(ctx context.Context, req []byte, stream io.ReadWriter, is ItemStore) error } type syncRunner interface { splitSync(ctx context.Context, syncPeers []p2p.Peer) error fullSync(ctx context.Context, syncPeers []p2p.Peer) error } - -type pairwiseSyncer interface { - probe(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) (ProbeResult, error) - syncStore(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) error -} diff --git a/hashsync/mocks_test.go b/sync2/hashsync/mocks_test.go similarity index 63% rename from hashsync/mocks_test.go rename to sync2/hashsync/mocks_test.go index ec175b0b1d..a78fc16861 100644 --- a/hashsync/mocks_test.go +++ b/sync2/hashsync/mocks_test.go @@ -11,6 +11,7 @@ package hashsync import ( context "context" + io "io" reflect "reflect" types "github.com/spacemeshos/go-spacemesh/common/types" @@ -405,31 +406,31 @@ func (c *MockItemStoreMinCall) DoAndReturn(f func() Iterator) *MockItemStoreMinC return c } -// Mockrequester is a mock of requester interface. -type Mockrequester struct { +// MockRequester is a mock of Requester interface. +type MockRequester struct { ctrl *gomock.Controller - recorder *MockrequesterMockRecorder + recorder *MockRequesterMockRecorder } -// MockrequesterMockRecorder is the mock recorder for Mockrequester. -type MockrequesterMockRecorder struct { - mock *Mockrequester +// MockRequesterMockRecorder is the mock recorder for MockRequester. +type MockRequesterMockRecorder struct { + mock *MockRequester } -// NewMockrequester creates a new mock instance. -func NewMockrequester(ctrl *gomock.Controller) *Mockrequester { - mock := &Mockrequester{ctrl: ctrl} - mock.recorder = &MockrequesterMockRecorder{mock} +// NewMockRequester creates a new mock instance. +func NewMockRequester(ctrl *gomock.Controller) *MockRequester { + mock := &MockRequester{ctrl: ctrl} + mock.recorder = &MockRequesterMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *Mockrequester) EXPECT() *MockrequesterMockRecorder { +func (m *MockRequester) EXPECT() *MockRequesterMockRecorder { return m.recorder } // Run mocks base method. -func (m *Mockrequester) Run(arg0 context.Context) error { +func (m *MockRequester) Run(arg0 context.Context) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Run", arg0) ret0, _ := ret[0].(error) @@ -437,37 +438,37 @@ func (m *Mockrequester) Run(arg0 context.Context) error { } // Run indicates an expected call of Run. -func (mr *MockrequesterMockRecorder) Run(arg0 any) *MockrequesterRunCall { +func (mr *MockRequesterMockRecorder) Run(arg0 any) *MockRequesterRunCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*Mockrequester)(nil).Run), arg0) - return &MockrequesterRunCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRequester)(nil).Run), arg0) + return &MockRequesterRunCall{Call: call} } -// MockrequesterRunCall wrap *gomock.Call -type MockrequesterRunCall struct { +// MockRequesterRunCall wrap *gomock.Call +type MockRequesterRunCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockrequesterRunCall) Return(arg0 error) *MockrequesterRunCall { +func (c *MockRequesterRunCall) Return(arg0 error) *MockRequesterRunCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockrequesterRunCall) Do(f func(context.Context) error) *MockrequesterRunCall { +func (c *MockRequesterRunCall) Do(f func(context.Context) error) *MockRequesterRunCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockrequesterRunCall) DoAndReturn(f func(context.Context) error) *MockrequesterRunCall { +func (c *MockRequesterRunCall) DoAndReturn(f func(context.Context) error) *MockRequesterRunCall { c.Call = c.Call.DoAndReturn(f) return c } // StreamRequest mocks base method. -func (m *Mockrequester) StreamRequest(arg0 context.Context, arg1 p2p.Peer, arg2 []byte, arg3 server.StreamRequestCallback, arg4 ...string) error { +func (m *MockRequester) StreamRequest(arg0 context.Context, arg1 p2p.Peer, arg2 []byte, arg3 server.StreamRequestCallback, arg4 ...string) error { m.ctrl.T.Helper() varargs := []any{arg0, arg1, arg2, arg3} for _, a := range arg4 { @@ -479,506 +480,582 @@ func (m *Mockrequester) StreamRequest(arg0 context.Context, arg1 p2p.Peer, arg2 } // StreamRequest indicates an expected call of StreamRequest. -func (mr *MockrequesterMockRecorder) StreamRequest(arg0, arg1, arg2, arg3 any, arg4 ...any) *MockrequesterStreamRequestCall { +func (mr *MockRequesterMockRecorder) StreamRequest(arg0, arg1, arg2, arg3 any, arg4 ...any) *MockRequesterStreamRequestCall { mr.mock.ctrl.T.Helper() varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamRequest", reflect.TypeOf((*Mockrequester)(nil).StreamRequest), varargs...) - return &MockrequesterStreamRequestCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamRequest", reflect.TypeOf((*MockRequester)(nil).StreamRequest), varargs...) + return &MockRequesterStreamRequestCall{Call: call} } -// MockrequesterStreamRequestCall wrap *gomock.Call -type MockrequesterStreamRequestCall struct { +// MockRequesterStreamRequestCall wrap *gomock.Call +type MockRequesterStreamRequestCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockrequesterStreamRequestCall) Return(arg0 error) *MockrequesterStreamRequestCall { +func (c *MockRequesterStreamRequestCall) Return(arg0 error) *MockRequesterStreamRequestCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockrequesterStreamRequestCall) Do(f func(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error) *MockrequesterStreamRequestCall { +func (c *MockRequesterStreamRequestCall) Do(f func(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error) *MockRequesterStreamRequestCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockrequesterStreamRequestCall) DoAndReturn(f func(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error) *MockrequesterStreamRequestCall { +func (c *MockRequesterStreamRequestCall) DoAndReturn(f func(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error) *MockRequesterStreamRequestCall { c.Call = c.Call.DoAndReturn(f) return c } -// MocksyncBase is a mock of syncBase interface. -type MocksyncBase struct { +// MockSyncBase is a mock of SyncBase interface. +type MockSyncBase struct { ctrl *gomock.Controller - recorder *MocksyncBaseMockRecorder + recorder *MockSyncBaseMockRecorder } -// MocksyncBaseMockRecorder is the mock recorder for MocksyncBase. -type MocksyncBaseMockRecorder struct { - mock *MocksyncBase +// MockSyncBaseMockRecorder is the mock recorder for MockSyncBase. +type MockSyncBaseMockRecorder struct { + mock *MockSyncBase } -// NewMocksyncBase creates a new mock instance. -func NewMocksyncBase(ctrl *gomock.Controller) *MocksyncBase { - mock := &MocksyncBase{ctrl: ctrl} - mock.recorder = &MocksyncBaseMockRecorder{mock} +// NewMockSyncBase creates a new mock instance. +func NewMockSyncBase(ctrl *gomock.Controller) *MockSyncBase { + mock := &MockSyncBase{ctrl: ctrl} + mock.recorder = &MockSyncBaseMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MocksyncBase) EXPECT() *MocksyncBaseMockRecorder { +func (m *MockSyncBase) EXPECT() *MockSyncBaseMockRecorder { return m.recorder } -// count mocks base method. -func (m *MocksyncBase) count() int { +// Count mocks base method. +func (m *MockSyncBase) Count() int { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "count") + ret := m.ctrl.Call(m, "Count") ret0, _ := ret[0].(int) return ret0 } -// count indicates an expected call of count. -func (mr *MocksyncBaseMockRecorder) count() *MocksyncBasecountCall { +// Count indicates an expected call of Count. +func (mr *MockSyncBaseMockRecorder) Count() *MockSyncBaseCountCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "count", reflect.TypeOf((*MocksyncBase)(nil).count)) - return &MocksyncBasecountCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Count", reflect.TypeOf((*MockSyncBase)(nil).Count)) + return &MockSyncBaseCountCall{Call: call} } -// MocksyncBasecountCall wrap *gomock.Call -type MocksyncBasecountCall struct { +// MockSyncBaseCountCall wrap *gomock.Call +type MockSyncBaseCountCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MocksyncBasecountCall) Return(arg0 int) *MocksyncBasecountCall { +func (c *MockSyncBaseCountCall) Return(arg0 int) *MockSyncBaseCountCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MocksyncBasecountCall) Do(f func() int) *MocksyncBasecountCall { +func (c *MockSyncBaseCountCall) Do(f func() int) *MockSyncBaseCountCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MocksyncBasecountCall) DoAndReturn(f func() int) *MocksyncBasecountCall { +func (c *MockSyncBaseCountCall) DoAndReturn(f func() int) *MockSyncBaseCountCall { c.Call = c.Call.DoAndReturn(f) return c } -// derive mocks base method. -func (m *MocksyncBase) derive(p p2p.Peer) syncer { +// Derive mocks base method. +func (m *MockSyncBase) Derive(p p2p.Peer) Syncer { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "derive", p) - ret0, _ := ret[0].(syncer) + ret := m.ctrl.Call(m, "Derive", p) + ret0, _ := ret[0].(Syncer) return ret0 } -// derive indicates an expected call of derive. -func (mr *MocksyncBaseMockRecorder) derive(p any) *MocksyncBasederiveCall { +// Derive indicates an expected call of Derive. +func (mr *MockSyncBaseMockRecorder) Derive(p any) *MockSyncBaseDeriveCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "derive", reflect.TypeOf((*MocksyncBase)(nil).derive), p) - return &MocksyncBasederiveCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Derive", reflect.TypeOf((*MockSyncBase)(nil).Derive), p) + return &MockSyncBaseDeriveCall{Call: call} } -// MocksyncBasederiveCall wrap *gomock.Call -type MocksyncBasederiveCall struct { +// MockSyncBaseDeriveCall wrap *gomock.Call +type MockSyncBaseDeriveCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MocksyncBasederiveCall) Return(arg0 syncer) *MocksyncBasederiveCall { +func (c *MockSyncBaseDeriveCall) Return(arg0 Syncer) *MockSyncBaseDeriveCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MocksyncBasederiveCall) Do(f func(p2p.Peer) syncer) *MocksyncBasederiveCall { +func (c *MockSyncBaseDeriveCall) Do(f func(p2p.Peer) Syncer) *MockSyncBaseDeriveCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MocksyncBasederiveCall) DoAndReturn(f func(p2p.Peer) syncer) *MocksyncBasederiveCall { +func (c *MockSyncBaseDeriveCall) DoAndReturn(f func(p2p.Peer) Syncer) *MockSyncBaseDeriveCall { c.Call = c.Call.DoAndReturn(f) return c } -// probe mocks base method. -func (m *MocksyncBase) probe(ctx context.Context, p p2p.Peer) (ProbeResult, error) { +// Probe mocks base method. +func (m *MockSyncBase) Probe(ctx context.Context, p p2p.Peer) (ProbeResult, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "probe", ctx, p) + ret := m.ctrl.Call(m, "Probe", ctx, p) ret0, _ := ret[0].(ProbeResult) ret1, _ := ret[1].(error) return ret0, ret1 } -// probe indicates an expected call of probe. -func (mr *MocksyncBaseMockRecorder) probe(ctx, p any) *MocksyncBaseprobeCall { +// Probe indicates an expected call of Probe. +func (mr *MockSyncBaseMockRecorder) Probe(ctx, p any) *MockSyncBaseProbeCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "probe", reflect.TypeOf((*MocksyncBase)(nil).probe), ctx, p) - return &MocksyncBaseprobeCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Probe", reflect.TypeOf((*MockSyncBase)(nil).Probe), ctx, p) + return &MockSyncBaseProbeCall{Call: call} } -// MocksyncBaseprobeCall wrap *gomock.Call -type MocksyncBaseprobeCall struct { +// MockSyncBaseProbeCall wrap *gomock.Call +type MockSyncBaseProbeCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MocksyncBaseprobeCall) Return(arg0 ProbeResult, arg1 error) *MocksyncBaseprobeCall { +func (c *MockSyncBaseProbeCall) Return(arg0 ProbeResult, arg1 error) *MockSyncBaseProbeCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MocksyncBaseprobeCall) Do(f func(context.Context, p2p.Peer) (ProbeResult, error)) *MocksyncBaseprobeCall { +func (c *MockSyncBaseProbeCall) Do(f func(context.Context, p2p.Peer) (ProbeResult, error)) *MockSyncBaseProbeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MocksyncBaseprobeCall) DoAndReturn(f func(context.Context, p2p.Peer) (ProbeResult, error)) *MocksyncBaseprobeCall { +func (c *MockSyncBaseProbeCall) DoAndReturn(f func(context.Context, p2p.Peer) (ProbeResult, error)) *MockSyncBaseProbeCall { c.Call = c.Call.DoAndReturn(f) return c } -// wait mocks base method. -func (m *MocksyncBase) wait() error { +// Wait mocks base method. +func (m *MockSyncBase) Wait() error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "wait") + ret := m.ctrl.Call(m, "Wait") ret0, _ := ret[0].(error) return ret0 } -// wait indicates an expected call of wait. -func (mr *MocksyncBaseMockRecorder) wait() *MocksyncBasewaitCall { +// Wait indicates an expected call of Wait. +func (mr *MockSyncBaseMockRecorder) Wait() *MockSyncBaseWaitCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "wait", reflect.TypeOf((*MocksyncBase)(nil).wait)) - return &MocksyncBasewaitCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Wait", reflect.TypeOf((*MockSyncBase)(nil).Wait)) + return &MockSyncBaseWaitCall{Call: call} } -// MocksyncBasewaitCall wrap *gomock.Call -type MocksyncBasewaitCall struct { +// MockSyncBaseWaitCall wrap *gomock.Call +type MockSyncBaseWaitCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MocksyncBasewaitCall) Return(arg0 error) *MocksyncBasewaitCall { +func (c *MockSyncBaseWaitCall) Return(arg0 error) *MockSyncBaseWaitCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MocksyncBasewaitCall) Do(f func() error) *MocksyncBasewaitCall { +func (c *MockSyncBaseWaitCall) Do(f func() error) *MockSyncBaseWaitCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MocksyncBasewaitCall) DoAndReturn(f func() error) *MocksyncBasewaitCall { +func (c *MockSyncBaseWaitCall) DoAndReturn(f func() error) *MockSyncBaseWaitCall { c.Call = c.Call.DoAndReturn(f) return c } -// Mocksyncer is a mock of syncer interface. -type Mocksyncer struct { +// MockSyncer is a mock of Syncer interface. +type MockSyncer struct { ctrl *gomock.Controller - recorder *MocksyncerMockRecorder + recorder *MockSyncerMockRecorder } -// MocksyncerMockRecorder is the mock recorder for Mocksyncer. -type MocksyncerMockRecorder struct { - mock *Mocksyncer +// MockSyncerMockRecorder is the mock recorder for MockSyncer. +type MockSyncerMockRecorder struct { + mock *MockSyncer } -// NewMocksyncer creates a new mock instance. -func NewMocksyncer(ctrl *gomock.Controller) *Mocksyncer { - mock := &Mocksyncer{ctrl: ctrl} - mock.recorder = &MocksyncerMockRecorder{mock} +// NewMockSyncer creates a new mock instance. +func NewMockSyncer(ctrl *gomock.Controller) *MockSyncer { + mock := &MockSyncer{ctrl: ctrl} + mock.recorder = &MockSyncerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *Mocksyncer) EXPECT() *MocksyncerMockRecorder { +func (m *MockSyncer) EXPECT() *MockSyncerMockRecorder { return m.recorder } -// peer mocks base method. -func (m *Mocksyncer) peer() p2p.Peer { +// Peer mocks base method. +func (m *MockSyncer) Peer() p2p.Peer { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "peer") + ret := m.ctrl.Call(m, "Peer") ret0, _ := ret[0].(p2p.Peer) return ret0 } -// peer indicates an expected call of peer. -func (mr *MocksyncerMockRecorder) peer() *MocksyncerpeerCall { +// Peer indicates an expected call of Peer. +func (mr *MockSyncerMockRecorder) Peer() *MockSyncerPeerCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "peer", reflect.TypeOf((*Mocksyncer)(nil).peer)) - return &MocksyncerpeerCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Peer", reflect.TypeOf((*MockSyncer)(nil).Peer)) + return &MockSyncerPeerCall{Call: call} } -// MocksyncerpeerCall wrap *gomock.Call -type MocksyncerpeerCall struct { +// MockSyncerPeerCall wrap *gomock.Call +type MockSyncerPeerCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MocksyncerpeerCall) Return(arg0 p2p.Peer) *MocksyncerpeerCall { +func (c *MockSyncerPeerCall) Return(arg0 p2p.Peer) *MockSyncerPeerCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MocksyncerpeerCall) Do(f func() p2p.Peer) *MocksyncerpeerCall { +func (c *MockSyncerPeerCall) Do(f func() p2p.Peer) *MockSyncerPeerCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MocksyncerpeerCall) DoAndReturn(f func() p2p.Peer) *MocksyncerpeerCall { +func (c *MockSyncerPeerCall) DoAndReturn(f func() p2p.Peer) *MockSyncerPeerCall { c.Call = c.Call.DoAndReturn(f) return c } -// sync mocks base method. -func (m *Mocksyncer) sync(ctx context.Context, x, y *types.Hash32) error { +// Serve mocks base method. +func (m *MockSyncer) Serve(ctx context.Context, req []byte, stream io.ReadWriter) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "sync", ctx, x, y) + ret := m.ctrl.Call(m, "Serve", ctx, req, stream) ret0, _ := ret[0].(error) return ret0 } -// sync indicates an expected call of sync. -func (mr *MocksyncerMockRecorder) sync(ctx, x, y any) *MocksyncersyncCall { +// Serve indicates an expected call of Serve. +func (mr *MockSyncerMockRecorder) Serve(ctx, req, stream any) *MockSyncerServeCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "sync", reflect.TypeOf((*Mocksyncer)(nil).sync), ctx, x, y) - return &MocksyncersyncCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Serve", reflect.TypeOf((*MockSyncer)(nil).Serve), ctx, req, stream) + return &MockSyncerServeCall{Call: call} } -// MocksyncersyncCall wrap *gomock.Call -type MocksyncersyncCall struct { +// MockSyncerServeCall wrap *gomock.Call +type MockSyncerServeCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MocksyncersyncCall) Return(arg0 error) *MocksyncersyncCall { +func (c *MockSyncerServeCall) Return(arg0 error) *MockSyncerServeCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MocksyncersyncCall) Do(f func(context.Context, *types.Hash32, *types.Hash32) error) *MocksyncersyncCall { +func (c *MockSyncerServeCall) Do(f func(context.Context, []byte, io.ReadWriter) error) *MockSyncerServeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MocksyncersyncCall) DoAndReturn(f func(context.Context, *types.Hash32, *types.Hash32) error) *MocksyncersyncCall { +func (c *MockSyncerServeCall) DoAndReturn(f func(context.Context, []byte, io.ReadWriter) error) *MockSyncerServeCall { c.Call = c.Call.DoAndReturn(f) return c } -// MocksyncRunner is a mock of syncRunner interface. -type MocksyncRunner struct { +// Sync mocks base method. +func (m *MockSyncer) Sync(ctx context.Context, x, y *types.Hash32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Sync", ctx, x, y) + ret0, _ := ret[0].(error) + return ret0 +} + +// Sync indicates an expected call of Sync. +func (mr *MockSyncerMockRecorder) Sync(ctx, x, y any) *MockSyncerSyncCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sync", reflect.TypeOf((*MockSyncer)(nil).Sync), ctx, x, y) + return &MockSyncerSyncCall{Call: call} +} + +// MockSyncerSyncCall wrap *gomock.Call +type MockSyncerSyncCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSyncerSyncCall) Return(arg0 error) *MockSyncerSyncCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSyncerSyncCall) Do(f func(context.Context, *types.Hash32, *types.Hash32) error) *MockSyncerSyncCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSyncerSyncCall) DoAndReturn(f func(context.Context, *types.Hash32, *types.Hash32) error) *MockSyncerSyncCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockPairwiseSyncer is a mock of PairwiseSyncer interface. +type MockPairwiseSyncer struct { ctrl *gomock.Controller - recorder *MocksyncRunnerMockRecorder + recorder *MockPairwiseSyncerMockRecorder } -// MocksyncRunnerMockRecorder is the mock recorder for MocksyncRunner. -type MocksyncRunnerMockRecorder struct { - mock *MocksyncRunner +// MockPairwiseSyncerMockRecorder is the mock recorder for MockPairwiseSyncer. +type MockPairwiseSyncerMockRecorder struct { + mock *MockPairwiseSyncer } -// NewMocksyncRunner creates a new mock instance. -func NewMocksyncRunner(ctrl *gomock.Controller) *MocksyncRunner { - mock := &MocksyncRunner{ctrl: ctrl} - mock.recorder = &MocksyncRunnerMockRecorder{mock} +// NewMockPairwiseSyncer creates a new mock instance. +func NewMockPairwiseSyncer(ctrl *gomock.Controller) *MockPairwiseSyncer { + mock := &MockPairwiseSyncer{ctrl: ctrl} + mock.recorder = &MockPairwiseSyncerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MocksyncRunner) EXPECT() *MocksyncRunnerMockRecorder { +func (m *MockPairwiseSyncer) EXPECT() *MockPairwiseSyncerMockRecorder { return m.recorder } -// fullSync mocks base method. -func (m *MocksyncRunner) fullSync(ctx context.Context, syncPeers []p2p.Peer) error { +// Probe mocks base method. +func (m *MockPairwiseSyncer) Probe(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) (ProbeResult, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "fullSync", ctx, syncPeers) + ret := m.ctrl.Call(m, "Probe", ctx, peer, is, x, y) + ret0, _ := ret[0].(ProbeResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Probe indicates an expected call of Probe. +func (mr *MockPairwiseSyncerMockRecorder) Probe(ctx, peer, is, x, y any) *MockPairwiseSyncerProbeCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Probe", reflect.TypeOf((*MockPairwiseSyncer)(nil).Probe), ctx, peer, is, x, y) + return &MockPairwiseSyncerProbeCall{Call: call} +} + +// MockPairwiseSyncerProbeCall wrap *gomock.Call +type MockPairwiseSyncerProbeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPairwiseSyncerProbeCall) Return(arg0 ProbeResult, arg1 error) *MockPairwiseSyncerProbeCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPairwiseSyncerProbeCall) Do(f func(context.Context, p2p.Peer, ItemStore, *types.Hash32, *types.Hash32) (ProbeResult, error)) *MockPairwiseSyncerProbeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPairwiseSyncerProbeCall) DoAndReturn(f func(context.Context, p2p.Peer, ItemStore, *types.Hash32, *types.Hash32) (ProbeResult, error)) *MockPairwiseSyncerProbeCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Serve mocks base method. +func (m *MockPairwiseSyncer) Serve(ctx context.Context, req []byte, stream io.ReadWriter, is ItemStore) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Serve", ctx, req, stream, is) ret0, _ := ret[0].(error) return ret0 } -// fullSync indicates an expected call of fullSync. -func (mr *MocksyncRunnerMockRecorder) fullSync(ctx, syncPeers any) *MocksyncRunnerfullSyncCall { +// Serve indicates an expected call of Serve. +func (mr *MockPairwiseSyncerMockRecorder) Serve(ctx, req, stream, is any) *MockPairwiseSyncerServeCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "fullSync", reflect.TypeOf((*MocksyncRunner)(nil).fullSync), ctx, syncPeers) - return &MocksyncRunnerfullSyncCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Serve", reflect.TypeOf((*MockPairwiseSyncer)(nil).Serve), ctx, req, stream, is) + return &MockPairwiseSyncerServeCall{Call: call} } -// MocksyncRunnerfullSyncCall wrap *gomock.Call -type MocksyncRunnerfullSyncCall struct { +// MockPairwiseSyncerServeCall wrap *gomock.Call +type MockPairwiseSyncerServeCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MocksyncRunnerfullSyncCall) Return(arg0 error) *MocksyncRunnerfullSyncCall { +func (c *MockPairwiseSyncerServeCall) Return(arg0 error) *MockPairwiseSyncerServeCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MocksyncRunnerfullSyncCall) Do(f func(context.Context, []p2p.Peer) error) *MocksyncRunnerfullSyncCall { +func (c *MockPairwiseSyncerServeCall) Do(f func(context.Context, []byte, io.ReadWriter, ItemStore) error) *MockPairwiseSyncerServeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MocksyncRunnerfullSyncCall) DoAndReturn(f func(context.Context, []p2p.Peer) error) *MocksyncRunnerfullSyncCall { +func (c *MockPairwiseSyncerServeCall) DoAndReturn(f func(context.Context, []byte, io.ReadWriter, ItemStore) error) *MockPairwiseSyncerServeCall { c.Call = c.Call.DoAndReturn(f) return c } -// splitSync mocks base method. -func (m *MocksyncRunner) splitSync(ctx context.Context, syncPeers []p2p.Peer) error { +// SyncStore mocks base method. +func (m *MockPairwiseSyncer) SyncStore(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "splitSync", ctx, syncPeers) + ret := m.ctrl.Call(m, "SyncStore", ctx, peer, is, x, y) ret0, _ := ret[0].(error) return ret0 } -// splitSync indicates an expected call of splitSync. -func (mr *MocksyncRunnerMockRecorder) splitSync(ctx, syncPeers any) *MocksyncRunnersplitSyncCall { +// SyncStore indicates an expected call of SyncStore. +func (mr *MockPairwiseSyncerMockRecorder) SyncStore(ctx, peer, is, x, y any) *MockPairwiseSyncerSyncStoreCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "splitSync", reflect.TypeOf((*MocksyncRunner)(nil).splitSync), ctx, syncPeers) - return &MocksyncRunnersplitSyncCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncStore", reflect.TypeOf((*MockPairwiseSyncer)(nil).SyncStore), ctx, peer, is, x, y) + return &MockPairwiseSyncerSyncStoreCall{Call: call} } -// MocksyncRunnersplitSyncCall wrap *gomock.Call -type MocksyncRunnersplitSyncCall struct { +// MockPairwiseSyncerSyncStoreCall wrap *gomock.Call +type MockPairwiseSyncerSyncStoreCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MocksyncRunnersplitSyncCall) Return(arg0 error) *MocksyncRunnersplitSyncCall { +func (c *MockPairwiseSyncerSyncStoreCall) Return(arg0 error) *MockPairwiseSyncerSyncStoreCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MocksyncRunnersplitSyncCall) Do(f func(context.Context, []p2p.Peer) error) *MocksyncRunnersplitSyncCall { +func (c *MockPairwiseSyncerSyncStoreCall) Do(f func(context.Context, p2p.Peer, ItemStore, *types.Hash32, *types.Hash32) error) *MockPairwiseSyncerSyncStoreCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MocksyncRunnersplitSyncCall) DoAndReturn(f func(context.Context, []p2p.Peer) error) *MocksyncRunnersplitSyncCall { +func (c *MockPairwiseSyncerSyncStoreCall) DoAndReturn(f func(context.Context, p2p.Peer, ItemStore, *types.Hash32, *types.Hash32) error) *MockPairwiseSyncerSyncStoreCall { c.Call = c.Call.DoAndReturn(f) return c } -// MockpairwiseSyncer is a mock of pairwiseSyncer interface. -type MockpairwiseSyncer struct { +// MocksyncRunner is a mock of syncRunner interface. +type MocksyncRunner struct { ctrl *gomock.Controller - recorder *MockpairwiseSyncerMockRecorder + recorder *MocksyncRunnerMockRecorder } -// MockpairwiseSyncerMockRecorder is the mock recorder for MockpairwiseSyncer. -type MockpairwiseSyncerMockRecorder struct { - mock *MockpairwiseSyncer +// MocksyncRunnerMockRecorder is the mock recorder for MocksyncRunner. +type MocksyncRunnerMockRecorder struct { + mock *MocksyncRunner } -// NewMockpairwiseSyncer creates a new mock instance. -func NewMockpairwiseSyncer(ctrl *gomock.Controller) *MockpairwiseSyncer { - mock := &MockpairwiseSyncer{ctrl: ctrl} - mock.recorder = &MockpairwiseSyncerMockRecorder{mock} +// NewMocksyncRunner creates a new mock instance. +func NewMocksyncRunner(ctrl *gomock.Controller) *MocksyncRunner { + mock := &MocksyncRunner{ctrl: ctrl} + mock.recorder = &MocksyncRunnerMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockpairwiseSyncer) EXPECT() *MockpairwiseSyncerMockRecorder { +func (m *MocksyncRunner) EXPECT() *MocksyncRunnerMockRecorder { return m.recorder } -// probe mocks base method. -func (m *MockpairwiseSyncer) probe(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) (ProbeResult, error) { +// fullSync mocks base method. +func (m *MocksyncRunner) fullSync(ctx context.Context, syncPeers []p2p.Peer) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "probe", ctx, peer, is, x, y) - ret0, _ := ret[0].(ProbeResult) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "fullSync", ctx, syncPeers) + ret0, _ := ret[0].(error) + return ret0 } -// probe indicates an expected call of probe. -func (mr *MockpairwiseSyncerMockRecorder) probe(ctx, peer, is, x, y any) *MockpairwiseSyncerprobeCall { +// fullSync indicates an expected call of fullSync. +func (mr *MocksyncRunnerMockRecorder) fullSync(ctx, syncPeers any) *MocksyncRunnerfullSyncCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "probe", reflect.TypeOf((*MockpairwiseSyncer)(nil).probe), ctx, peer, is, x, y) - return &MockpairwiseSyncerprobeCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "fullSync", reflect.TypeOf((*MocksyncRunner)(nil).fullSync), ctx, syncPeers) + return &MocksyncRunnerfullSyncCall{Call: call} } -// MockpairwiseSyncerprobeCall wrap *gomock.Call -type MockpairwiseSyncerprobeCall struct { +// MocksyncRunnerfullSyncCall wrap *gomock.Call +type MocksyncRunnerfullSyncCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockpairwiseSyncerprobeCall) Return(arg0 ProbeResult, arg1 error) *MockpairwiseSyncerprobeCall { - c.Call = c.Call.Return(arg0, arg1) +func (c *MocksyncRunnerfullSyncCall) Return(arg0 error) *MocksyncRunnerfullSyncCall { + c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockpairwiseSyncerprobeCall) Do(f func(context.Context, p2p.Peer, ItemStore, *types.Hash32, *types.Hash32) (ProbeResult, error)) *MockpairwiseSyncerprobeCall { +func (c *MocksyncRunnerfullSyncCall) Do(f func(context.Context, []p2p.Peer) error) *MocksyncRunnerfullSyncCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockpairwiseSyncerprobeCall) DoAndReturn(f func(context.Context, p2p.Peer, ItemStore, *types.Hash32, *types.Hash32) (ProbeResult, error)) *MockpairwiseSyncerprobeCall { +func (c *MocksyncRunnerfullSyncCall) DoAndReturn(f func(context.Context, []p2p.Peer) error) *MocksyncRunnerfullSyncCall { c.Call = c.Call.DoAndReturn(f) return c } -// syncStore mocks base method. -func (m *MockpairwiseSyncer) syncStore(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) error { +// splitSync mocks base method. +func (m *MocksyncRunner) splitSync(ctx context.Context, syncPeers []p2p.Peer) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "syncStore", ctx, peer, is, x, y) + ret := m.ctrl.Call(m, "splitSync", ctx, syncPeers) ret0, _ := ret[0].(error) return ret0 } -// syncStore indicates an expected call of syncStore. -func (mr *MockpairwiseSyncerMockRecorder) syncStore(ctx, peer, is, x, y any) *MockpairwiseSyncersyncStoreCall { +// splitSync indicates an expected call of splitSync. +func (mr *MocksyncRunnerMockRecorder) splitSync(ctx, syncPeers any) *MocksyncRunnersplitSyncCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "syncStore", reflect.TypeOf((*MockpairwiseSyncer)(nil).syncStore), ctx, peer, is, x, y) - return &MockpairwiseSyncersyncStoreCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "splitSync", reflect.TypeOf((*MocksyncRunner)(nil).splitSync), ctx, syncPeers) + return &MocksyncRunnersplitSyncCall{Call: call} } -// MockpairwiseSyncersyncStoreCall wrap *gomock.Call -type MockpairwiseSyncersyncStoreCall struct { +// MocksyncRunnersplitSyncCall wrap *gomock.Call +type MocksyncRunnersplitSyncCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockpairwiseSyncersyncStoreCall) Return(arg0 error) *MockpairwiseSyncersyncStoreCall { +func (c *MocksyncRunnersplitSyncCall) Return(arg0 error) *MocksyncRunnersplitSyncCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockpairwiseSyncersyncStoreCall) Do(f func(context.Context, p2p.Peer, ItemStore, *types.Hash32, *types.Hash32) error) *MockpairwiseSyncersyncStoreCall { +func (c *MocksyncRunnersplitSyncCall) Do(f func(context.Context, []p2p.Peer) error) *MocksyncRunnersplitSyncCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockpairwiseSyncersyncStoreCall) DoAndReturn(f func(context.Context, p2p.Peer, ItemStore, *types.Hash32, *types.Hash32) error) *MockpairwiseSyncersyncStoreCall { +func (c *MocksyncRunnersplitSyncCall) DoAndReturn(f func(context.Context, []p2p.Peer) error) *MocksyncRunnersplitSyncCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/hashsync/monoid.go b/sync2/hashsync/monoid.go similarity index 100% rename from hashsync/monoid.go rename to sync2/hashsync/monoid.go diff --git a/hashsync/multipeer.go b/sync2/hashsync/multipeer.go similarity index 94% rename from hashsync/multipeer.go rename to sync2/hashsync/multipeer.go index bfd9b1cf00..7ea20d630d 100644 --- a/hashsync/multipeer.go +++ b/sync2/hashsync/multipeer.go @@ -31,7 +31,7 @@ func WithSyncPeerCount(count int) MultiPeerReconcilerOpt { } } -func WithMinFullSyncCount(count int) MultiPeerReconcilerOpt { +func WithMinSplitSyncCount(count int) MultiPeerReconcilerOpt { return func(mpr *MultiPeerReconciler) { mpr.minSplitSyncCount = count } @@ -110,7 +110,7 @@ func (r *runner) fullSync(ctx context.Context, syncPeers []p2p.Peer) error { type MultiPeerReconciler struct { logger *zap.Logger - syncBase syncBase + syncBase SyncBase peers *peers.Peers syncPeerCount int minSplitSyncPeers int @@ -125,7 +125,7 @@ type MultiPeerReconciler struct { } func NewMultiPeerReconciler( - syncBase syncBase, + syncBase SyncBase, peers *peers.Peers, opts ...MultiPeerReconcilerOpt, ) *MultiPeerReconciler { @@ -159,7 +159,7 @@ func (mpr *MultiPeerReconciler) probePeers(ctx context.Context, syncPeers []p2p. s.nearFullCount = 0 for _, p := range syncPeers { mpr.logger.Debug("probe peer", zap.Stringer("peer", p)) - pr, err := mpr.syncBase.probe(ctx, p) + pr, err := mpr.syncBase.Probe(ctx, p) if err != nil { log.Warning("error probing the peer", zap.Any("peer", p), zap.Error(err)) if errors.Is(err, context.Canceled) { @@ -179,17 +179,17 @@ func (mpr *MultiPeerReconciler) probePeers(ctx context.Context, syncPeers []p2p. zap.Int("count", pr.Count)) } - if (1-pr.Sim)*float64(mpr.syncBase.count()) < float64(mpr.maxFullDiff) { + if (1-pr.Sim)*float64(mpr.syncBase.Count()) < float64(mpr.maxFullDiff) { mpr.logger.Debug("nearFull peer", zap.Stringer("peer", p), zap.Float64("sim", pr.Sim), - zap.Int("localCount", mpr.syncBase.count())) + zap.Int("localCount", mpr.syncBase.Count())) s.nearFullCount++ } else { mpr.logger.Debug("nearFull peer", zap.Stringer("peer", p), zap.Float64("sim", pr.Sim), - zap.Int("localCount", mpr.syncBase.count())) + zap.Int("localCount", mpr.syncBase.Count())) } } return s, nil @@ -213,9 +213,9 @@ func (mpr *MultiPeerReconciler) needSplitSync(s syncability) bool { func (mpr *MultiPeerReconciler) fullSync(ctx context.Context, syncPeers []p2p.Peer) error { var eg errgroup.Group for _, p := range syncPeers { - syncer := mpr.syncBase.derive(p) + syncer := mpr.syncBase.Derive(p) eg.Go(func() error { - err := syncer.sync(ctx, nil, nil) + err := syncer.Sync(ctx, nil, nil) switch { case err == nil: case errors.Is(err, context.Canceled): @@ -268,7 +268,7 @@ func (mpr *MultiPeerReconciler) syncOnce(ctx context.Context) error { } // handler errors are not fatal - if handlerErr := mpr.syncBase.wait(); handlerErr != nil { + if handlerErr := mpr.syncBase.Wait(); handlerErr != nil { mpr.logger.Error("error handling synced keys", zap.Error(handlerErr)) } @@ -321,5 +321,5 @@ LOOP: } } cancel() - return errors.Join(err, mpr.syncBase.wait()) + return errors.Join(err, mpr.syncBase.Wait()) } diff --git a/hashsync/multipeer_test.go b/sync2/hashsync/multipeer_test.go similarity index 84% rename from hashsync/multipeer_test.go rename to sync2/hashsync/multipeer_test.go index 4c8073b642..ecc9ad6c1e 100644 --- a/hashsync/multipeer_test.go +++ b/sync2/hashsync/multipeer_test.go @@ -29,7 +29,7 @@ type fakeClock interface { type multiPeerSyncTester struct { *testing.T ctrl *gomock.Controller - syncBase *MocksyncBase + syncBase *MockSyncBase syncRunner *MocksyncRunner peers *peers.Peers clock fakeClock @@ -44,7 +44,7 @@ func newMultiPeerSyncTester(t *testing.T) *multiPeerSyncTester { mt := &multiPeerSyncTester{ T: t, ctrl: ctrl, - syncBase: NewMocksyncBase(ctrl), + syncBase: NewMockSyncBase(ctrl), syncRunner: NewMocksyncRunner(ctrl), peers: peers.New(), clock: clockwork.NewFakeClock().(fakeClock), @@ -54,7 +54,7 @@ func newMultiPeerSyncTester(t *testing.T) *multiPeerSyncTester { WithSyncInterval(time.Minute), WithSyncPeerCount(6), WithMinSplitSyncPeers(2), - WithMinFullSyncCount(90), + WithMinSplitSyncCount(90), WithMaxFullDiff(20), WithMinCompleteFraction(0.9), WithNoPeersRecheckInterval(10*time.Second), @@ -84,7 +84,7 @@ func (mt *multiPeerSyncTester) start() context.Context { func (mt *multiPeerSyncTester) expectProbe(times int, pr ProbeResult) { mt.selectedPeers = nil - mt.syncBase.EXPECT().probe(gomock.Any(), gomock.Any()).DoAndReturn( + mt.syncBase.EXPECT().Probe(gomock.Any(), gomock.Any()).DoAndReturn( func(_ context.Context, p p2p.Peer) (ProbeResult, error) { require.NotContains(mt, mt.selectedPeers, p, "peer probed twice") require.True(mt, mt.peers.Contains(p)) @@ -100,11 +100,11 @@ func (mt *multiPeerSyncTester) expectFullSync(times, numFails int) { // delegate to the real fullsync return mt.reconciler.fullSync(ctx, peers) }) - mt.syncBase.EXPECT().derive(gomock.Any()).DoAndReturn(func(p p2p.Peer) syncer { + mt.syncBase.EXPECT().Derive(gomock.Any()).DoAndReturn(func(p p2p.Peer) Syncer { require.Contains(mt, mt.selectedPeers, p) - s := NewMocksyncer(mt.ctrl) - s.EXPECT().peer().Return(p).AnyTimes() - expSync := s.EXPECT().sync(gomock.Any(), gomock.Nil(), gomock.Nil()) + s := NewMockSyncer(mt.ctrl) + s.EXPECT().Peer().Return(p).AnyTimes() + expSync := s.EXPECT().Sync(gomock.Any(), gomock.Nil(), gomock.Nil()) if numFails != 0 { expSync.Return(errors.New("sync failed")) numFails-- @@ -131,7 +131,7 @@ func TestMultiPeerSync(t *testing.T) { mt.addPeers(10) // Advance by peer wait time. After that, 6 peers will be selected // randomly and probed - mt.syncBase.EXPECT().count().Return(50).AnyTimes() + mt.syncBase.EXPECT().Count().Return(50).AnyTimes() for i := 0; i < numSyncs; i++ { mt.expectProbe(6, ProbeResult{ FP: "foo", @@ -143,7 +143,7 @@ func TestMultiPeerSync(t *testing.T) { require.ElementsMatch(t, mt.selectedPeers, peers) return nil }) - mt.syncBase.EXPECT().wait() + mt.syncBase.EXPECT().Wait() mt.clock.BlockUntilContext(ctx, 1) if i > 0 { mt.clock.Advance(time.Minute) @@ -152,14 +152,14 @@ func TestMultiPeerSync(t *testing.T) { } mt.satisfy() } - mt.syncBase.EXPECT().wait() + mt.syncBase.EXPECT().Wait() }) t.Run("full sync", func(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(10) - mt.syncBase.EXPECT().count().Return(100).AnyTimes() + mt.syncBase.EXPECT().Count().Return(100).AnyTimes() for i := 0; i < numSyncs; i++ { mt.expectProbe(6, ProbeResult{ FP: "foo", @@ -167,19 +167,19 @@ func TestMultiPeerSync(t *testing.T) { Sim: 0.99, // high enough for full sync }) mt.expectFullSync(6, 0) - mt.syncBase.EXPECT().wait() + mt.syncBase.EXPECT().Wait() mt.clock.BlockUntilContext(ctx, 1) mt.clock.Advance(time.Minute) mt.satisfy() } - mt.syncBase.EXPECT().wait() + mt.syncBase.EXPECT().Wait() }) t.Run("full sync due to low peer count", func(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(1) - mt.syncBase.EXPECT().count().Return(50).AnyTimes() + mt.syncBase.EXPECT().Count().Return(50).AnyTimes() for i := 0; i < numSyncs; i++ { mt.expectProbe(1, ProbeResult{ FP: "foo", @@ -187,25 +187,25 @@ func TestMultiPeerSync(t *testing.T) { Sim: 0.5, // too low for full sync, but will have it anyway }) mt.expectFullSync(1, 0) - mt.syncBase.EXPECT().wait() + mt.syncBase.EXPECT().Wait() mt.clock.BlockUntilContext(ctx, 1) mt.clock.Advance(time.Minute) mt.satisfy() } - mt.syncBase.EXPECT().wait() + mt.syncBase.EXPECT().Wait() }) t.Run("probe failure", func(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(10) - mt.syncBase.EXPECT().count().Return(100).AnyTimes() - mt.syncBase.EXPECT().probe(gomock.Any(), gomock.Any()). + mt.syncBase.EXPECT().Count().Return(100).AnyTimes() + mt.syncBase.EXPECT().Probe(gomock.Any(), gomock.Any()). Return(ProbeResult{}, errors.New("probe failed")) mt.expectProbe(5, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) // just 5 peers for which the probe worked will be checked mt.expectFullSync(5, 0) - mt.syncBase.EXPECT().wait().Times(2) + mt.syncBase.EXPECT().Wait().Times(2) mt.clock.BlockUntilContext(ctx, 1) mt.clock.Advance(time.Minute) }) @@ -214,46 +214,46 @@ func TestMultiPeerSync(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(10) - mt.syncBase.EXPECT().count().Return(100).AnyTimes() + mt.syncBase.EXPECT().Count().Return(100).AnyTimes() for i := 0; i < numSyncs; i++ { mt.expectProbe(6, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) mt.expectFullSync(6, 3) - mt.syncBase.EXPECT().wait() + mt.syncBase.EXPECT().Wait() mt.clock.BlockUntilContext(ctx, 1) mt.clock.Advance(time.Minute) mt.satisfy() } - mt.syncBase.EXPECT().wait() + mt.syncBase.EXPECT().Wait() }) t.Run("failed synced key handling during full sync", func(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(10) - mt.syncBase.EXPECT().count().Return(100).AnyTimes() + mt.syncBase.EXPECT().Count().Return(100).AnyTimes() for i := 0; i < numSyncs; i++ { mt.expectProbe(6, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) mt.expectFullSync(6, 0) - mt.syncBase.EXPECT().wait().Return(errors.New("some handlers failed")) + mt.syncBase.EXPECT().Wait().Return(errors.New("some handlers failed")) mt.clock.BlockUntilContext(ctx, 1) mt.clock.Advance(time.Minute) mt.satisfy() } - mt.syncBase.EXPECT().wait() + mt.syncBase.EXPECT().Wait() }) t.Run("cancellation during sync", func(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(10) - mt.syncBase.EXPECT().count().Return(100).AnyTimes() + mt.syncBase.EXPECT().Count().Return(100).AnyTimes() mt.expectProbe(6, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) mt.syncRunner.EXPECT().fullSync(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, peers []p2p.Peer) error { mt.cancel() return ctx.Err() }) - mt.syncBase.EXPECT().wait().Times(2) + mt.syncBase.EXPECT().Wait().Times(2) mt.clock.BlockUntilContext(ctx, 1) mt.clock.Advance(time.Minute) require.ErrorIs(t, mt.eg.Wait(), context.Canceled) diff --git a/hashsync/rangesync.go b/sync2/hashsync/rangesync.go similarity index 97% rename from hashsync/rangesync.go rename to sync2/hashsync/rangesync.go index 7dcf2d197b..40efc9cc3f 100644 --- a/hashsync/rangesync.go +++ b/sync2/hashsync/rangesync.go @@ -10,9 +10,9 @@ import ( ) const ( - defaultMaxSendRange = 16 - defaultItemChunkSize = 16 - defaultSampleSize = 200 + DefaultMaxSendRange = 16 + DefaultItemChunkSize = 16 + DefaultSampleSize = 200 maxSampleSize = 1000 ) @@ -122,21 +122,21 @@ type Conduit interface { ShortenKey(k Ordered) Ordered } -type Option func(r *RangeSetReconciler) +type RangeSetReconcilerOption func(r *RangeSetReconciler) -func WithMaxSendRange(n int) Option { +func WithMaxSendRange(n int) RangeSetReconcilerOption { return func(r *RangeSetReconciler) { r.maxSendRange = n } } -func WithItemChunkSize(n int) Option { +func WithItemChunkSize(n int) RangeSetReconcilerOption { return func(r *RangeSetReconciler) { r.itemChunkSize = n } } -func WithSampleSize(s int) Option { +func WithSampleSize(s int) RangeSetReconcilerOption { return func(r *RangeSetReconciler) { r.sampleSize = s } @@ -155,12 +155,12 @@ type RangeSetReconciler struct { sampleSize int } -func NewRangeSetReconciler(is ItemStore, opts ...Option) *RangeSetReconciler { +func NewRangeSetReconciler(is ItemStore, opts ...RangeSetReconcilerOption) *RangeSetReconciler { rsr := &RangeSetReconciler{ is: is, - maxSendRange: defaultMaxSendRange, - itemChunkSize: defaultItemChunkSize, - sampleSize: defaultSampleSize, + maxSendRange: DefaultMaxSendRange, + itemChunkSize: DefaultItemChunkSize, + sampleSize: DefaultSampleSize, } for _, opt := range opts { opt(rsr) diff --git a/hashsync/rangesync_test.go b/sync2/hashsync/rangesync_test.go similarity index 100% rename from hashsync/rangesync_test.go rename to sync2/hashsync/rangesync_test.go diff --git a/sync2/hashsync/setsyncbase.go b/sync2/hashsync/setsyncbase.go new file mode 100644 index 0000000000..869a5a57c4 --- /dev/null +++ b/sync2/hashsync/setsyncbase.go @@ -0,0 +1,136 @@ +package hashsync + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + + "github.com/spacemeshos/go-spacemesh/p2p" + "golang.org/x/sync/singleflight" + + "github.com/spacemeshos/go-spacemesh/common/types" +) + +type SyncKeyHandler func(ctx context.Context, k Ordered, peer p2p.Peer) error + +type SetSyncBase struct { + sync.Mutex + ps PairwiseSyncer + is ItemStore + handler SyncKeyHandler + waiting []<-chan singleflight.Result + g singleflight.Group +} + +var _ SyncBase = &SetSyncBase{} + +func NewSetSyncBase(ps PairwiseSyncer, is ItemStore, handler SyncKeyHandler) *SetSyncBase { + return &SetSyncBase{ + ps: ps, + is: is, + handler: handler, + } +} + +// Count implements syncBase. +func (ssb *SetSyncBase) Count() int { + ssb.Lock() + defer ssb.Unlock() + it := ssb.is.Min() + if it == nil { + return 0 + } + x := it.Key() + return ssb.is.GetRangeInfo(nil, x, x, -1).Count +} + +// Derive implements syncBase. +func (ssb *SetSyncBase) Derive(p p2p.Peer) Syncer { + ssb.Lock() + defer ssb.Unlock() + return &setSyncer{ + SetSyncBase: ssb, + ItemStore: ssb.is.Copy(), + p: p, + } +} + +// Probe implements syncBase. +func (ssb *SetSyncBase) Probe(ctx context.Context, p p2p.Peer) (ProbeResult, error) { + // Use a snapshot of the store to avoid holding the mutex for a long time + ssb.Lock() + is := ssb.is.Copy() + ssb.Unlock() + + return ssb.ps.Probe(ctx, p, is, nil, nil) +} + +func (ssb *SetSyncBase) acceptKey(ctx context.Context, k Ordered, p p2p.Peer) { + ssb.Lock() + defer ssb.Unlock() + key := k.(fmt.Stringer).String() + if !ssb.is.Has(k) { + ssb.waiting = append(ssb.waiting, + ssb.g.DoChan(key, func() (any, error) { + err := ssb.handler(ctx, k, p) + if err == nil { + ssb.Lock() + defer ssb.Unlock() + err = ssb.is.Add(ctx, k) + } + return key, err + })) + } +} + +func (ssb *SetSyncBase) Wait() error { + // At this point, the derived syncers should be done syncing, and we only want to + // wait for the remaining handlers to complete. In case if some syncers happen to + // be still running at this point, let's not fail too badly. + // TODO: wait for any derived running syncers here, too + ssb.Lock() + waiting := ssb.waiting + ssb.waiting = nil + ssb.Unlock() + var errs []error + for _, w := range waiting { + r := <-w + ssb.g.Forget(r.Val.(string)) + errs = append(errs, r.Err) + } + return errors.Join(errs...) +} + +type setSyncer struct { + *SetSyncBase + ItemStore + p p2p.Peer +} + +var ( + _ Syncer = &setSyncer{} + _ ItemStore = &setSyncer{} +) + +// Peer implements syncer. +func (ss *setSyncer) Peer() p2p.Peer { + return ss.p +} + +// Sync implements syncer. +func (ss *setSyncer) Sync(ctx context.Context, x, y *types.Hash32) error { + return ss.ps.SyncStore(ctx, ss.p, ss, x, y) +} + +// Serve implements Syncer +func (ss *setSyncer) Serve(ctx context.Context, req []byte, stream io.ReadWriter) error { + return ss.ps.Serve(ctx, req, stream, ss) +} + +// Add implements ItemStore. +func (ss *setSyncer) Add(ctx context.Context, k Ordered) error { + ss.acceptKey(ctx, k, ss.p) + return ss.ItemStore.Add(ctx, k) +} diff --git a/hashsync/setsyncbase_test.go b/sync2/hashsync/setsyncbase_test.go similarity index 83% rename from hashsync/setsyncbase_test.go rename to sync2/hashsync/setsyncbase_test.go index 618b222e95..333b1db695 100644 --- a/hashsync/setsyncbase_test.go +++ b/sync2/hashsync/setsyncbase_test.go @@ -16,9 +16,9 @@ import ( type setSyncBaseTester struct { *testing.T ctrl *gomock.Controller - ps *MockpairwiseSyncer + ps *MockPairwiseSyncer is *MockItemStore - ssb *setSyncBase + ssb *SetSyncBase waitMtx sync.Mutex waitChs map[Ordered]chan error doneCh chan Ordered @@ -29,7 +29,7 @@ func newSetSyncBaseTester(t *testing.T, is ItemStore) *setSyncBaseTester { st := &setSyncBaseTester{ T: t, ctrl: ctrl, - ps: NewMockpairwiseSyncer(ctrl), + ps: NewMockPairwiseSyncer(ctrl), waitChs: make(map[Ordered]chan error), doneCh: make(chan Ordered), } @@ -37,7 +37,7 @@ func newSetSyncBaseTester(t *testing.T, is ItemStore) *setSyncBaseTester { st.is = NewMockItemStore(ctrl) is = st.is } - st.ssb = newSetSyncBase(st.ps, is, func(ctx context.Context, k Ordered) error { + st.ssb = NewSetSyncBase(st.ps, is, func(ctx context.Context, k Ordered, p p2p.Peer) error { err := <-st.getWaitCh(k) st.doneCh <- k return err @@ -69,10 +69,10 @@ func (st *setSyncBaseTester) expectCopy(ctx context.Context, addedKeys ...types. func (st *setSyncBaseTester) expectSyncStore( ctx context.Context, p p2p.Peer, - ss syncer, + ss Syncer, addedKeys ...types.Hash32, ) { - st.ps.EXPECT().syncStore(ctx, p, ss, nil, nil). + st.ps.EXPECT().SyncStore(ctx, p, ss, nil, nil). DoAndReturn(func(ctx context.Context, p p2p.Peer, is ItemStore, x, y *types.Hash32) error { for _, k := range addedKeys { require.NoError(st, is.Add(ctx, k)) @@ -84,10 +84,10 @@ func (st *setSyncBaseTester) expectSyncStore( func (st *setSyncBaseTester) failToSyncStore( ctx context.Context, p p2p.Peer, - ss syncer, + ss Syncer, err error, ) { - st.ps.EXPECT().syncStore(ctx, p, ss, nil, nil). + st.ps.EXPECT().SyncStore(ctx, p, ss, nil, nil). DoAndReturn(func(ctx context.Context, p p2p.Peer, is ItemStore, x, y *types.Hash32) error { return err }) @@ -95,7 +95,7 @@ func (st *setSyncBaseTester) failToSyncStore( func (st *setSyncBaseTester) wait(count int) ([]types.Hash32, error) { var eg errgroup.Group - eg.Go(st.ssb.wait) + eg.Go(st.ssb.Wait) var handledKeys []types.Hash32 for k := range st.doneCh { handledKeys = append(handledKeys, k.(types.Hash32)) @@ -117,8 +117,8 @@ func TestSetSyncBase(t *testing.T) { Count: 42, Sim: 0.99, } - st.ps.EXPECT().probe(ctx, p2p.Peer("p1"), st.is, nil, nil).Return(expPr, nil) - pr, err := st.ssb.probe(ctx, p2p.Peer("p1")) + st.ps.EXPECT().Probe(ctx, p2p.Peer("p1"), st.is, nil, nil).Return(expPr, nil) + pr, err := st.ssb.Probe(ctx, p2p.Peer("p1")) require.NoError(t, err) require.Equal(t, expPr, pr) }) @@ -130,18 +130,18 @@ func TestSetSyncBase(t *testing.T) { addedKey := types.RandomHash() st.expectCopy(ctx, addedKey) - ss := st.ssb.derive(p2p.Peer("p1")) - require.Equal(t, p2p.Peer("p1"), ss.peer()) + ss := st.ssb.Derive(p2p.Peer("p1")) + require.Equal(t, p2p.Peer("p1"), ss.Peer()) x := types.RandomHash() y := types.RandomHash() - st.ps.EXPECT().syncStore(ctx, p2p.Peer("p1"), ss, &x, &y) - require.NoError(t, ss.sync(ctx, &x, &y)) + st.ps.EXPECT().SyncStore(ctx, p2p.Peer("p1"), ss, &x, &y) + require.NoError(t, ss.Sync(ctx, &x, &y)) st.is.EXPECT().Has(addedKey) st.is.EXPECT().Add(ctx, addedKey) st.expectSyncStore(ctx, p2p.Peer("p1"), ss, addedKey) - require.NoError(t, ss.sync(ctx, nil, nil)) + require.NoError(t, ss.Sync(ctx, nil, nil)) close(st.getWaitCh(addedKey)) handledKeys, err := st.wait(1) @@ -156,15 +156,15 @@ func TestSetSyncBase(t *testing.T) { addedKey := types.RandomHash() st.expectCopy(ctx, addedKey, addedKey, addedKey) - ss := st.ssb.derive(p2p.Peer("p1")) - require.Equal(t, p2p.Peer("p1"), ss.peer()) + ss := st.ssb.Derive(p2p.Peer("p1")) + require.Equal(t, p2p.Peer("p1"), ss.Peer()) // added just once st.is.EXPECT().Add(ctx, addedKey) for i := 0; i < 3; i++ { st.is.EXPECT().Has(addedKey) st.expectSyncStore(ctx, p2p.Peer("p1"), ss, addedKey) - require.NoError(t, ss.sync(ctx, nil, nil)) + require.NoError(t, ss.Sync(ctx, nil, nil)) } close(st.getWaitCh(addedKey)) @@ -181,15 +181,15 @@ func TestSetSyncBase(t *testing.T) { k1 := types.RandomHash() k2 := types.RandomHash() st.expectCopy(ctx, k1, k2) - ss := st.ssb.derive(p2p.Peer("p1")) - require.Equal(t, p2p.Peer("p1"), ss.peer()) + ss := st.ssb.Derive(p2p.Peer("p1")) + require.Equal(t, p2p.Peer("p1"), ss.Peer()) st.is.EXPECT().Has(k1) st.is.EXPECT().Has(k2) st.is.EXPECT().Add(ctx, k1) st.is.EXPECT().Add(ctx, k2) st.expectSyncStore(ctx, p2p.Peer("p1"), ss, k1, k2) - require.NoError(t, ss.sync(ctx, nil, nil)) + require.NoError(t, ss.Sync(ctx, nil, nil)) close(st.getWaitCh(k1)) close(st.getWaitCh(k2)) @@ -206,15 +206,15 @@ func TestSetSyncBase(t *testing.T) { k1 := types.RandomHash() k2 := types.RandomHash() st.expectCopy(ctx, k1, k2) - ss := st.ssb.derive(p2p.Peer("p1")) - require.Equal(t, p2p.Peer("p1"), ss.peer()) + ss := st.ssb.Derive(p2p.Peer("p1")) + require.Equal(t, p2p.Peer("p1"), ss.Peer()) st.is.EXPECT().Has(k1) st.is.EXPECT().Has(k2) // k1 is not propagated to syncBase due to the handler failure st.is.EXPECT().Add(ctx, k2) st.expectSyncStore(ctx, p2p.Peer("p1"), ss, k1, k2) - require.NoError(t, ss.sync(ctx, nil, nil)) + require.NoError(t, ss.Sync(ctx, nil, nil)) handlerErr := errors.New("fail") st.getWaitCh(k1) <- handlerErr close(st.getWaitCh(k2)) @@ -234,7 +234,7 @@ func TestSetSyncBase(t *testing.T) { is.Add(context.Background(), hs[0]) is.Add(context.Background(), hs[1]) st := newSetSyncBaseTester(t, is) - ss := st.ssb.derive(p2p.Peer("p1")) + ss := st.ssb.Derive(p2p.Peer("p1")) ss.(ItemStore).Add(context.Background(), hs[2]) ss.(ItemStore).Add(context.Background(), hs[3]) // syncer's cloned ItemStore has new key immediately diff --git a/hashsync/split_sync.go b/sync2/hashsync/split_sync.go similarity index 92% rename from hashsync/split_sync.go rename to sync2/hashsync/split_sync.go index 5d272a4296..b26187526d 100644 --- a/hashsync/split_sync.go +++ b/sync2/hashsync/split_sync.go @@ -17,13 +17,13 @@ import ( ) type syncResult struct { - s syncer + s Syncer err error } type splitSync struct { logger *zap.Logger - syncBase syncBase + syncBase SyncBase peers *peers.Peers syncPeers []p2p.Peer gracePeriod time.Duration @@ -36,13 +36,13 @@ type splitSync struct { numRunning int numRemaining int numPeers int - syncers []syncer + syncers []Syncer eg *errgroup.Group } func newSplitSync( logger *zap.Logger, - syncBase syncBase, + syncBase SyncBase, peers *peers.Peers, syncPeers []p2p.Peer, gracePeriod time.Duration, @@ -77,13 +77,13 @@ func (s *splitSync) nextPeer() p2p.Peer { } func (s *splitSync) startPeerSync(ctx context.Context, p p2p.Peer, sr *syncRange) { - syncer := s.syncBase.derive(p) + syncer := s.syncBase.Derive(p) sr.numSyncers++ s.numRunning++ doneCh := make(chan struct{}) s.eg.Go(func() error { defer close(doneCh) - err := syncer.sync(ctx, &sr.x, &sr.y) + err := syncer.Sync(ctx, &sr.x, &sr.y) select { case <-ctx.Done(): return ctx.Err() @@ -107,18 +107,18 @@ func (s *splitSync) startPeerSync(ctx context.Context, p p2p.Peer, sr *syncRange } func (s *splitSync) handleSyncResult(r syncResult) error { - sr, found := s.syncMap[r.s.peer()] + sr, found := s.syncMap[r.s.Peer()] if !found { panic("BUG: error in split sync syncMap handling") } s.numRunning-- - delete(s.syncMap, r.s.peer()) + delete(s.syncMap, r.s.Peer()) sr.numSyncers-- if r.err != nil { s.numPeers-- - s.failedPeers[r.s.peer()] = struct{}{} + s.failedPeers[r.s.Peer()] = struct{}{} s.logger.Debug("remove failed peer", - zap.Stringer("peer", r.s.peer()), + zap.Stringer("peer", r.s.Peer()), zap.Int("numPeers", s.numPeers), zap.Int("numRemaining", s.numRemaining), zap.Int("numRunning", s.numRunning), @@ -134,10 +134,10 @@ func (s *splitSync) handleSyncResult(r syncResult) error { } } else { sr.done = true - s.syncPeers = append(s.syncPeers, r.s.peer()) + s.syncPeers = append(s.syncPeers, r.s.Peer()) s.numRemaining-- s.logger.Debug("peer synced successfully", - zap.Stringer("peer", r.s.peer()), + zap.Stringer("peer", r.s.Peer()), zap.Int("numPeers", s.numPeers), zap.Int("numRemaining", s.numRemaining), zap.Int("numRunning", s.numRunning), diff --git a/hashsync/split_sync_test.go b/sync2/hashsync/split_sync_test.go similarity index 95% rename from hashsync/split_sync_test.go rename to sync2/hashsync/split_sync_test.go index 843e5c1dbc..e28bc45294 100644 --- a/hashsync/split_sync_test.go +++ b/sync2/hashsync/split_sync_test.go @@ -78,7 +78,7 @@ type splitSyncTester struct { fail map[hexRange]bool expPeerRanges map[hexRange]int peerRanges map[hexRange][]p2p.Peer - syncBase *MocksyncBase + syncBase *MockSyncBase peers *peers.Peers splitSync *splitSync } @@ -115,7 +115,7 @@ func newTestSplitSync(t testing.TB) *splitSyncTester { tstRanges[3]: 0, }, peerRanges: make(map[hexRange][]p2p.Peer), - syncBase: NewMocksyncBase(ctrl), + syncBase: NewMockSyncBase(ctrl), peers: peers.New(), } for n := range tst.syncPeers { @@ -125,12 +125,12 @@ func newTestSplitSync(t testing.TB) *splitSyncTester { index := index p := p tst.syncBase.EXPECT(). - derive(p). - DoAndReturn(func(peer p2p.Peer) syncer { - s := NewMocksyncer(ctrl) - s.EXPECT().peer().Return(p).AnyTimes() + Derive(p). + DoAndReturn(func(peer p2p.Peer) Syncer { + s := NewMockSyncer(ctrl) + s.EXPECT().Peer().Return(p).AnyTimes() s.EXPECT(). - sync(gomock.Any(), gomock.Any(), gomock.Any()). + Sync(gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(ctx context.Context, x, y *types.Hash32) error { tst.mtx.Lock() defer tst.mtx.Unlock() diff --git a/hashsync/sync_queue.go b/sync2/hashsync/sync_queue.go similarity index 100% rename from hashsync/sync_queue.go rename to sync2/hashsync/sync_queue.go diff --git a/hashsync/sync_queue_test.go b/sync2/hashsync/sync_queue_test.go similarity index 100% rename from hashsync/sync_queue_test.go rename to sync2/hashsync/sync_queue_test.go diff --git a/hashsync/sync_tree.go b/sync2/hashsync/sync_tree.go similarity index 99% rename from hashsync/sync_tree.go rename to sync2/hashsync/sync_tree.go index ff995f743b..690be069d2 100644 --- a/hashsync/sync_tree.go +++ b/sync2/hashsync/sync_tree.go @@ -512,7 +512,7 @@ func (st *syncTree) Set(k Ordered, v any) { func (st *syncTree) add(k Ordered, v any, set bool) { st.rootMtx.Lock() - st.rootMtx.Unlock() + defer st.rootMtx.Unlock() st.root = st.insert(st.root, k, v, true, set) if st.root.flags&flagBlack == 0 { st.root = st.ensureCloned(st.root) @@ -622,10 +622,12 @@ func (st *syncTree) findGTENode(ptr *syncTreePointer, x Ordered) bool { // or equal to x, we can find them on the right if ptr.node.right == nil { // sn.Max lied to us + // TODO: QQQQQ: this bug is being hit panic("BUG: SyncTreeNode: x > sn.Max but no right branch") } // Avoid endless recursion in case of a bug if x.Compare(ptr.node.right.max) > 0 { + // TODO: QQQQQ: this bug is being hit panic("BUG: SyncTreeNode: inconsistent Max on the right branch") } ptr.right() diff --git a/hashsync/sync_tree_store.go b/sync2/hashsync/sync_tree_store.go similarity index 100% rename from hashsync/sync_tree_store.go rename to sync2/hashsync/sync_tree_store.go diff --git a/hashsync/sync_tree_test.go b/sync2/hashsync/sync_tree_test.go similarity index 92% rename from hashsync/sync_tree_test.go rename to sync2/hashsync/sync_tree_test.go index aaa3b526cf..1eb60a06dc 100644 --- a/hashsync/sync_tree_test.go +++ b/sync2/hashsync/sync_tree_test.go @@ -5,8 +5,10 @@ import ( "fmt" "math/rand" "slices" + "sync" "testing" + "github.com/spacemeshos/go-spacemesh/common/types" "github.com/stretchr/testify/require" ) @@ -510,3 +512,57 @@ func TestTreeValues(t *testing.T) { require.True(t, found) require.Equal(t, 456, v) } + +func TestParallelAddition(t *testing.T) { + for i := 0; i < 10; i++ { + const ( + nInitial = 10000 + nAdd = 1000 + nSets = 100 + ) + srcTree := NewSyncTree(Hash32To12Xor{}) + initialHashes := make([]types.Hash32, nInitial) + for n := range initialHashes { + h := types.RandomHash() + initialHashes[n] = h + srcTree.Add(h) + } + type set struct { + added []types.Hash32 + tree SyncTree + } + sets := make([]*set, nSets) + for n := range sets { + sets[n] = &set{} + } + sets[0].tree = srcTree + var wg sync.WaitGroup + for n, s := range sets { + wg.Add(1) + go func() { + defer wg.Done() + if n > 0 { + s.tree = srcTree.Copy() + } + s.added = make([]types.Hash32, nAdd) + for n := range s.added { + h := types.RandomHash() + s.added[n] = h + s.tree.Add(h) + } + }() + } + wg.Wait() + for _, s := range sets { + items := make(map[types.Hash32]struct{}, nInitial+nAdd) + for ptr := s.tree.Min(); ptr.Key() != nil; ptr.Next() { + items[ptr.Key().(types.Hash32)] = struct{}{} + } + require.GreaterOrEqual(t, len(items), nInitial+nAdd) + for _, k := range s.added { + _, found := items[k] // faster than require.Contains + require.True(t, found) + } + } + } +} diff --git a/hashsync/wire_types.go b/sync2/hashsync/wire_types.go similarity index 100% rename from hashsync/wire_types.go rename to sync2/hashsync/wire_types.go diff --git a/hashsync/wire_types_scale.go b/sync2/hashsync/wire_types_scale.go similarity index 100% rename from hashsync/wire_types_scale.go rename to sync2/hashsync/wire_types_scale.go diff --git a/hashsync/xorsync.go b/sync2/hashsync/xorsync.go similarity index 100% rename from hashsync/xorsync.go rename to sync2/hashsync/xorsync.go diff --git a/hashsync/xorsync_test.go b/sync2/hashsync/xorsync_test.go similarity index 95% rename from hashsync/xorsync_test.go rename to sync2/hashsync/xorsync_test.go index c8008f8f6d..aad8ccdbce 100644 --- a/hashsync/xorsync_test.go +++ b/sync2/hashsync/xorsync_test.go @@ -74,8 +74,8 @@ type xorSyncTestConfig struct { maxNumSpecificB int } -func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB ItemStore, numSpecific int, opts []Option) bool) { - opts := []Option{ +func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB ItemStore, numSpecific int, opts []RangeSetReconcilerOption) bool) { + opts := []RangeSetReconcilerOption{ WithMaxSendRange(cfg.maxSendRange), } numSpecificA := rand.Intn(cfg.maxNumSpecificA+1-cfg.minNumSpecificA) + cfg.minNumSpecificA @@ -125,7 +125,7 @@ func TestBigSyncHash32(t *testing.T) { minNumSpecificB: 4, maxNumSpecificB: 100, } - verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []Option) bool { + verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []RangeSetReconcilerOption) bool { syncA := NewRangeSetReconciler(storeA, opts...) syncB := NewRangeSetReconciler(storeB, opts...) nRounds, nMsg, nItems := runSync(t, syncA, syncB, 100) diff --git a/sync2/p2p.go b/sync2/p2p.go new file mode 100644 index 0000000000..207d1dc0ba --- /dev/null +++ b/sync2/p2p.go @@ -0,0 +1,134 @@ +package hashsync + +import ( + "context" + "errors" + "io" + "sync" + "sync/atomic" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + + "github.com/spacemeshos/go-spacemesh/fetch/peers" + "github.com/spacemeshos/go-spacemesh/p2p/server" + "github.com/spacemeshos/go-spacemesh/sync2/hashsync" +) + +type Config struct { + MaxSendRange int `mapstructure:"max-send-range"` + SampleSize int `mapstructure:"sample-size"` + Timeout time.Duration `mapstructure:"timeout"` + SyncPeerCount int `mapstructure:"sync-peer-count"` + MinSplitSyncCount int `mapstructure:"min-split-sync-count"` + MaxFullDiff int `mapstructure:"max-full-diff"` + SyncInterval time.Duration `mapstructure:"sync-interval"` + NoPeersRecheckInterval time.Duration `mapstructure:"no-peers-recheck-interval"` + MinSplitSyncPeers int `mapstructure:"min-split-sync-peers"` + MinCompleteFraction float64 `mapstructure:"min-complete-fraction"` + SplitSyncGracePeriod time.Duration `mapstructure:"split-sync-grace-period"` +} + +func DefaultConfig() Config { + return Config{ + MaxSendRange: hashsync.DefaultMaxSendRange, + SampleSize: hashsync.DefaultSampleSize, + Timeout: 10 * time.Second, + SyncPeerCount: 20, + MinSplitSyncPeers: 2, + MinSplitSyncCount: 1000, + MaxFullDiff: 10000, + SyncInterval: 5 * time.Minute, + MinCompleteFraction: 0.5, + SplitSyncGracePeriod: time.Minute, + NoPeersRecheckInterval: 30 * time.Second, + } +} + +type P2PHashSync struct { + logger *zap.Logger + h host.Host + is hashsync.ItemStore + syncBase hashsync.SyncBase + reconciler *hashsync.MultiPeerReconciler + srv *server.Server + cancel context.CancelFunc + eg errgroup.Group + start sync.Once + running atomic.Bool +} + +func NewP2PHashSync( + logger *zap.Logger, + h host.Host, + proto string, + peers *peers.Peers, + handler hashsync.SyncKeyHandler, + cfg Config, +) *P2PHashSync { + s := &P2PHashSync{ + logger: logger, + h: h, + is: hashsync.NewSyncTreeStore(hashsync.Hash32To12Xor{}), + } + s.srv = server.New(h, proto, s.handle, + server.WithTimeout(cfg.Timeout), + server.WithLog(logger)) + ps := hashsync.NewPairwiseStoreSyncer(s.srv, []hashsync.RangeSetReconcilerOption{ + hashsync.WithMaxSendRange(cfg.MaxSendRange), + hashsync.WithSampleSize(cfg.SampleSize), + }) + s.syncBase = hashsync.NewSetSyncBase(ps, s.is, handler) + s.reconciler = hashsync.NewMultiPeerReconciler( + s.syncBase, peers, + hashsync.WithLogger(logger), + hashsync.WithSyncPeerCount(cfg.SyncPeerCount), + hashsync.WithMinSplitSyncPeers(cfg.MinSplitSyncPeers), + hashsync.WithMinSplitSyncCount(cfg.MinSplitSyncCount), + hashsync.WithMaxFullDiff(cfg.MaxFullDiff), + hashsync.WithSyncInterval(cfg.SyncInterval), + hashsync.WithMinCompleteFraction(cfg.MinCompleteFraction), + hashsync.WithSplitSyncGracePeriod(time.Minute), + hashsync.WithNoPeersRecheckInterval(cfg.NoPeersRecheckInterval)) + return s +} + +func (s *P2PHashSync) handle(ctx context.Context, req []byte, stream io.ReadWriter) error { + if !s.running.Load() { + return errors.New("sync server not running") + } + peer, found := server.ContextPeerID(ctx) + if !found { + panic("BUG: no peer ID found in the handler") + } + // We derive a dedicated Syncer for the peer being served to pass all the received + // items through the handler before adding them to the main ItemStore + syncer := s.syncBase.Derive(peer) + return syncer.Serve(ctx, req, stream) +} + +func (s *P2PHashSync) ItemStore() hashsync.ItemStore { + return s.is +} + +func (s *P2PHashSync) Start() { + s.start.Do(func() { + var ctx context.Context + ctx, s.cancel = context.WithCancel(context.Background()) + s.eg.Go(func() error { return s.srv.Run(ctx) }) + s.eg.Go(func() error { return s.reconciler.Run(ctx) }) + s.running.Store(true) + }) +} + +func (s *P2PHashSync) Stop() { + s.running.Store(false) + if s.cancel != nil { + s.cancel() + } + if err := s.eg.Wait(); err != nil && !errors.Is(err, context.Canceled) { + s.logger.Error("P2PHashSync terminated with an error", zap.Error(err)) + } +} diff --git a/sync2/p2p_test.go b/sync2/p2p_test.go new file mode 100644 index 0000000000..de1d4449c7 --- /dev/null +++ b/sync2/p2p_test.go @@ -0,0 +1,102 @@ +package hashsync + +import ( + "context" + "sync" + "testing" + "time" + + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/fetch/peers" + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/sync2/hashsync" +) + +func TestP2P(t *testing.T) { + const ( + numNodes = 4 + numHashes = 100 + ) + logger := zaptest.NewLogger(t) + mesh, err := mocknet.FullMeshConnected(numNodes) + require.NoError(t, err) + type addedKey struct { + fromPeer, toPeer p2p.Peer + key hashsync.Ordered + } + var mtx sync.Mutex + synced := make(map[addedKey]struct{}) + hs := make([]*P2PHashSync, numNodes) + initialSet := make([]types.Hash32, numHashes) + for n := range initialSet { + initialSet[n] = types.RandomHash() + } + for n := range hs { + ps := peers.New() + for m := 0; m < numNodes; m++ { + if m != n { + ps.Add(mesh.Hosts()[m].ID()) + } + } + cfg := DefaultConfig() + cfg.SyncInterval = 100 * time.Millisecond + host := mesh.Hosts()[n] + handler := func(ctx context.Context, k hashsync.Ordered, peer p2p.Peer) error { + mtx.Lock() + defer mtx.Unlock() + ak := addedKey{ + fromPeer: peer, + toPeer: host.ID(), + key: k, + } + synced[ak] = struct{}{} + return nil + } + hs[n] = NewP2PHashSync(logger, host, "sync2test", ps, handler, cfg) + if n == 0 { + is := hs[n].ItemStore() + for _, h := range initialSet { + is.Add(context.Background(), h) + } + } + hs[n].Start() + } + + require.Eventually(t, func() bool { + for _, hsync := range hs { + // use a snapshot to avoid races + is := hsync.ItemStore().Copy() + it := is.Min() + if it == nil { + return false + } + if is.GetRangeInfo(nil, it.Key(), it.Key(), -1).Count < numHashes { + return false + } + } + return true + }, 30*time.Second, 300*time.Millisecond) + + for _, hsync := range hs { + hsync.Stop() + min := hsync.ItemStore().Min() + it := hsync.ItemStore().Min() + require.NotNil(t, it) + var actualItems []types.Hash32 + for { + k := it.Key().(types.Hash32) + actualItems = append(actualItems, k) + it.Next() + if it.Equal(min) { + break + } + } + require.ElementsMatch(t, initialSet, actualItems) + } +} + +// TODO: make sure all the keys have passed through the handler before being added From ef020242e100cbe1629afbd2abd1e7e9b460246e Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 23 May 2024 17:10:11 +0400 Subject: [PATCH 28/76] sync2: fixup --- common/types/hashes.go | 17 ++++++++++++----- sync2/p2p.go | 2 +- sync2/p2p_test.go | 2 +- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/common/types/hashes.go b/common/types/hashes.go index 6ebbfe5e2a..6638489083 100644 --- a/common/types/hashes.go +++ b/common/types/hashes.go @@ -16,6 +16,7 @@ import ( const ( Hash32Length = 32 Hash20Length = 20 + Hash12Length = 12 ) var ( @@ -29,11 +30,8 @@ type Hash32 [Hash32Length]byte // Hash20 represents the 20-byte blake3 hash of arbitrary data. type Hash20 [Hash20Length]byte -// String implements the stringer interface and is used also by the logger when -// doing full logging into a file. -func (h Hash12) String() string { - return util.Encode(h[:5]) -} +// Hash12 represents the 12-byte hash used for sync +type Hash12 [Hash12Length]byte // Bytes gets the byte representation of the underlying hash. func (h Hash20) Bytes() []byte { return h[:] } @@ -93,6 +91,15 @@ func (h Hash20) ToHash32() (h32 Hash32) { return } +// String implements the stringer interface and is used also by the logger when +// doing full logging into a file. +func (h Hash12) String() string { + return util.Encode(h[:5]) +} + +// Field returns a log field. Implements the LoggableField interface. +func (h Hash12) Field() log.Field { return log.String("hash", hex.EncodeToString(h[:])) } + // CalcProposalsHash32 returns the 32-byte blake3 sum of the IDs, sorted in lexicographic order. The pre-image is // prefixed with additionalBytes. func CalcProposalsHash32(view []ProposalID, additionalBytes []byte) Hash32 { diff --git a/sync2/p2p.go b/sync2/p2p.go index 207d1dc0ba..fca6a84665 100644 --- a/sync2/p2p.go +++ b/sync2/p2p.go @@ -1,4 +1,4 @@ -package hashsync +package sync2 import ( "context" diff --git a/sync2/p2p_test.go b/sync2/p2p_test.go index de1d4449c7..35665a2fd2 100644 --- a/sync2/p2p_test.go +++ b/sync2/p2p_test.go @@ -1,4 +1,4 @@ -package hashsync +package sync2 import ( "context" From ad768820c972b6c9cf41fd22a33578f950979f09 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 18 Jun 2024 08:56:29 +0400 Subject: [PATCH 29/76] wip --- sync2/dbsync/dbsync.go | 339 ++++++++++++++++++++++++++++++++++++ sync2/dbsync/dbsync_test.go | 317 +++++++++++++++++++++++++++++++++ 2 files changed, 656 insertions(+) create mode 100644 sync2/dbsync/dbsync.go create mode 100644 sync2/dbsync/dbsync_test.go diff --git a/sync2/dbsync/dbsync.go b/sync2/dbsync/dbsync.go new file mode 100644 index 0000000000..1d9a4db167 --- /dev/null +++ b/sync2/dbsync/dbsync.go @@ -0,0 +1,339 @@ +package dbsync + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "fmt" + "io" + "math/bits" + "os" + + "golang.org/x/exp/slices" +) + +const ( + fingerprintBytes = 12 + cachedBits = 24 + cachedSize = 1 << cachedBits + cacheMask = cachedSize - 1 + maxIDBytes = 32 + bit63 = 1 << 63 +) + +type fingerprint [fingerprintBytes]byte + +func (fp fingerprint) String() string { + return hex.EncodeToString(fp[:]) +} + +func (fp *fingerprint) update(h []byte) { + for n := range *fp { + (*fp)[n] ^= h[n] + } +} + +func hexToFingerprint(s string) fingerprint { + b, err := hex.DecodeString(s) + if err != nil { + panic("bad hex fingerprint: " + err.Error()) + } + var fp fingerprint + if len(b) != len(fp) { + panic("bad hex fingerprint") + } + copy(fp[:], b) + return fp +} + +// const ( +// nodeFlagLeaf = 1 << 32 +// nodeFlagChanged = 1 << 31 +// ) + +// NOTE: all leafs are on the last level + +type node struct { + // 16-byte structure with alignment + // The cache is 512 MiB per 1<<24 (16777216) IDs + fp fingerprint + count uint32 +} + +type cacheIndex uint32 + +const ( + prefixLenBits = 6 + prefixLenMask = 1<> prefixLenBits) +} + +func (p prefix) left() prefix { + l := uint64(p) & prefixLenMask + if l == maxPrefixLen { + panic("BUG: max prefix len reached") + } + return prefix((uint64(p)&prefixBitMask)<<1 + l + 1) +} + +func (p prefix) right() prefix { + return p.left() + (1 << prefixLenBits) +} + +func (p prefix) cacheIndex() (cacheIndex, bool) { + if l := p.len(); l <= cachedBits { + // Notation: prefix(cacheIndex) + // + // empty(0) + // / \ + // / \ + // / \ + // 0(1) 1(2) + // / \ / \ + // / \ / \ + // 00(3) 01(4) 10(5) 11(6) + + // indexing starts at 1 + // left: n = n*2 + // right: n = n*2+1 + // but in the end we substract 1 to make it 0-based again + + return cacheIndex(p.bits() | (1 << l) - 1), true + } + return 0, false +} + +func (p prefix) String() string { + if p.len() == 0 { + return "<0>" + } + b := fmt.Sprintf("%064b", p.bits()) + return fmt.Sprintf("<%d:%s>", p.len(), b[64-p.len():]) +} + +func load64(h []byte) uint64 { + return binary.BigEndian.Uint64(h[:8]) +} + +func hashPrefix(h []byte, nbits int) prefix { + if nbits < 0 || nbits > maxPrefixLen { + panic("BUG: bad prefix length") + } + if nbits == 0 { + return 0 + } + v := load64(h) + return prefix((v>>(64-nbits-prefixLenBits))&prefixBitMask + uint64(nbits)) +} + +func preFirst0(h []byte) prefix { + l := min(maxPrefixLen, bits.LeadingZeros64(^load64(h))) + return hashPrefix(h, l) +} + +func preFirst1(h []byte) prefix { + l := min(maxPrefixLen, bits.LeadingZeros64(load64(h))) + return hashPrefix(h, l) +} + +func commonPrefix(a, b []byte) prefix { + v1 := load64(a) + v2 := load64(b) + l := uint64(min(maxPrefixLen, bits.LeadingZeros64(v1^v2))) + return prefix((v1>>(64-l))< cachedBits: + panic("BUG: prefix too long") + case p.len() == cachedBits: + return + case v&bit63 == 0: + p = p.left() + default: + p = p.right() + } + v <<= 1 + } +} + +func (ft *fpTree) aggregateLeft(v uint64, p prefix, r *aggResult) { + bit := v & (1 << (63 - p.len())) + switch { + case p.len() >= cachedBits: + r.tails = append(r.tails, p) + fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: add tail\n", v, p) + case bit == 0: + idx, gotIdx := p.right().cacheIndex() + if !gotIdx { + panic("BUG: no idx") + } + r.update(&ft.nodes[idx]) + fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: 0 -> add count %d fp %s\n", v, p, + ft.nodes[idx].count, ft.nodes[idx].fp) + ft.aggregateLeft(v, p.left(), r) + default: + fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: 1 -> go right\n", v, p) + ft.aggregateLeft(v, p.right(), r) + } +} + +func (ft *fpTree) aggregateRight(v uint64, p prefix, r *aggResult) { + bit := v & (1 << (63 - p.len())) + switch { + case p.len() >= cachedBits: + r.tails = append(r.tails, p) + fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: add tail\n", v, p) + case bit == 0: + fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: 0 -> go left\n", v, p) + ft.aggregateRight(v, p.left(), r) + default: + idx, gotIdx := p.left().cacheIndex() + if !gotIdx { + panic("BUG: no idx") + } + r.update(&ft.nodes[idx]) + fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: 1 -> add count %d fp %s + go right\n", v, p, + ft.nodes[idx].count, ft.nodes[idx].fp) + ft.aggregateRight(v, p.right(), r) + } +} + +func (ft *fpTree) aggregateInterval(x, y []byte) aggResult { + var r aggResult + r.itype = bytes.Compare(x, y) + switch { + case r.itype == 0: + // the whole set + r.update(&ft.nodes[0]) + case r.itype < 0: + // "proper" interval: [x; lca); (lca; y) + p := commonPrefix(x, y) + ft.aggregateLeft(load64(x), p.left(), &r) + ft.aggregateRight(load64(y), p.right(), &r) + default: + // inverse interval: [min; y); [x; max] + ft.aggregateRight(load64(y), preFirst1(y), &r) + ft.aggregateLeft(load64(x), preFirst0(x), &r) + } + return r +} + +func (ft *fpTree) dumpNode(w io.Writer, p prefix, indent, dir string) { + idx, gotIdx := p.cacheIndex() + if !gotIdx { + return + } + c := ft.nodes[idx].count + if c == 0 { + return + } + fmt.Fprintf(w, "%s%s%s %d\n", indent, dir, ft.nodes[idx].fp, c) + if c > 1 { + indent += " " + ft.dumpNode(w, p.left(), indent, "l: ") + ft.dumpNode(w, p.right(), indent, "r: ") + } +} + +func (ft *fpTree) dump(w io.Writer) { + if ft.nodes[0].count == 0 { + fmt.Fprintln(w, "empty tree") + } else { + ft.dumpNode(w, 0, "", "") + } +} + +type inMemFPTree struct { + tree fpTree + ids [cachedSize][][]byte +} + +func (mft *inMemFPTree) addHash(h []byte) { + mft.tree.addHash(h) + idx := load64(h) >> (64 - cachedBits) + s := mft.ids[idx] + n := slices.IndexFunc(s, func(cur []byte) bool { + return bytes.Compare(cur, h) > 0 + }) + if n < 0 { + mft.ids[idx] = append(s, h) + } else { + mft.ids[idx] = slices.Insert(s, n, h) + } +} + +func (mft *inMemFPTree) aggregateInterval(x, y []byte) fpResult { + r := mft.tree.aggregateInterval(x, y) + for _, t := range r.tails { + if t.len() != cachedBits { + panic("BUG: inMemFPTree.aggregateInterval: bad prefix bit count") + } + ids := mft.ids[t.bits()] + for _, id := range ids { + // FIXME: this can be optimized as the IDs are ordered + if idWithinInterval(id, x, y, r.itype) { + r.fp.update(id) + r.count++ + } + } + } + return fpResult{fp: r.fp, count: r.count} +} + +func idWithinInterval(id, x, y []byte, itype int) bool { + switch itype { + case 0: + return true + case -1: + return bytes.Compare(id, x) >= 0 && bytes.Compare(id, y) < 0 + default: + return bytes.Compare(id, y) < 0 || bytes.Compare(id, x) >= 0 + } +} + +// TBD: perhaps use json-based SELECTs +// TBD: extra cache for after-24bit entries +// TBD: benchmark 24-bit limit (not going beyond the cache) +// TBD: optimize, get rid of binary.BigEndian.* diff --git a/sync2/dbsync/dbsync_test.go b/sync2/dbsync/dbsync_test.go new file mode 100644 index 0000000000..ba05a3e9e3 --- /dev/null +++ b/sync2/dbsync/dbsync_test.go @@ -0,0 +1,317 @@ +package dbsync + +import ( + "math/bits" + "slices" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" +) + +func TestPrefix(t *testing.T) { + for _, tc := range []struct { + p prefix + s string + bits uint64 + len int + left prefix + right prefix + gotCacheIndex bool + cacheIndex cacheIndex + }{ + { + p: 0, + s: "<0>", + len: 0, + bits: 0, + left: 0b0_000001, + right: 0b1_000001, + gotCacheIndex: true, + cacheIndex: 0, + }, + { + p: 0b0_000001, + s: "<1:0>", + len: 1, + bits: 0, + left: 0b00_000010, + right: 0b01_000010, + gotCacheIndex: true, + cacheIndex: 1, + }, + { + p: 0b1_000001, + s: "<1:1>", + len: 1, + bits: 1, + left: 0b10_000010, + right: 0b11_000010, + gotCacheIndex: true, + cacheIndex: 2, + }, + { + p: 0b00_000010, + s: "<2:00>", + len: 2, + bits: 0, + left: 0b000_000011, + right: 0b001_000011, + gotCacheIndex: true, + cacheIndex: 3, + }, + { + p: 0b01_000010, + s: "<2:01>", + len: 2, + bits: 1, + left: 0b010_000011, + right: 0b011_000011, + gotCacheIndex: true, + cacheIndex: 4, + }, + { + p: 0b10_000010, + s: "<2:10>", + len: 2, + bits: 2, + left: 0b100_000011, + right: 0b101_000011, + gotCacheIndex: true, + cacheIndex: 5, + }, + { + p: 0b11_000010, + s: "<2:11>", + len: 2, + bits: 3, + left: 0b110_000011, + right: 0b111_000011, + gotCacheIndex: true, + cacheIndex: 6, + }, + { + p: 0x3fffffd8, + s: "<24:111111111111111111111111>", + len: 24, + bits: 0xffffff, + left: 0x7fffff99, + right: 0x7fffffd9, + gotCacheIndex: true, + cacheIndex: 0x1fffffe, + }, + { + p: 0x7fffff99, + s: "<25:1111111111111111111111110>", + len: 25, + bits: 0x1fffffe, + left: 0xffffff1a, + right: 0xffffff5a, + gotCacheIndex: false, // len > 24 + }, + } { + require.Equal(t, tc.s, tc.p.String()) + require.Equal(t, tc.bits, tc.p.bits()) + require.Equal(t, tc.len, tc.p.len()) + require.Equal(t, tc.left, tc.p.left()) + require.Equal(t, tc.right, tc.p.right()) + idx, gotIdx := tc.p.cacheIndex() + require.Equal(t, tc.gotCacheIndex, gotIdx) + if gotIdx { + require.Equal(t, tc.cacheIndex, idx) + } + } +} + +func TestHashPrefix(t *testing.T) { + for _, tc := range []struct { + h string + l int + p prefix + preFirst0 prefix + preFirst1 prefix + }{ + { + h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", + l: 0, + p: 0, + preFirst0: 0b1_000001, + preFirst1: 0, + }, + { + h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", + l: 1, + p: 0b1_000001, + preFirst0: 0b1_000001, + preFirst1: 0, + }, + { + h: "2BCDEF1234567890000000000000000000000000000000000000000000000000", + l: 1, + p: 0b0_000001, + preFirst0: 0, + preFirst1: 0b00_000010, + }, + { + h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", + l: 4, + p: 0b1010_000100, + preFirst0: 0b1_000001, + preFirst1: 0, + }, + { + h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", + l: 57, + p: 0x55e6f7891a2b3c79, + preFirst0: 0b1_000001, + preFirst1: 0, + }, + { + h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", + l: 58, + p: 0xabcdef12345678ba, + preFirst0: 0b1_000001, + preFirst1: 0, + }, + { + h: "0000000000000000000000000000000000000000000000000000000000000000", + l: 0, + p: 0, + preFirst0: 0, + preFirst1: 58, + }, + { + h: "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", + l: 0, + p: 0, + preFirst0: 0xfffffffffffffffa, + preFirst1: 0, + }, + } { + h := types.HexToHash32(tc.h) + require.Equal(t, tc.p, hashPrefix(h[:], tc.l), "hash prefix: h %s l %d", tc.h, tc.l) + require.Equal(t, tc.preFirst0, preFirst0(h[:]), "preFirst0: h %s", tc.h) + require.Equal(t, tc.preFirst1, preFirst1(h[:]), "preFirst1: h %s", tc.h) + } +} + +func TestCommonPrefix(t *testing.T) { + for _, tc := range []struct { + a, b string + p prefix + }{ + { + a: "0000000000000000000000000000000000000000000000000000000000000000", + b: "8000000000000000000000000000000000000000000000000000000000000000", + p: 0, + }, + { + a: "A000000000000000000000000000000000000000000000000000000000000000", + b: "8000000000000000000000000000000000000000000000000000000000000000", + p: 0b10_000010, + }, + { + a: "A000000000000000000000000000000000000000000000000000000000000000", + b: "A800000000000000000000000000000000000000000000000000000000000000", + p: 0b1010_000100, + }, + { + a: "ABCDEF1234567890000000000000000000000000000000000000000000000000", + b: "ABCDEF1234567800000000000000000000000000000000000000000000000000", + p: 0x2af37bc48d159e38, + }, + { + a: "ABCDEF1234567890123456789ABCDEF000000000000000000000000000000000", + b: "ABCDEF1234567890123456789ABCDEF000000000000000000000000000000000", + p: 0xabcdef12345678ba, + }, + } { + a := types.HexToHash32(tc.a) + b := types.HexToHash32(tc.b) + require.Equal(t, tc.p, commonPrefix(a[:], b[:])) + } +} + +const dbFile = "/Users/ivan4th/Library/Application Support/Spacemesh/node-data/7c8cef2b/state.sql" + +func TestRmme(t *testing.T) { + t.Skip("slow tmp test") + counts := make(map[uint64]uint64) + prefLens := make(map[int]int) + db, err := statesql.Open("file:" + dbFile) + require.NoError(t, err) + defer db.Close() + var prev uint64 + first := true + // where epoch=23 + _, err = db.Exec("select id from atxs order by id", nil, func(stmt *sql.Statement) bool { + var id types.Hash32 + stmt.ColumnBytes(0, id[:]) + v := load64(id[:]) + counts[v>>40]++ + if first { + first = false + } else { + prefLens[bits.LeadingZeros64(prev^v)]++ + } + prev = v + return true + }) + require.NoError(t, err) + countFreq := make(map[uint64]int) + for _, c := range counts { + countFreq[c]++ + } + ks := maps.Keys(countFreq) + slices.Sort(ks) + for _, c := range ks { + t.Logf("%d: %d times", c, countFreq[c]) + } + pls := maps.Keys(prefLens) + slices.Sort(pls) + for _, pl := range pls { + t.Logf("pl %d: %d times", pl, prefLens[pl]) + } +} + +func TestInMemFPTree(t *testing.T) { + var mft inMemFPTree + var hs []types.Hash32 + for _, hex := range []string{ + "0000000000000000000000000000000000000000000000000000000000000000", + "123456789ABCDEF0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "ABCDEF1234567890000000000000000000000000000000000000000000000000", + } { + h := types.HexToHash32(hex) + hs = append(hs, h) + mft.addHash(h[:]) + } + var sb strings.Builder + mft.tree.dump(&sb) + t.Logf("QQQQQ: tree:\n%s", sb.String()) + require.Equal(t, hexToFingerprint("642464b773377bbddddddddd"), mft.tree.nodes[0].fp) + require.Equal(t, fpResult{ + fp: hexToFingerprint("642464b773377bbddddddddd"), + count: 5, + }, mft.aggregateInterval(hs[0][:], hs[0][:])) + require.Equal(t, fpResult{ + fp: hexToFingerprint("642464b773377bbddddddddd"), + count: 5, + }, mft.aggregateInterval(hs[4][:], hs[4][:])) + require.Equal(t, fpResult{ + fp: hexToFingerprint("000000000000000000000000"), + count: 1, + }, mft.aggregateInterval(hs[0][:], hs[1][:])) + require.Equal(t, fpResult{ + fp: hexToFingerprint("cfe98ba54761032ddddddddd"), + count: 3, + }, mft.aggregateInterval(hs[1][:], hs[4][:])) + // TBD: test reverse range +} From 4450fcefce36aecc0d1caee592ebe85c568cd637 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 18 Jun 2024 16:14:47 +0400 Subject: [PATCH 30/76] wip2 --- sync2/dbsync/dbsync.go | 207 ++++++++++++++++++++++++++++-------- sync2/dbsync/dbsync_test.go | 1 + 2 files changed, 162 insertions(+), 46 deletions(-) diff --git a/sync2/dbsync/dbsync.go b/sync2/dbsync/dbsync.go index 1d9a4db167..87ac0a4e4c 100644 --- a/sync2/dbsync/dbsync.go +++ b/sync2/dbsync/dbsync.go @@ -8,6 +8,7 @@ import ( "io" "math/bits" "os" + "strconv" "golang.org/x/exp/slices" ) @@ -46,18 +47,33 @@ func hexToFingerprint(s string) fingerprint { return fp } -// const ( -// nodeFlagLeaf = 1 << 32 -// nodeFlagChanged = 1 << 31 -// ) +const ( + nodeFlagLeaf = 1 << 31 + nodeFlagMask = nodeFlagLeaf +) // NOTE: all leafs are on the last level type node struct { // 16-byte structure with alignment // The cache is 512 MiB per 1<<24 (16777216) IDs - fp fingerprint - count uint32 + fp fingerprint + c uint32 +} + +func (node *node) empty() bool { + return node.c == 0 +} + +func (node *node) leaf() bool { + return node.c&nodeFlagLeaf != 0 +} + +func (node *node) count() uint32 { + if node.leaf() { + return 1 + } + return node.c } type cacheIndex uint32 @@ -160,7 +176,7 @@ type fpResult struct { } type aggResult struct { - tails []prefix + tails []uint64 fp fingerprint count uint32 itype int @@ -168,20 +184,66 @@ type aggResult struct { func (r *aggResult) update(node *node) { r.fp.update(node.fp[:]) - r.count += node.count + r.count += node.c } type fpTree struct { nodes [cachedSize * 2]node } +func (ft *fpTree) pushDown(node *node, p prefix) { + if p.len() >= cachedBits { + return + } + pushDownBit := node.c & (1 << (cachedBits - 1 - p.len())) + var pushDownPrefix prefix + if pushDownBit == 0 { + pushDownPrefix = p.left() + } else { + pushDownPrefix = p.right() + } + pushDownIdx, haveIdx := pushDownPrefix.cacheIndex() + if !haveIdx { + panic("BUG: no idx for pushDownPrefix") + } + pushDownNode := &ft.nodes[pushDownIdx] + + // QQQQQ: rm + idx, _ := p.cacheIndex() + fmt.Fprintf(os.Stderr, "QQQQQ: idx: %d pushDownIdx: %d c: %d\n", idx, pushDownIdx, pushDownNode.c) + + if !pushDownNode.empty() { + panic("BUG: non-empty push down node") + } + pushDownNode.c = node.c + pushDownNode.fp = node.fp +} + func (ft *fpTree) addHash(h []byte) { var p prefix v := binary.BigEndian.Uint64(h[:8]) + vFull := v for { idx, haveIdx := p.cacheIndex() - ft.nodes[idx].fp.update(h[:]) - ft.nodes[idx].count++ + if !haveIdx { + panic("BUG: no cache idx") + } + node := &ft.nodes[idx] + switch { + case node.empty(): + node.c = uint32(vFull>>(64-cachedBits)) | nodeFlagLeaf + node.fp.update(h[:]) + fmt.Fprintf(os.Stderr, "QQQQQ: leaf at idx: %d\n", idx) + return + case node.leaf(): + // push down the old leaf + ft.pushDown(node, p) + node.c = 2 + node.fp.update(h[:]) + default: + node.c++ + node.fp.update(h[:]) + } switch { case !haveIdx: panic("BUG: no cache idx") @@ -199,45 +261,92 @@ func (ft *fpTree) addHash(h []byte) { } func (ft *fpTree) aggregateLeft(v uint64, p prefix, r *aggResult) { - bit := v & (1 << (63 - p.len())) - switch { - case p.len() >= cachedBits: - r.tails = append(r.tails, p) + if p.len() >= cachedBits { + r.tails = append(r.tails, p.bits()<<(24-p.len())) fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: add tail\n", v, p) - case bit == 0: - idx, gotIdx := p.right().cacheIndex() - if !gotIdx { - panic("BUG: no idx") - } - r.update(&ft.nodes[idx]) - fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: 0 -> add count %d fp %s\n", v, p, - ft.nodes[idx].count, ft.nodes[idx].fp) + return + } + idx, gotIdx := p.right().cacheIndex() + if !gotIdx { + panic("BUG: no idx") + } + node := &ft.nodes[idx] + if node.leaf() { + r.tails = append(r.tails, uint64(node.c & ^uint32(nodeFlagMask))) + return + } + if bit := v & (1 << (63 - p.len())); bit == 0 { + r.update(node) + fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: 0 -> add count %d fp %s + go left\n", v, p, + node.c, node.fp) ft.aggregateLeft(v, p.left(), r) - default: + } else { fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: 1 -> go right\n", v, p) ft.aggregateLeft(v, p.right(), r) } + + // switch { + // case p.len() >= cachedBits: + // r.tails = append(r.tails, p) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: add tail\n", v, p) + // case bit == 0: + // idx, gotIdx := p.right().cacheIndex() + // if !gotIdx { + // panic("BUG: no idx") + // } + // r.update(&ft.nodes[idx]) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: 0 -> add count %d fp %s\n", v, p, + // ft.nodes[idx].c, ft.nodes[idx].fp) + // ft.aggregateLeft(v, p.left(), r) + // default: + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: 1 -> go right\n", v, p) + // ft.aggregateLeft(v, p.right(), r) + // } } func (ft *fpTree) aggregateRight(v uint64, p prefix, r *aggResult) { - bit := v & (1 << (63 - p.len())) - switch { - case p.len() >= cachedBits: - r.tails = append(r.tails, p) + if p.len() >= cachedBits { + r.tails = append(r.tails, p.bits()<<(24-p.len())) fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: add tail\n", v, p) - case bit == 0: - fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: 0 -> go left\n", v, p) + return + } + idx, gotIdx := p.left().cacheIndex() + if !gotIdx { + panic("BUG: no idx") + } + node := &ft.nodes[idx] + if node.leaf() { + r.tails = append(r.tails, uint64(node.c & ^uint32(nodeFlagMask))) + return + } + if bit := v & (1 << (63 - p.len())); bit == 0 { + fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: 1 -> go left\n", v, p) ft.aggregateRight(v, p.left(), r) - default: - idx, gotIdx := p.left().cacheIndex() - if !gotIdx { - panic("BUG: no idx") - } - r.update(&ft.nodes[idx]) - fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: 1 -> add count %d fp %s + go right\n", v, p, - ft.nodes[idx].count, ft.nodes[idx].fp) + } else { + r.update(node) + fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: 0 -> add count %d fp %s + go right\n", v, p, + node.c, node.fp) ft.aggregateRight(v, p.right(), r) } + + // bit := v & (1 << (63 - p.len())) + // switch { + // case p.len() >= cachedBits: + // r.tails = append(r.tails, p) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: add tail\n", v, p) + // case bit == 0: + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: 0 -> go left\n", v, p) + // ft.aggregateRight(v, p.left(), r) + // default: + // idx, gotIdx := p.left().cacheIndex() + // if !gotIdx { + // panic("BUG: no idx") + // } + // r.update(&ft.nodes[idx]) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: 1 -> add count %d fp %s + go right\n", v, p, + // ft.nodes[idx].c, ft.nodes[idx].fp) + // ft.aggregateRight(v, p.right(), r) + // } } func (ft *fpTree) aggregateInterval(x, y []byte) aggResult { @@ -265,12 +374,21 @@ func (ft *fpTree) dumpNode(w io.Writer, p prefix, indent, dir string) { if !gotIdx { return } - c := ft.nodes[idx].count - if c == 0 { + node := &ft.nodes[idx] + if node.empty() { return } - fmt.Fprintf(w, "%s%s%s %d\n", indent, dir, ft.nodes[idx].fp, c) - if c > 1 { + var countStr string + leaf := node.leaf() + if leaf { + countStr = "LEAF" + } else if node.empty() { + countStr = "EMPTY" + } else { + countStr = strconv.Itoa(int(node.count())) + } + fmt.Fprintf(w, "%s%s%s %s\n", indent, dir, node.fp, countStr) + if !leaf { indent += " " ft.dumpNode(w, p.left(), indent, "l: ") ft.dumpNode(w, p.right(), indent, "r: ") @@ -278,7 +396,7 @@ func (ft *fpTree) dumpNode(w io.Writer, p prefix, indent, dir string) { } func (ft *fpTree) dump(w io.Writer) { - if ft.nodes[0].count == 0 { + if ft.nodes[0].c == 0 { fmt.Fprintln(w, "empty tree") } else { ft.dumpNode(w, 0, "", "") @@ -307,10 +425,7 @@ func (mft *inMemFPTree) addHash(h []byte) { func (mft *inMemFPTree) aggregateInterval(x, y []byte) fpResult { r := mft.tree.aggregateInterval(x, y) for _, t := range r.tails { - if t.len() != cachedBits { - panic("BUG: inMemFPTree.aggregateInterval: bad prefix bit count") - } - ids := mft.ids[t.bits()] + ids := mft.ids[t] for _, id := range ids { // FIXME: this can be optimized as the IDs are ordered if idWithinInterval(id, x, y, r.itype) { diff --git a/sync2/dbsync/dbsync_test.go b/sync2/dbsync/dbsync_test.go index ba05a3e9e3..acd1a2c8f1 100644 --- a/sync2/dbsync/dbsync_test.go +++ b/sync2/dbsync/dbsync_test.go @@ -289,6 +289,7 @@ func TestInMemFPTree(t *testing.T) { "8888888888888888888888888888888888888888888888888888888888888888", "ABCDEF1234567890000000000000000000000000000000000000000000000000", } { + t.Logf("QQQQQ: ADD: %s", hex) h := types.HexToHash32(hex) hs = append(hs, h) mft.addHash(h[:]) From 34224ceca6781f9749bac59a9b082ac339a4bd32 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 18 Jun 2024 20:40:15 +0400 Subject: [PATCH 31/76] fptree works --- sync2/dbsync/dbsync.go | 66 +++++++++++++++++++-------- sync2/dbsync/dbsync_test.go | 91 +++++++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+), 18 deletions(-) diff --git a/sync2/dbsync/dbsync.go b/sync2/dbsync/dbsync.go index 87ac0a4e4c..039bb55586 100644 --- a/sync2/dbsync/dbsync.go +++ b/sync2/dbsync/dbsync.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "math/bits" - "os" "strconv" "golang.org/x/exp/slices" @@ -54,6 +53,13 @@ const ( // NOTE: all leafs are on the last level +// type node struct { +// fp fingerprint +// c uint32 +// left, right uint32 +// refCount atomic.Uint32 +// } + type node struct { // 16-byte structure with alignment // The cache is 512 MiB per 1<<24 (16777216) IDs @@ -184,7 +190,7 @@ type aggResult struct { func (r *aggResult) update(node *node) { r.fp.update(node.fp[:]) - r.count += node.c + r.count += node.count() } type fpTree struct { @@ -209,8 +215,8 @@ func (ft *fpTree) pushDown(node *node, p prefix) { pushDownNode := &ft.nodes[pushDownIdx] // QQQQQ: rm - idx, _ := p.cacheIndex() - fmt.Fprintf(os.Stderr, "QQQQQ: idx: %d pushDownIdx: %d c: %d\n", idx, pushDownIdx, pushDownNode.c) + // idx, _ := p.cacheIndex() + // fmt.Fprintf(os.Stderr, "QQQQQ: idx: %d pushDownIdx: %d c: %d\n", idx, pushDownIdx, pushDownNode.c) if !pushDownNode.empty() { panic("BUG: non-empty push down node") @@ -233,7 +239,7 @@ func (ft *fpTree) addHash(h []byte) { case node.empty(): node.c = uint32(vFull>>(64-cachedBits)) | nodeFlagLeaf node.fp.update(h[:]) - fmt.Fprintf(os.Stderr, "QQQQQ: leaf at idx: %d\n", idx) + // fmt.Fprintf(os.Stderr, "QQQQQ: leaf at idx: %d\n", idx) return case node.leaf(): // push down the old leaf @@ -263,25 +269,35 @@ func (ft *fpTree) addHash(h []byte) { func (ft *fpTree) aggregateLeft(v uint64, p prefix, r *aggResult) { if p.len() >= cachedBits { r.tails = append(r.tails, p.bits()<<(24-p.len())) - fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: add tail\n", v, p) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: add tail\n", v, p) return } - idx, gotIdx := p.right().cacheIndex() + idx, gotIdx := p.cacheIndex() if !gotIdx { panic("BUG: no idx") } node := &ft.nodes[idx] + if node.empty() { + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: idx=%d: %016x %s: empty node\n", idx, v, p) + return + } if node.leaf() { r.tails = append(r.tails, uint64(node.c & ^uint32(nodeFlagMask))) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: leaf\n", v, p) return } if bit := v & (1 << (63 - p.len())); bit == 0 { - r.update(node) - fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: 0 -> add count %d fp %s + go left\n", v, p, - node.c, node.fp) + rIdx, gotIdx := p.right().cacheIndex() + if !gotIdx { + panic("BUG: no idx") + } + rNode := &ft.nodes[rIdx] + r.update(rNode) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: 0 -> add count %d fp %s + go left\n", v, p, + // rNode.c, rNode.fp) ft.aggregateLeft(v, p.left(), r) } else { - fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: 1 -> go right\n", v, p) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: 1 -> go right\n", v, p) ft.aggregateLeft(v, p.right(), r) } @@ -307,25 +323,35 @@ func (ft *fpTree) aggregateLeft(v uint64, p prefix, r *aggResult) { func (ft *fpTree) aggregateRight(v uint64, p prefix, r *aggResult) { if p.len() >= cachedBits { r.tails = append(r.tails, p.bits()<<(24-p.len())) - fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: add tail\n", v, p) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: add tail\n", v, p) return } - idx, gotIdx := p.left().cacheIndex() + idx, gotIdx := p.cacheIndex() if !gotIdx { panic("BUG: no idx") } node := &ft.nodes[idx] + if node.empty() { + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: idx=%d: %016x %s: empty node\n", idx, v, p) + return + } if node.leaf() { r.tails = append(r.tails, uint64(node.c & ^uint32(nodeFlagMask))) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: leaf\n", v, p) return } if bit := v & (1 << (63 - p.len())); bit == 0 { - fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: 1 -> go left\n", v, p) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: 1 -> go left\n", v, p) ft.aggregateRight(v, p.left(), r) } else { - r.update(node) - fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: 0 -> add count %d fp %s + go right\n", v, p, - node.c, node.fp) + lIdx, gotIdx := p.left().cacheIndex() + if !gotIdx { + panic("BUG: no idx") + } + lNode := &ft.nodes[lIdx] + r.update(lNode) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: 0 -> add count %d fp %s + go right\n", v, p, + // lNode.c, lNode.fp) ft.aggregateRight(v, p.right(), r) } @@ -350,6 +376,7 @@ func (ft *fpTree) aggregateRight(v uint64, p prefix, r *aggResult) { } func (ft *fpTree) aggregateInterval(x, y []byte) aggResult { + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateInterval: %s %s\n", hex.EncodeToString(x), hex.EncodeToString(y)) var r aggResult r.itype = bytes.Compare(x, y) switch { @@ -387,7 +414,7 @@ func (ft *fpTree) dumpNode(w io.Writer, p prefix, indent, dir string) { } else { countStr = strconv.Itoa(int(node.count())) } - fmt.Fprintf(w, "%s%s%s %s\n", indent, dir, node.fp, countStr) + fmt.Fprintf(w, "%s%sidx=%d %s %s\n", indent, dir, idx, node.fp, countStr) if !leaf { indent += " " ft.dumpNode(w, p.left(), indent, "l: ") @@ -429,8 +456,11 @@ func (mft *inMemFPTree) aggregateInterval(x, y []byte) fpResult { for _, id := range ids { // FIXME: this can be optimized as the IDs are ordered if idWithinInterval(id, x, y, r.itype) { + // fmt.Fprintf(os.Stderr, "QQQQQ: including tail: %s\n", hex.EncodeToString(id)) r.fp.update(id) r.count++ + } else { + // fmt.Fprintf(os.Stderr, "QQQQQ: NOT including tail: %s\n", hex.EncodeToString(id)) } } } diff --git a/sync2/dbsync/dbsync_test.go b/sync2/dbsync/dbsync_test.go index acd1a2c8f1..85408aa595 100644 --- a/sync2/dbsync/dbsync_test.go +++ b/sync2/dbsync/dbsync_test.go @@ -1,6 +1,7 @@ package dbsync import ( + "bytes" "math/bits" "slices" "strings" @@ -8,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/exp/maps" + // "golang.org/x/exp/rand" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" @@ -316,3 +318,92 @@ func TestInMemFPTree(t *testing.T) { }, mft.aggregateInterval(hs[1][:], hs[4][:])) // TBD: test reverse range } + +func TestInMemFPTreeRmme1(t *testing.T) { + var mft inMemFPTree + var hs []types.Hash32 + for _, hex := range []string{ + "829977b444c8408dcddc1210536f3b3bdc7fd97777426264b9ac8f70b97a7fd1", + "6e476ca729c3840d0118785496e488124ee7dade1aef0c87c6edc78f72e4904f", + "a280bcb8123393e0d4a15e5c9850aab5dddffa03d5efa92e59bc96202e8992bc", + "e93163f908630280c2a8bffd9930aa684be7a3085432035f5c641b0786590d1d", + } { + t.Logf("QQQQQ: ADD: %s", hex) + h := types.HexToHash32(hex) + hs = append(hs, h) + mft.addHash(h[:]) + } + var sb strings.Builder + mft.tree.dump(&sb) + t.Logf("QQQQQ: tree:\n%s", sb.String()) + require.Equal(t, hexToFingerprint("a76fc452775b55e0dacd8be5"), mft.tree.nodes[0].fp) + require.Equal(t, fpResult{ + fp: hexToFingerprint("2019cb0c56fbd36d197d4c4c"), + count: 2, + }, mft.aggregateInterval(hs[0][:], hs[3][:])) +} + +type hashList []types.Hash32 + +func (l hashList) findGTE(h types.Hash32) int { + p, _ := slices.BinarySearchFunc(l, h, func(a, b types.Hash32) int { + return a.Compare(b) + }) + return p +} + +func TestInMemFPTreeManyItems(t *testing.T) { + var mft inMemFPTree + const numItems = 1 << 20 + hs := make(hashList, numItems) + var fp fingerprint + for i := range hs { + h := types.RandomHash() + hs[i] = h + mft.addHash(h[:]) + fp.update(h[:]) + } + // var sb strings.Builder + // mft.tree.dump(&sb) + // t.Logf("QQQQQ: tree:\n%s", sb.String()) + slices.SortFunc(hs, func(a, b types.Hash32) int { + return a.Compare(b) + }) + // for i, h := range hs { + // t.Logf("h[%d] = %s", i, h.String()) + // } + require.Equal(t, fp, mft.tree.nodes[0].fp) + for i := 0; i < 100; i++ { + // TBD: allow reverse order + // TBD: pick some intervals from the hashes + x := types.RandomHash() + y := types.RandomHash() + // x := hs[rand.Intn(numItems)] + // y := hs[rand.Intn(numItems)] + c := bytes.Compare(x[:], y[:]) + var ( + expFP fingerprint + expN uint32 + ) + if c > 0 { + x, y = y, x + } + if c == 0 { + expFP = fp + expN = numItems + } else { + pX := hs.findGTE(x) + pY := hs.findGTE(y) + // t.Logf("x=%s y=%s pX=%d y=%d", x.String(), y.String(), pX, pY) + for p := pX; p < pY; p++ { + // t.Logf("XOR %s", hs[p].String()) + expFP.update(hs[p][:]) + } + expN = uint32(pY - pX) + } + require.Equal(t, fpResult{ + fp: expFP, + count: expN, + }, mft.aggregateInterval(x[:], y[:])) + } +} From a970d21bd7c127f1e203d25adbae331b569e0645 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 19 Jun 2024 17:37:26 +0400 Subject: [PATCH 32/76] pool based tree --- sync2/dbsync/dbsync.go | 672 ++++++++++++++++++++++------------- sync2/dbsync/dbsync_test.go | 555 +++++++++++++++++------------ sync2/dbsync/refcountpool.go | 75 ++++ 3 files changed, 826 insertions(+), 476 deletions(-) create mode 100644 sync2/dbsync/refcountpool.go diff --git a/sync2/dbsync/dbsync.go b/sync2/dbsync/dbsync.go index 039bb55586..4a9755b9dd 100644 --- a/sync2/dbsync/dbsync.go +++ b/sync2/dbsync/dbsync.go @@ -7,22 +7,25 @@ import ( "fmt" "io" "math/bits" + "slices" "strconv" - - "golang.org/x/exp/slices" ) const ( fingerprintBytes = 12 - cachedBits = 24 - cachedSize = 1 << cachedBits - cacheMask = cachedSize - 1 - maxIDBytes = 32 - bit63 = 1 << 63 + // cachedBits = 24 + // cachedSize = 1 << cachedBits + // cacheMask = cachedSize - 1 + maxIDBytes = 32 + bit63 = 1 << 63 ) type fingerprint [fingerprintBytes]byte +func (fp fingerprint) Compare(other fingerprint) int { + return bytes.Compare(fp[:], other[:]) +} + func (fp fingerprint) String() string { return hex.EncodeToString(fp[:]) } @@ -33,6 +36,13 @@ func (fp *fingerprint) update(h []byte) { } } +func (fp *fingerprint) bitFromLeft(n int) bool { + if n > fingerprintBytes*8 { + panic("BUG: bad fingerprint bit index") + } + return (fp[n>>3]>>(7-n&0x7))&1 != 0 +} + func hexToFingerprint(s string) fingerprint { b, err := hex.DecodeString(s) if err != nil { @@ -46,43 +56,142 @@ func hexToFingerprint(s string) fingerprint { return fp } -const ( - nodeFlagLeaf = 1 << 31 - nodeFlagMask = nodeFlagLeaf -) +// const ( +// nodeFlagLeaf = 1 << 31 +// nodeFlagMask = nodeFlagLeaf +// ) // NOTE: all leafs are on the last level -// type node struct { -// fp fingerprint -// c uint32 -// left, right uint32 -// refCount atomic.Uint32 +// type nodeIndex uint32 + +// const noIndex nodeIndex = ^nodeIndex(0) + +// // TODO: nodePool limiting +// type nodePool struct { +// mtx sync.Mutex +// nodes []node +// // freeList is 1-based so that nodePool doesn't need a constructor +// freeList nodeIndex // } -type node struct { - // 16-byte structure with alignment - // The cache is 512 MiB per 1<<24 (16777216) IDs - fp fingerprint - c uint32 +// func (np *nodePool) node(idx nodeIndex) node { +// np.mtx.Lock() +// defer np.mtx.Unlock() +// return np.nodeUnlocked(idx) +// } + +// func (np *nodePool) nodeUnlocked(idx nodeIndex) node { +// node := &np.nodes[idx] +// refs := node.refCount +// if refs < 0 { +// panic("BUG: negative nodePool entry refcount") +// } else if refs == 0 { +// panic("BUG: referencing a free nodePool entry") +// } +// return *node +// } + +// func (np *nodePool) add(fp fingerprint, c uint32, left, right nodeIndex) nodeIndex { +// np.mtx.Lock() +// defer np.mtx.Unlock() +// var idx nodeIndex +// // validate indices +// if left != noIndex { +// np.nodeUnlocked(left) +// } +// if right != noIndex { +// np.nodeUnlocked(right) +// } +// if np.freeList != 0 { +// idx = nodeIndex(np.freeList - 1) +// np.freeList = np.nodes[idx].left +// np.nodes[idx].refCount++ +// if np.nodes[idx].refCount != 1 { +// panic("BUG: refCount != 1 for a node taken from the freelist") +// } +// } else { +// idx = nodeIndex(len(np.nodes)) +// np.nodes = append(np.nodes, node{refCount: 1}) +// } +// node := &np.nodes[idx] +// node.fp = fp +// node.c = c +// node.left = left +// node.right = right +// return idx +// } + +// func (np *nodePool) release(idx nodeIndex) { +// np.mtx.Lock() +// defer np.mtx.Unlock() +// node := &np.nodes[idx] +// if node.refCount <= 0 { +// panic("BUG: negative nodePool entry refcount") +// } +// node.refCount-- +// if node.refCount == 0 { +// node.left = np.freeList +// np.freeList = idx + 1 +// } +// } + +// func (np *nodePool) ref(idx nodeIndex) { +// np.mtx.Lock() +// np.nodes[idx].refCount++ +// np.mtx.Unlock() +// } + +type nodeIndex uint32 + +const noIndex = ^nodeIndex(0) + +type nodePool struct { + rcPool[node, nodeIndex] } -func (node *node) empty() bool { - return node.c == 0 +func (np *nodePool) add(fp fingerprint, c uint32, left, right nodeIndex) nodeIndex { + return np.rcPool.add(node{fp: fp, c: c, left: left, right: right}) } -func (node *node) leaf() bool { - return node.c&nodeFlagLeaf != 0 +func (np *nodePool) node(idx nodeIndex) node { + return np.rcPool.item(idx) } -func (node *node) count() uint32 { - if node.leaf() { - return 1 - } - return node.c +// fpTree node. +// The nodes are immutable except for refCount field, which should +// only be used directly by nodePool methods +type node struct { + fp fingerprint + c uint32 + left, right nodeIndex } -type cacheIndex uint32 +func (n node) leaf() bool { + return n.left == noIndex && n.right == noIndex +} + +// type node struct { +// // 16-byte structure with alignment +// // The cache is 512 MiB per 1<<24 (16777216) IDs +// fp fingerprint +// c uint32 +// } + +// func (node *node) empty() bool { +// return node.c == 0 +// } + +// func (node *node) leaf() bool { +// return node.c&nodeFlagLeaf != 0 +// } + +// func (node *node) count() uint32 { +// if node.leaf() { +// return 1 +// } +// return node.c +// } const ( prefixLenBits = 6 @@ -113,27 +222,11 @@ func (p prefix) right() prefix { return p.left() + (1 << prefixLenBits) } -func (p prefix) cacheIndex() (cacheIndex, bool) { - if l := p.len(); l <= cachedBits { - // Notation: prefix(cacheIndex) - // - // empty(0) - // / \ - // / \ - // / \ - // 0(1) 1(2) - // / \ / \ - // / \ / \ - // 00(3) 01(4) 10(5) 11(6) - - // indexing starts at 1 - // left: n = n*2 - // right: n = n*2+1 - // but in the end we substract 1 to make it 0-based again - - return cacheIndex(p.bits() | (1 << l) - 1), true +func (p prefix) dir(bit bool) prefix { + if bit { + return p.right() } - return 0, false + return p.left() } func (p prefix) String() string { @@ -144,29 +237,48 @@ func (p prefix) String() string { return fmt.Sprintf("<%d:%s>", p.len(), b[64-p.len():]) } -func load64(h []byte) uint64 { - return binary.BigEndian.Uint64(h[:8]) +func (p prefix) highBit() bool { + if p == 0 { + return false + } + return p.bits()>>(p.len()-1) != 0 } -func hashPrefix(h []byte, nbits int) prefix { - if nbits < 0 || nbits > maxPrefixLen { - panic("BUG: bad prefix length") - } - if nbits == 0 { +// shift removes the highest bit from the prefix +// TBD: QQQQQ: test shift +func (p prefix) shift() prefix { + switch l := uint64(p.len()); l { + case 0: + panic("BUG: can't shift zero prefix") + case 1: return 0 + default: + return prefix(((p.bits() & ((1 << (l - 1)) - 1)) << prefixLenBits) + l - 1) } - v := load64(h) - return prefix((v>>(64-nbits-prefixLenBits))&prefixBitMask + uint64(nbits)) } +func load64(h []byte) uint64 { + return binary.BigEndian.Uint64(h[:8]) +} + +// func hashPrefix(h []byte, nbits int) prefix { +// if nbits < 0 || nbits > maxPrefixLen { +// panic("BUG: bad prefix length") +// } +// if nbits == 0 { +// return 0 +// } +// v := load64(h) +// return prefix((v>>(64-nbits-prefixLenBits))&prefixBitMask + uint64(nbits)) +// } + func preFirst0(h []byte) prefix { l := min(maxPrefixLen, bits.LeadingZeros64(^load64(h))) - return hashPrefix(h, l) + return prefix(((1<>(64-l))<= cachedBits { - return + np *nodePool + root nodeIndex + maxDepth int +} + +func newFPTree(np *nodePool, maxDepth int) *fpTree { + return &fpTree{np: np, root: noIndex, maxDepth: maxDepth} +} + +func (ft *fpTree) pushDown(fpA, fpB fingerprint, p prefix, curCount uint32) nodeIndex { + // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: fpA %s fpB %s p %s\n", fpA.String(), fpB.String(), p) + fpCombined := fpA + fpCombined.update(fpB[:]) + if ft.maxDepth != 0 && p.len() == ft.maxDepth { + // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: add at maxDepth\n") + return ft.np.add(fpCombined, curCount+1, noIndex, noIndex) + } + if curCount != 1 { + panic("BUG: pushDown of non-1-leaf below maxDepth") + } + dirA := fpA.bitFromLeft(p.len()) + dirB := fpB.bitFromLeft(p.len()) + // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: bitFromLeft %d: dirA %v dirB %v\n", p.len(), dirA, dirB) + if dirA == dirB { + childIdx := ft.pushDown(fpA, fpB, p.dir(dirA), 1) + if dirA { + // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: sameDir: left\n") + return ft.np.add(fpCombined, 2, noIndex, childIdx) + } else { + // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: sameDir: right\n") + return ft.np.add(fpCombined, 2, childIdx, noIndex) + } } - pushDownBit := node.c & (1 << (cachedBits - 1 - p.len())) - var pushDownPrefix prefix - if pushDownBit == 0 { - pushDownPrefix = p.left() + + idxA := ft.np.add(fpA, 1, noIndex, noIndex) + idxB := ft.np.add(fpB, curCount, noIndex, noIndex) + if dirA { + // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: add A-B\n") + return ft.np.add(fpCombined, 2, idxB, idxA) } else { - pushDownPrefix = p.right() - } - pushDownIdx, haveIdx := pushDownPrefix.cacheIndex() - if !haveIdx { - panic("BUG: no idx for pushDownPrefix") + // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: add B-A\n") + return ft.np.add(fpCombined, 2, idxA, idxB) + } +} + +func (ft *fpTree) addValue(fp fingerprint, p prefix, idx nodeIndex) nodeIndex { + if idx == noIndex { + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: addNew fp %s p %s idx %d\n", fp.String(), p.String(), idx) + return ft.np.add(fp, 1, noIndex, noIndex) + } + node := ft.np.node(idx) + // We've got a copy of the node, so we release it right away. + // This way, it'll likely be reused for the new nodes created + // as this hash is being added, as the node pool's freeList is + // LIFO + ft.np.release(idx) + if node.c == 1 || (ft.maxDepth != 0 && p.len() == ft.maxDepth) { + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: pushDown fp %s p %s idx %d\n", fp.String(), p.String(), idx) + // we're at a leaf node, need to push down the old fingerprint, or, + // if we've reached the max depth, just update the current node + return ft.pushDown(fp, node.fp, p, node.c) + } + fpCombined := fp + fpCombined.update(node.fp[:]) + if fp.bitFromLeft(p.len()) { + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceRight fp %s p %s idx %d\n", fp.String(), p.String(), idx) + newRight := ft.addValue(fp, p.right(), node.right) + return ft.np.add(fpCombined, node.c+1, node.left, newRight) + } else { + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceLeft fp %s p %s idx %d\n", fp.String(), p.String(), idx) + newLeft := ft.addValue(fp, p.left(), node.left) + return ft.np.add(fpCombined, node.c+1, newLeft, node.right) } - pushDownNode := &ft.nodes[pushDownIdx] +} - // QQQQQ: rm - // idx, _ := p.cacheIndex() - // fmt.Fprintf(os.Stderr, "QQQQQ: idx: %d pushDownIdx: %d c: %d\n", idx, pushDownIdx, pushDownNode.c) +func (ft *fpTree) addHash(h []byte) { + var fp fingerprint + fp.update(h) + ft.root = ft.addValue(fp, 0, ft.root) + // fmt.Fprintf(os.Stderr, "QQQQQ: addHash: new root %d\n", ft.root) +} - if !pushDownNode.empty() { - panic("BUG: non-empty push down node") +func (ft *fpTree) followPrefix(from nodeIndex, p prefix) (nodeIndex, bool) { + // fmt.Fprintf(os.Stderr, "QQQQQ: followPrefix: from %d p %s highBit %v\n", from, p, p.highBit()) + switch { + case p == 0: + return from, true + case from == noIndex: + return noIndex, false + case ft.np.node(from).leaf(): + return from, false + case p.highBit(): + return ft.followPrefix(ft.np.node(from).right, p.shift()) + default: + return ft.followPrefix(ft.np.node(from).left, p.shift()) } - pushDownNode.c = node.c - pushDownNode.fp = node.fp } -func (ft *fpTree) addHash(h []byte) { - var p prefix - v := binary.BigEndian.Uint64(h[:8]) - vFull := v - for { - idx, haveIdx := p.cacheIndex() - if !haveIdx { - panic("BUG: no cache idx") - } - node := &ft.nodes[idx] - switch { - case node.empty(): - node.c = uint32(vFull>>(64-cachedBits)) | nodeFlagLeaf - node.fp.update(h[:]) - // fmt.Fprintf(os.Stderr, "QQQQQ: leaf at idx: %d\n", idx) - return - case node.leaf(): - // push down the old leaf - ft.pushDown(node, p) - node.c = 2 - node.fp.update(h[:]) - default: - node.c++ - node.fp.update(h[:]) - } - switch { - case !haveIdx: - panic("BUG: no cache idx") - case p.len() > cachedBits: - panic("BUG: prefix too long") - case p.len() == cachedBits: - return - case v&bit63 == 0: - p = p.left() - default: - p = p.right() - } - v <<= 1 +func (ft *fpTree) tailRefFromPrefix(p prefix) uint64 { + if p.len() != ft.maxDepth { + panic("BUG: tail from short prefix") } + return p.bits() } -func (ft *fpTree) aggregateLeft(v uint64, p prefix, r *aggResult) { - if p.len() >= cachedBits { - r.tails = append(r.tails, p.bits()<<(24-p.len())) - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: add tail\n", v, p) - return +func (ft *fpTree) tailRefFromFingerprint(fp fingerprint) uint64 { + v := load64(fp[:]) + if ft.maxDepth >= 64 { + return v } - idx, gotIdx := p.cacheIndex() - if !gotIdx { - panic("BUG: no idx") - } - node := &ft.nodes[idx] - if node.empty() { - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: idx=%d: %016x %s: empty node\n", idx, v, p) - return + // fmt.Fprintf(os.Stderr, "QQQQQ: AAAAA: v %016x maxDepth %d shift %d\n", v, ft.maxDepth, (64 - ft.maxDepth)) + return v >> (64 - ft.maxDepth) +} + +func (ft *fpTree) tailRefFromNodeAndPrefix(n node, p prefix) uint64 { + if n.c == 1 { + return ft.tailRefFromFingerprint(n.fp) + } else { + return ft.tailRefFromPrefix(p) } - if node.leaf() { - r.tails = append(r.tails, uint64(node.c & ^uint32(nodeFlagMask))) - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: leaf\n", v, p) +} + +func (ft *fpTree) aggregateLeft(idx nodeIndex, v uint64, p prefix, r *aggResult) { + if idx == noIndex { + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft %d %016x %s: noIndex\n", idx, v, p) return } - if bit := v & (1 << (63 - p.len())); bit == 0 { - rIdx, gotIdx := p.right().cacheIndex() - if !gotIdx { - panic("BUG: no idx") + node := ft.np.node(idx) + switch { + case p.len() == ft.maxDepth: + if node.left != noIndex || node.right != noIndex { + panic("BUG: node @ maxDepth has children") } - rNode := &ft.nodes[rIdx] - r.update(rNode) - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: 0 -> add count %d fp %s + go left\n", v, p, - // rNode.c, rNode.fp) - ft.aggregateLeft(v, p.left(), r) - } else { - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: 1 -> go right\n", v, p) - ft.aggregateLeft(v, p.right(), r) - } - - // switch { - // case p.len() >= cachedBits: - // r.tails = append(r.tails, p) - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: add tail\n", v, p) - // case bit == 0: - // idx, gotIdx := p.right().cacheIndex() - // if !gotIdx { - // panic("BUG: no idx") - // } - // r.update(&ft.nodes[idx]) - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: 0 -> add count %d fp %s\n", v, p, - // ft.nodes[idx].c, ft.nodes[idx].fp) - // ft.aggregateLeft(v, p.left(), r) - // default: - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: %016x %s: 1 -> go right\n", v, p) - // ft.aggregateLeft(v, p.right(), r) - // } -} - -func (ft *fpTree) aggregateRight(v uint64, p prefix, r *aggResult) { - if p.len() >= cachedBits { - r.tails = append(r.tails, p.bits()<<(24-p.len())) - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: add tail\n", v, p) - return - } - idx, gotIdx := p.cacheIndex() - if !gotIdx { - panic("BUG: no idx") - } - node := &ft.nodes[idx] - if node.empty() { - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: idx=%d: %016x %s: empty node\n", idx, v, p) - return + tail := ft.tailRefFromPrefix(p) + r.tailRefs = append(r.tailRefs, tail) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft %d %016x %s: hit maxDepth, add prefix to the tails: %016x\n", idx, v, p, tail) + case node.leaf(): + // For leaf 1-nodes, we can use the fingerprint to get tailRef + // by which the actual IDs will be selected + if node.c != 1 { + panic("BUG: leaf non-1 node below maxDepth") + } + tail := ft.tailRefFromFingerprint(node.fp) + r.tailRefs = append(r.tailRefs, tail) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft %d %016x %s: hit 1-leaf, add prefix to the tails: %016x (fp %s)\n", idx, v, p, tail, node.fp) + case v&bit63 == 0: + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft %d %016x %s: incl right node %d + go left to node %d\n", idx, v, p, node.right, node.left) + if node.right != noIndex { + r.update(ft.np.node(node.right)) + } + ft.aggregateLeft(node.left, v<<1, p.left(), r) + default: + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft %d %016x %s: go right node %d\n", idx, v, p, node.right) + ft.aggregateLeft(node.right, v<<1, p.right(), r) } - if node.leaf() { - r.tails = append(r.tails, uint64(node.c & ^uint32(nodeFlagMask))) - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: leaf\n", v, p) +} + +func (ft *fpTree) aggregateRight(idx nodeIndex, v uint64, p prefix, r *aggResult) { + if idx == noIndex { + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight %d %016x %s: noIndex\n", idx, v, p) return } - if bit := v & (1 << (63 - p.len())); bit == 0 { - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: 1 -> go left\n", v, p) - ft.aggregateRight(v, p.left(), r) - } else { - lIdx, gotIdx := p.left().cacheIndex() - if !gotIdx { - panic("BUG: no idx") + node := ft.np.node(idx) + switch { + case p.len() == ft.maxDepth: + if node.left != noIndex || node.right != noIndex { + panic("BUG: node @ maxDepth has children") + } + tail := ft.tailRefFromPrefix(p) + r.tailRefs = append(r.tailRefs, tail) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight %d %016x %s: hit maxDepth, add prefix to the tails: %016x\n", idx, v, p, tail) + case node.leaf(): + // For leaf 1-nodes, we can use the fingerprint to get tailRef + // by which the actual IDs will be selected + if node.c != 1 { + panic("BUG: leaf non-1 node below maxDepth") } - lNode := &ft.nodes[lIdx] - r.update(lNode) - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: 0 -> add count %d fp %s + go right\n", v, p, - // lNode.c, lNode.fp) - ft.aggregateRight(v, p.right(), r) - } - - // bit := v & (1 << (63 - p.len())) - // switch { - // case p.len() >= cachedBits: - // r.tails = append(r.tails, p) - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: add tail\n", v, p) - // case bit == 0: - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: 0 -> go left\n", v, p) - // ft.aggregateRight(v, p.left(), r) - // default: - // idx, gotIdx := p.left().cacheIndex() - // if !gotIdx { - // panic("BUG: no idx") - // } - // r.update(&ft.nodes[idx]) - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight: %016x %s: 1 -> add count %d fp %s + go right\n", v, p, - // ft.nodes[idx].c, ft.nodes[idx].fp) - // ft.aggregateRight(v, p.right(), r) - // } + tail := ft.tailRefFromFingerprint(node.fp) + r.tailRefs = append(r.tailRefs, tail) + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight %d %016x %s: hit 1-leaf, add prefix to the tails: %016x (fp %s)\n", idx, v, p, tail, node.fp) + case v&bit63 == 0: + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight %d %016x %s: go left to node %d\n", idx, v, p, node.left) + ft.aggregateRight(node.left, v<<1, p.left(), r) + default: + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight %d %016x %s: incl left node %d + go right to node %d\n", idx, v, p, node.left, node.right) + if node.left != noIndex { + r.update(ft.np.node(node.left)) + } + ft.aggregateRight(node.right, v<<1, p.right(), r) + } } func (ft *fpTree) aggregateInterval(x, y []byte) aggResult { @@ -382,62 +501,103 @@ func (ft *fpTree) aggregateInterval(x, y []byte) aggResult { switch { case r.itype == 0: // the whole set - r.update(&ft.nodes[0]) + if ft.root != noIndex { + r.update(ft.np.node(ft.root)) + } case r.itype < 0: // "proper" interval: [x; lca); (lca; y) p := commonPrefix(x, y) - ft.aggregateLeft(load64(x), p.left(), &r) - ft.aggregateRight(load64(y), p.right(), &r) + lca, found := ft.followPrefix(ft.root, p) + // fmt.Fprintf(os.Stderr, "QQQQQ: commonPrefix %s lca %d found %v\n", p, lca, found) + switch { + case found: + lcaNode := ft.np.node(lca) + ft.aggregateLeft(lcaNode.left, load64(x)<<(p.len()+1), p.left(), &r) + ft.aggregateRight(lcaNode.right, load64(y)<<(p.len()+1), p.right(), &r) + case lca != noIndex: + // fmt.Fprintf(os.Stderr, "QQQQQ: commonPrefix %s NOT found but have lca %d\n", p, lca) + // Didn't reach LCA in the tree b/c ended up + // at a leaf, just use the prefix to go + // through the IDs + lcaNode := ft.np.node(lca) + r.tailRefs = append(r.tailRefs, ft.tailRefFromNodeAndPrefix(lcaNode, p)) + } default: // inverse interval: [min; y); [x; max] - ft.aggregateRight(load64(y), preFirst1(y), &r) - ft.aggregateLeft(load64(x), preFirst0(x), &r) + pf1 := preFirst1(y) + idx1, found := ft.followPrefix(ft.root, pf1) + switch { + case found: + ft.aggregateRight(idx1, load64(y)<> (64 - cachedBits) + idx := load64(h) >> (64 - mft.tree.maxDepth) s := mft.ids[idx] n := slices.IndexFunc(s, func(cur []byte) bool { return bytes.Compare(cur, h) > 0 @@ -451,7 +611,7 @@ func (mft *inMemFPTree) addHash(h []byte) { func (mft *inMemFPTree) aggregateInterval(x, y []byte) fpResult { r := mft.tree.aggregateInterval(x, y) - for _, t := range r.tails { + for _, t := range r.tailRefs { ids := mft.ids[t] for _, id := range ids { // FIXME: this can be optimized as the IDs are ordered diff --git a/sync2/dbsync/dbsync_test.go b/sync2/dbsync/dbsync_test.go index 85408aa595..c6166ffe78 100644 --- a/sync2/dbsync/dbsync_test.go +++ b/sync2/dbsync/dbsync_test.go @@ -3,118 +3,153 @@ package dbsync import ( "bytes" "math/bits" + "math/rand" "slices" "strings" "testing" "github.com/stretchr/testify/require" "golang.org/x/exp/maps" - // "golang.org/x/exp/rand" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/statesql" ) +func TestNodePool(t *testing.T) { + // TODO: convert to TestRCPool + var np nodePool + idx1 := np.add(fingerprint{1, 2, 3}, 1, noIndex, noIndex) + node1 := np.node(idx1) + idx2 := np.add(fingerprint{2, 3, 4}, 2, idx1, noIndex) + node2 := np.node(idx2) + require.Equal(t, fingerprint{1, 2, 3}, node1.fp) + require.Equal(t, uint32(1), node1.c) + require.Equal(t, noIndex, node1.left) + require.Equal(t, noIndex, node1.right) + require.Equal(t, fingerprint{2, 3, 4}, node2.fp) + require.Equal(t, uint32(2), node2.c) + require.Equal(t, idx1, node2.left) + require.Equal(t, noIndex, node2.right) + idx3 := np.add(fingerprint{2, 3, 5}, 1, noIndex, noIndex) + idx4 := np.add(fingerprint{2, 3, 6}, 1, idx2, idx3) + require.Equal(t, nodeIndex(3), idx4) + np.ref(idx4) + + np.release(idx4) + // not yet released due to an extra ref + require.Equal(t, nodeIndex(4), np.add(fingerprint{2, 3, 7}, 1, noIndex, noIndex)) + + np.release(idx4) + // idx4 was freed + require.Equal(t, idx4, np.add(fingerprint{2, 3, 8}, 1, noIndex, noIndex)) + + // free item used just once + require.Equal(t, nodeIndex(5), np.add(fingerprint{2, 3, 9}, 1, noIndex, noIndex)) + + // form a free list + np.release(idx3) + np.release(idx2) + np.release(idx1) + + // the free list is LIFO + require.Equal(t, idx1, np.add(fingerprint{2, 3, 10}, 1, noIndex, noIndex)) + require.Equal(t, idx2, np.add(fingerprint{2, 3, 11}, 1, noIndex, noIndex)) + require.Equal(t, idx3, np.add(fingerprint{2, 3, 12}, 1, noIndex, noIndex)) + + // the free list is exhausted + require.Equal(t, nodeIndex(6), np.add(fingerprint{2, 3, 13}, 1, noIndex, noIndex)) +} + func TestPrefix(t *testing.T) { for _, tc := range []struct { - p prefix - s string - bits uint64 - len int - left prefix - right prefix - gotCacheIndex bool - cacheIndex cacheIndex + p prefix + s string + bits uint64 + len int + left prefix + right prefix + shift prefix }{ { - p: 0, - s: "<0>", - len: 0, - bits: 0, - left: 0b0_000001, - right: 0b1_000001, - gotCacheIndex: true, - cacheIndex: 0, + p: 0, + s: "<0>", + len: 0, + bits: 0, + left: 0b0_000001, + right: 0b1_000001, }, { - p: 0b0_000001, - s: "<1:0>", - len: 1, - bits: 0, - left: 0b00_000010, - right: 0b01_000010, - gotCacheIndex: true, - cacheIndex: 1, + p: 0b0_000001, + s: "<1:0>", + len: 1, + bits: 0, + left: 0b00_000010, + right: 0b01_000010, + shift: 0, }, { - p: 0b1_000001, - s: "<1:1>", - len: 1, - bits: 1, - left: 0b10_000010, - right: 0b11_000010, - gotCacheIndex: true, - cacheIndex: 2, + p: 0b1_000001, + s: "<1:1>", + len: 1, + bits: 1, + left: 0b10_000010, + right: 0b11_000010, + shift: 0, }, { - p: 0b00_000010, - s: "<2:00>", - len: 2, - bits: 0, - left: 0b000_000011, - right: 0b001_000011, - gotCacheIndex: true, - cacheIndex: 3, + p: 0b00_000010, + s: "<2:00>", + len: 2, + bits: 0, + left: 0b000_000011, + right: 0b001_000011, + shift: 0b0_000001, }, { - p: 0b01_000010, - s: "<2:01>", - len: 2, - bits: 1, - left: 0b010_000011, - right: 0b011_000011, - gotCacheIndex: true, - cacheIndex: 4, + p: 0b01_000010, + s: "<2:01>", + len: 2, + bits: 1, + left: 0b010_000011, + right: 0b011_000011, + shift: 0b1_000001, }, { - p: 0b10_000010, - s: "<2:10>", - len: 2, - bits: 2, - left: 0b100_000011, - right: 0b101_000011, - gotCacheIndex: true, - cacheIndex: 5, + p: 0b10_000010, + s: "<2:10>", + len: 2, + bits: 2, + left: 0b100_000011, + right: 0b101_000011, + shift: 0b0_000001, }, { - p: 0b11_000010, - s: "<2:11>", - len: 2, - bits: 3, - left: 0b110_000011, - right: 0b111_000011, - gotCacheIndex: true, - cacheIndex: 6, + p: 0b11_000010, + s: "<2:11>", + len: 2, + bits: 3, + left: 0b110_000011, + right: 0b111_000011, + shift: 0b1_000001, }, { - p: 0x3fffffd8, - s: "<24:111111111111111111111111>", - len: 24, - bits: 0xffffff, - left: 0x7fffff99, - right: 0x7fffffd9, - gotCacheIndex: true, - cacheIndex: 0x1fffffe, + p: 0x3fffffd8, + s: "<24:111111111111111111111111>", + len: 24, + bits: 0xffffff, + left: 0x7fffff99, + right: 0x7fffffd9, + shift: 0x1fffffd7, }, { - p: 0x7fffff99, - s: "<25:1111111111111111111111110>", - len: 25, - bits: 0x1fffffe, - left: 0xffffff1a, - right: 0xffffff5a, - gotCacheIndex: false, // len > 24 + p: 0x7fffff99, + s: "<25:1111111111111111111111110>", + len: 25, + bits: 0x1fffffe, + left: 0xffffff1a, + right: 0xffffff5a, + shift: 0x3fffff98, }, } { require.Equal(t, tc.s, tc.p.String()) @@ -122,85 +157,83 @@ func TestPrefix(t *testing.T) { require.Equal(t, tc.len, tc.p.len()) require.Equal(t, tc.left, tc.p.left()) require.Equal(t, tc.right, tc.p.right()) - idx, gotIdx := tc.p.cacheIndex() - require.Equal(t, tc.gotCacheIndex, gotIdx) - if gotIdx { - require.Equal(t, tc.cacheIndex, idx) + if tc.p != 0 { + require.Equal(t, tc.shift, tc.p.shift()) } } } -func TestHashPrefix(t *testing.T) { - for _, tc := range []struct { - h string - l int - p prefix - preFirst0 prefix - preFirst1 prefix - }{ - { - h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", - l: 0, - p: 0, - preFirst0: 0b1_000001, - preFirst1: 0, - }, - { - h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", - l: 1, - p: 0b1_000001, - preFirst0: 0b1_000001, - preFirst1: 0, - }, - { - h: "2BCDEF1234567890000000000000000000000000000000000000000000000000", - l: 1, - p: 0b0_000001, - preFirst0: 0, - preFirst1: 0b00_000010, - }, - { - h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", - l: 4, - p: 0b1010_000100, - preFirst0: 0b1_000001, - preFirst1: 0, - }, - { - h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", - l: 57, - p: 0x55e6f7891a2b3c79, - preFirst0: 0b1_000001, - preFirst1: 0, - }, - { - h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", - l: 58, - p: 0xabcdef12345678ba, - preFirst0: 0b1_000001, - preFirst1: 0, - }, - { - h: "0000000000000000000000000000000000000000000000000000000000000000", - l: 0, - p: 0, - preFirst0: 0, - preFirst1: 58, - }, - { - h: "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", - l: 0, - p: 0, - preFirst0: 0xfffffffffffffffa, - preFirst1: 0, - }, - } { - h := types.HexToHash32(tc.h) - require.Equal(t, tc.p, hashPrefix(h[:], tc.l), "hash prefix: h %s l %d", tc.h, tc.l) - require.Equal(t, tc.preFirst0, preFirst0(h[:]), "preFirst0: h %s", tc.h) - require.Equal(t, tc.preFirst1, preFirst1(h[:]), "preFirst1: h %s", tc.h) - } -} +// func TestHashPrefix(t *testing.T) { +// for _, tc := range []struct { +// h string +// l int +// p prefix +// preFirst0 prefix +// preFirst1 prefix +// }{ +// { +// h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", +// l: 0, +// p: 0, +// preFirst0: 0b1_000001, +// preFirst1: 0, +// }, +// { +// h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", +// l: 1, +// p: 0b1_000001, +// preFirst0: 0b1_000001, +// preFirst1: 0, +// }, +// { +// h: "2BCDEF1234567890000000000000000000000000000000000000000000000000", +// l: 1, +// p: 0b0_000001, +// preFirst0: 0, +// preFirst1: 0b00_000010, +// }, +// { +// h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", +// l: 4, +// p: 0b1010_000100, +// preFirst0: 0b1_000001, +// preFirst1: 0, +// }, +// { +// h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", +// l: 57, +// p: 0x55e6f7891a2b3c79, +// preFirst0: 0b1_000001, +// preFirst1: 0, +// }, +// { +// h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", +// l: 58, +// p: 0xabcdef12345678ba, +// preFirst0: 0b1_000001, +// preFirst1: 0, +// }, +// { +// h: "0000000000000000000000000000000000000000000000000000000000000000", +// l: 0, +// p: 0, +// preFirst0: 0, +// preFirst1: 58, +// }, +// { +// h: "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", +// l: 0, +// p: 0, +// preFirst0: 0xfffffffffffffffa, +// preFirst1: 0, +// }, +// } { +// h := types.HexToHash32(tc.h) +// require.Equal(t, tc.p, hashPrefix(h[:], tc.l), "hash prefix: h %s l %d", tc.h, tc.l) +// require.Equal(t, tc.preFirst0, preFirst0(h[:]), "preFirst0: h %s", tc.h) +// require.Equal(t, tc.preFirst1, preFirst1(h[:]), "preFirst1: h %s", tc.h) +// } +// } func TestCommonPrefix(t *testing.T) { for _, tc := range []struct { @@ -282,65 +315,83 @@ func TestRmme(t *testing.T) { } func TestInMemFPTree(t *testing.T) { - var mft inMemFPTree - var hs []types.Hash32 - for _, hex := range []string{ - "0000000000000000000000000000000000000000000000000000000000000000", - "123456789ABCDEF0000000000000000000000000000000000000000000000000", - "5555555555555555555555555555555555555555555555555555555555555555", - "8888888888888888888888888888888888888888888888888888888888888888", - "ABCDEF1234567890000000000000000000000000000000000000000000000000", + for _, tc := range []struct { + name string + ids []string + results map[[2]int]fpResult + }{ + { + name: "ids1", + ids: []string{ + "0000000000000000000000000000000000000000000000000000000000000000", + "123456789ABCDEF0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "ABCDEF1234567890000000000000000000000000000000000000000000000000", + }, + results: map[[2]int]fpResult{ + {0, 0}: { + fp: hexToFingerprint("642464b773377bbddddddddd"), + count: 5, + }, + {4, 4}: { + fp: hexToFingerprint("642464b773377bbddddddddd"), + count: 5, + }, + {0, 1}: { + fp: hexToFingerprint("000000000000000000000000"), + count: 1, + }, + {1, 4}: { + fp: hexToFingerprint("cfe98ba54761032ddddddddd"), + count: 3, + }, + }, + }, + { + name: "ids2", + ids: []string{ + "829977b444c8408dcddc1210536f3b3bdc7fd97777426264b9ac8f70b97a7fd1", + "6e476ca729c3840d0118785496e488124ee7dade1aef0c87c6edc78f72e4904f", + "a280bcb8123393e0d4a15e5c9850aab5dddffa03d5efa92e59bc96202e8992bc", + "e93163f908630280c2a8bffd9930aa684be7a3085432035f5c641b0786590d1d", + }, + results: map[[2]int]fpResult{ + {0, 0}: { + fp: hexToFingerprint("a76fc452775b55e0dacd8be5"), + count: 4, + }, + {0, 3}: { + fp: hexToFingerprint("2019cb0c56fbd36d197d4c4c"), + count: 2, + }, + }, + }, } { - t.Logf("QQQQQ: ADD: %s", hex) - h := types.HexToHash32(hex) - hs = append(hs, h) - mft.addHash(h[:]) - } - var sb strings.Builder - mft.tree.dump(&sb) - t.Logf("QQQQQ: tree:\n%s", sb.String()) - require.Equal(t, hexToFingerprint("642464b773377bbddddddddd"), mft.tree.nodes[0].fp) - require.Equal(t, fpResult{ - fp: hexToFingerprint("642464b773377bbddddddddd"), - count: 5, - }, mft.aggregateInterval(hs[0][:], hs[0][:])) - require.Equal(t, fpResult{ - fp: hexToFingerprint("642464b773377bbddddddddd"), - count: 5, - }, mft.aggregateInterval(hs[4][:], hs[4][:])) - require.Equal(t, fpResult{ - fp: hexToFingerprint("000000000000000000000000"), - count: 1, - }, mft.aggregateInterval(hs[0][:], hs[1][:])) - require.Equal(t, fpResult{ - fp: hexToFingerprint("cfe98ba54761032ddddddddd"), - count: 3, - }, mft.aggregateInterval(hs[1][:], hs[4][:])) - // TBD: test reverse range -} + t.Run(tc.name, func(t *testing.T) { + var np nodePool + mft := newInMemFPTree(&np, 24) + var hs []types.Hash32 + for _, hex := range tc.ids { + t.Logf("QQQQQ: ADD: %s", hex) + h := types.HexToHash32(hex) + hs = append(hs, h) + mft.addHash(h[:]) + } -func TestInMemFPTreeRmme1(t *testing.T) { - var mft inMemFPTree - var hs []types.Hash32 - for _, hex := range []string{ - "829977b444c8408dcddc1210536f3b3bdc7fd97777426264b9ac8f70b97a7fd1", - "6e476ca729c3840d0118785496e488124ee7dade1aef0c87c6edc78f72e4904f", - "a280bcb8123393e0d4a15e5c9850aab5dddffa03d5efa92e59bc96202e8992bc", - "e93163f908630280c2a8bffd9930aa684be7a3085432035f5c641b0786590d1d", - } { - t.Logf("QQQQQ: ADD: %s", hex) - h := types.HexToHash32(hex) - hs = append(hs, h) - mft.addHash(h[:]) + var sb strings.Builder + mft.tree.dump(&sb) + t.Logf("tree:\n%s", sb.String()) + + checkTree(t, mft.tree, 24) + + for idRange, fpResult := range tc.results { + x := hs[idRange[0]] + y := hs[idRange[1]] + require.Equal(t, fpResult, mft.aggregateInterval(x[:], y[:])) + } + }) } - var sb strings.Builder - mft.tree.dump(&sb) - t.Logf("QQQQQ: tree:\n%s", sb.String()) - require.Equal(t, hexToFingerprint("a76fc452775b55e0dacd8be5"), mft.tree.nodes[0].fp) - require.Equal(t, fpResult{ - fp: hexToFingerprint("2019cb0c56fbd36d197d4c4c"), - count: 2, - }, mft.aggregateInterval(hs[0][:], hs[3][:])) } type hashList []types.Hash32 @@ -352,16 +403,56 @@ func (l hashList) findGTE(h types.Hash32) int { return p } -func TestInMemFPTreeManyItems(t *testing.T) { - var mft inMemFPTree - const numItems = 1 << 20 +func checkNode(t *testing.T, ft *fpTree, idx nodeIndex, depth int) { + node := ft.np.node(idx) + if node.left == noIndex && node.right == noIndex { + if node.c != 1 { + require.Equal(t, depth, ft.maxDepth) + } + } else { + require.Less(t, depth, ft.maxDepth) + var expFP fingerprint + var expCount uint32 + if node.left != noIndex { + checkNode(t, ft, node.left, depth+1) + left := ft.np.node(node.left) + expFP.update(left.fp[:]) + expCount += left.c + } + if node.right != noIndex { + checkNode(t, ft, node.right, depth+1) + right := ft.np.node(node.right) + expFP.update(right.fp[:]) + expCount += right.c + } + require.Equal(t, expFP, node.fp, "node fp at depth %d", depth) + require.Equal(t, expCount, node.c, "node count at depth %d", depth) + } +} + +func checkTree(t *testing.T, ft *fpTree, maxDepth int) { + require.Equal(t, maxDepth, ft.maxDepth) + checkNode(t, ft, ft.root, 0) + +} + +func testInMemFPTreeManyItems(t *testing.T, randomXY bool) { + var np nodePool + const ( + numItems = 1 << 16 + maxDepth = 24 + ) + mft := newInMemFPTree(&np, maxDepth) hs := make(hashList, numItems) var fp fingerprint + rmmeMap := make(map[types.Hash32]bool) for i := range hs { h := types.RandomHash() hs[i] = h mft.addHash(h[:]) fp.update(h[:]) + require.False(t, rmmeMap[h]) + rmmeMap[h] = true } // var sb strings.Builder // mft.tree.dump(&sb) @@ -372,14 +463,28 @@ func TestInMemFPTreeManyItems(t *testing.T) { // for i, h := range hs { // t.Logf("h[%d] = %s", i, h.String()) // } - require.Equal(t, fp, mft.tree.nodes[0].fp) + + total := 0 + nums := make(map[int]int) + for _, ids := range mft.ids { + nums[len(ids)]++ + total += len(ids) + } + t.Logf("total %d, numItems %d, nums %#v", total, numItems, nums) + + checkTree(t, mft.tree, maxDepth) + + require.Equal(t, fpResult{fp: fp, count: numItems}, mft.aggregateInterval(hs[0][:], hs[0][:])) for i := 0; i < 100; i++ { // TBD: allow reverse order - // TBD: pick some intervals from the hashes - x := types.RandomHash() - y := types.RandomHash() - // x := hs[rand.Intn(numItems)] - // y := hs[rand.Intn(numItems)] + var x, y types.Hash32 + if randomXY { + x = types.RandomHash() + y = types.RandomHash() + } else { + x = hs[rand.Intn(numItems)] + y = hs[rand.Intn(numItems)] + } c := bytes.Compare(x[:], y[:]) var ( expFP fingerprint @@ -406,4 +511,14 @@ func TestInMemFPTreeManyItems(t *testing.T) { count: expN, }, mft.aggregateInterval(x[:], y[:])) } + // TODO: test inverse intervals +} + +func TestInMemFPTreeManyItems(t *testing.T) { + t.Run("bounds from the set", func(t *testing.T) { + testInMemFPTreeManyItems(t, false) + }) + t.Run("random bounds", func(t *testing.T) { + testInMemFPTreeManyItems(t, true) + }) } diff --git a/sync2/dbsync/refcountpool.go b/sync2/dbsync/refcountpool.go new file mode 100644 index 0000000000..85f508a03a --- /dev/null +++ b/sync2/dbsync/refcountpool.go @@ -0,0 +1,75 @@ +package dbsync + +import "sync" + +const freeBit = 1 << 31 +const freeListMask = freeBit - 1 + +type poolEntry[T any, I ~uint32] struct { + refCount uint32 + content T +} + +type rcPool[T any, I ~uint32] struct { + mtx sync.Mutex + entries []poolEntry[T, I] + // freeList is 1-based so that rcPool doesn't need a constructor + freeList uint32 +} + +func (rc *rcPool[T, I]) item(idx I) T { + rc.mtx.Lock() + defer rc.mtx.Unlock() + return rc.entry(idx).content +} + +func (rc *rcPool[T, I]) entry(idx I) *poolEntry[T, I] { + entry := &rc.entries[idx] + if entry.refCount&freeBit != 0 { + panic("BUG: referencing a free nodePool entry") + } + return entry +} + +func (rc *rcPool[T, I]) add(item T) I { + rc.mtx.Lock() + defer rc.mtx.Unlock() + var idx I + // // validate indices + // if left != I(^uint32(0)) { + // rc.entry(left) + // } + // if right != I(^uint32(0)) { + // rc.entry(right) + // } + if rc.freeList != 0 { + idx = I(rc.freeList - 1) + rc.freeList = rc.entries[idx].refCount & freeListMask + rc.entries[idx].refCount = 1 + } else { + idx = I(len(rc.entries)) + rc.entries = append(rc.entries, poolEntry[T, I]{refCount: 1}) + } + rc.entries[idx].content = item + return idx +} + +func (rc *rcPool[T, I]) release(idx I) { + rc.mtx.Lock() + defer rc.mtx.Unlock() + entry := &rc.entries[idx] + if entry.refCount <= 0 { + panic("BUG: negative rcPool[T, I] entry refcount") + } + entry.refCount-- + if entry.refCount == 0 { + entry.refCount = rc.freeList | freeBit + rc.freeList = uint32(idx + 1) + } +} + +func (rc *rcPool[T, I]) ref(idx nodeIndex) { + rc.mtx.Lock() + rc.entries[idx].refCount++ + rc.mtx.Unlock() +} From df3e554a022a86f946875d4191a7c19a27e24b21 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 19 Jun 2024 19:13:43 +0400 Subject: [PATCH 33/76] test inverse intervals --- sync2/dbsync/dbsync.go | 3 +++ sync2/dbsync/dbsync_test.go | 50 ++++++++++++++++++++++++++++--------- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/sync2/dbsync/dbsync.go b/sync2/dbsync/dbsync.go index 4a9755b9dd..56813569f8 100644 --- a/sync2/dbsync/dbsync.go +++ b/sync2/dbsync/dbsync.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "math/bits" + // "os" "slices" "strconv" ) @@ -526,6 +527,7 @@ func (ft *fpTree) aggregateInterval(x, y []byte) aggResult { // inverse interval: [min; y); [x; max] pf1 := preFirst1(y) idx1, found := ft.followPrefix(ft.root, pf1) + // fmt.Fprintf(os.Stderr, "QQQQQ: pf1 %s idx1 %d found %v\n", pf1, idx1, found) switch { case found: ft.aggregateRight(idx1, load64(y)< 0 { - x, y = y, x - } - if c == 0 { - expFP = fp - expN = numItems - } else { + switch bytes.Compare(x[:], y[:]) { + case -1: pX := hs.findGTE(x) pY := hs.findGTE(y) // t.Logf("x=%s y=%s pX=%d y=%d", x.String(), y.String(), pX, pY) @@ -505,13 +519,25 @@ func testInMemFPTreeManyItems(t *testing.T, randomXY bool) { expFP.update(hs[p][:]) } expN = uint32(pY - pX) + case 1: + pX := hs.findGTE(x) + pY := hs.findGTE(y) + for p := 0; p < pY; p++ { + expFP.update(hs[p][:]) + } + for p := pX; p < len(hs); p++ { + expFP.update(hs[p][:]) + } + expN = uint32(pY + len(hs) - pX) + default: + expFP = fp + expN = numItems } require.Equal(t, fpResult{ fp: expFP, count: expN, }, mft.aggregateInterval(x[:], y[:])) } - // TODO: test inverse intervals } func TestInMemFPTreeManyItems(t *testing.T) { From 949f2071b94d7394be5e22a5086ce65fefeb17f6 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 19 Jun 2024 21:33:38 +0400 Subject: [PATCH 34/76] cleanup --- sync2/dbsync/dbsync_test.go | 78 ------------------------------------ sync2/dbsync/refcountpool.go | 9 +---- 2 files changed, 2 insertions(+), 85 deletions(-) diff --git a/sync2/dbsync/dbsync_test.go b/sync2/dbsync/dbsync_test.go index 1862745385..cae90c9bc8 100644 --- a/sync2/dbsync/dbsync_test.go +++ b/sync2/dbsync/dbsync_test.go @@ -163,78 +163,6 @@ func TestPrefix(t *testing.T) { } } -// func TestHashPrefix(t *testing.T) { -// for _, tc := range []struct { -// h string -// l int -// p prefix -// preFirst0 prefix -// preFirst1 prefix -// }{ -// { -// h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", -// l: 0, -// p: 0, -// preFirst0: 0b1_000001, -// preFirst1: 0, -// }, -// { -// h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", -// l: 1, -// p: 0b1_000001, -// preFirst0: 0b1_000001, -// preFirst1: 0, -// }, -// { -// h: "2BCDEF1234567890000000000000000000000000000000000000000000000000", -// l: 1, -// p: 0b0_000001, -// preFirst0: 0, -// preFirst1: 0b00_000010, -// }, -// { -// h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", -// l: 4, -// p: 0b1010_000100, -// preFirst0: 0b1_000001, -// preFirst1: 0, -// }, -// { -// h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", -// l: 57, -// p: 0x55e6f7891a2b3c79, -// preFirst0: 0b1_000001, -// preFirst1: 0, -// }, -// { -// h: "ABCDEF1234567890000000000000000000000000000000000000000000000000", -// l: 58, -// p: 0xabcdef12345678ba, -// preFirst0: 0b1_000001, -// preFirst1: 0, -// }, -// { -// h: "0000000000000000000000000000000000000000000000000000000000000000", -// l: 0, -// p: 0, -// preFirst0: 0, -// preFirst1: 58, -// }, -// { -// h: "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", -// l: 0, -// p: 0, -// preFirst0: 0xfffffffffffffffa, -// preFirst1: 0, -// }, -// } { -// h := types.HexToHash32(tc.h) -// require.Equal(t, tc.p, hashPrefix(h[:], tc.l), "hash prefix: h %s l %d", tc.h, tc.l) -// require.Equal(t, tc.preFirst0, preFirst0(h[:]), "preFirst0: h %s", tc.h) -// require.Equal(t, tc.preFirst1, preFirst1(h[:]), "preFirst1: h %s", tc.h) -// } -// } - func TestCommonPrefix(t *testing.T) { for _, tc := range []struct { a, b string @@ -474,15 +402,9 @@ func testInMemFPTreeManyItems(t *testing.T, randomXY bool) { require.False(t, rmmeMap[h]) rmmeMap[h] = true } - // var sb strings.Builder - // mft.tree.dump(&sb) - // t.Logf("QQQQQ: tree:\n%s", sb.String()) slices.SortFunc(hs, func(a, b types.Hash32) int { return a.Compare(b) }) - // for i, h := range hs { - // t.Logf("h[%d] = %s", i, h.String()) - // } total := 0 nums := make(map[int]int) diff --git a/sync2/dbsync/refcountpool.go b/sync2/dbsync/refcountpool.go index 85f508a03a..5e357aa819 100644 --- a/sync2/dbsync/refcountpool.go +++ b/sync2/dbsync/refcountpool.go @@ -35,13 +35,6 @@ func (rc *rcPool[T, I]) add(item T) I { rc.mtx.Lock() defer rc.mtx.Unlock() var idx I - // // validate indices - // if left != I(^uint32(0)) { - // rc.entry(left) - // } - // if right != I(^uint32(0)) { - // rc.entry(right) - // } if rc.freeList != 0 { idx = I(rc.freeList - 1) rc.freeList = rc.entries[idx].refCount & freeListMask @@ -73,3 +66,5 @@ func (rc *rcPool[T, I]) ref(idx nodeIndex) { rc.entries[idx].refCount++ rc.mtx.Unlock() } + +// TODO: convert TestNodePool to TestRCPool From 3a8b8195a9fbbc6e28b9a3a5769342a3e345488a Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 20 Jun 2024 05:09:08 +0400 Subject: [PATCH 35/76] separate test for rcpool --- sync2/dbsync/dbsync_test.go | 45 ----------------------------- sync2/dbsync/refcountpool.go | 2 +- sync2/dbsync/refcountpool_test.go | 48 +++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 46 deletions(-) create mode 100644 sync2/dbsync/refcountpool_test.go diff --git a/sync2/dbsync/dbsync_test.go b/sync2/dbsync/dbsync_test.go index cae90c9bc8..36889ad329 100644 --- a/sync2/dbsync/dbsync_test.go +++ b/sync2/dbsync/dbsync_test.go @@ -16,51 +16,6 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/statesql" ) -func TestNodePool(t *testing.T) { - // TODO: convert to TestRCPool - var np nodePool - idx1 := np.add(fingerprint{1, 2, 3}, 1, noIndex, noIndex) - node1 := np.node(idx1) - idx2 := np.add(fingerprint{2, 3, 4}, 2, idx1, noIndex) - node2 := np.node(idx2) - require.Equal(t, fingerprint{1, 2, 3}, node1.fp) - require.Equal(t, uint32(1), node1.c) - require.Equal(t, noIndex, node1.left) - require.Equal(t, noIndex, node1.right) - require.Equal(t, fingerprint{2, 3, 4}, node2.fp) - require.Equal(t, uint32(2), node2.c) - require.Equal(t, idx1, node2.left) - require.Equal(t, noIndex, node2.right) - idx3 := np.add(fingerprint{2, 3, 5}, 1, noIndex, noIndex) - idx4 := np.add(fingerprint{2, 3, 6}, 1, idx2, idx3) - require.Equal(t, nodeIndex(3), idx4) - np.ref(idx4) - - np.release(idx4) - // not yet released due to an extra ref - require.Equal(t, nodeIndex(4), np.add(fingerprint{2, 3, 7}, 1, noIndex, noIndex)) - - np.release(idx4) - // idx4 was freed - require.Equal(t, idx4, np.add(fingerprint{2, 3, 8}, 1, noIndex, noIndex)) - - // free item used just once - require.Equal(t, nodeIndex(5), np.add(fingerprint{2, 3, 9}, 1, noIndex, noIndex)) - - // form a free list - np.release(idx3) - np.release(idx2) - np.release(idx1) - - // the free list is LIFO - require.Equal(t, idx1, np.add(fingerprint{2, 3, 10}, 1, noIndex, noIndex)) - require.Equal(t, idx2, np.add(fingerprint{2, 3, 11}, 1, noIndex, noIndex)) - require.Equal(t, idx3, np.add(fingerprint{2, 3, 12}, 1, noIndex, noIndex)) - - // the free list is exhausted - require.Equal(t, nodeIndex(6), np.add(fingerprint{2, 3, 13}, 1, noIndex, noIndex)) -} - func TestPrefix(t *testing.T) { for _, tc := range []struct { p prefix diff --git a/sync2/dbsync/refcountpool.go b/sync2/dbsync/refcountpool.go index 5e357aa819..211831af83 100644 --- a/sync2/dbsync/refcountpool.go +++ b/sync2/dbsync/refcountpool.go @@ -61,7 +61,7 @@ func (rc *rcPool[T, I]) release(idx I) { } } -func (rc *rcPool[T, I]) ref(idx nodeIndex) { +func (rc *rcPool[T, I]) ref(idx I) { rc.mtx.Lock() rc.entries[idx].refCount++ rc.mtx.Unlock() diff --git a/sync2/dbsync/refcountpool_test.go b/sync2/dbsync/refcountpool_test.go new file mode 100644 index 0000000000..de637c61d7 --- /dev/null +++ b/sync2/dbsync/refcountpool_test.go @@ -0,0 +1,48 @@ +package dbsync + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRCPool(t *testing.T) { + type foo struct{ x int } + type fooIndex uint32 + // TODO: convert to TestRCPool + var pool rcPool[foo, fooIndex] + idx1 := pool.add(foo{x: 1}) + foo1 := pool.item(idx1) + idx2 := pool.add(foo{x: 2}) + foo2 := pool.item(idx2) + require.Equal(t, foo{x: 1}, foo1) + require.Equal(t, foo{x: 2}, foo2) + idx3 := pool.add(foo{x: 3}) + idx4 := pool.add(foo{x: 4}) + require.Equal(t, fooIndex(3), idx4) + pool.ref(idx4) + + pool.release(idx4) + // not yet released due to an extra ref + require.Equal(t, fooIndex(4), pool.add(foo{x: 5})) + + pool.release(idx4) + // idx4 was freed + require.Equal(t, idx4, pool.add(foo{x: 6})) + + // free item used just once + require.Equal(t, fooIndex(5), pool.add(foo{x: 7})) + + // form a free list containing several items + pool.release(idx3) + pool.release(idx2) + pool.release(idx1) + + // the free list is LIFO + require.Equal(t, idx1, pool.add(foo{x: 8})) + require.Equal(t, idx2, pool.add(foo{x: 9})) + require.Equal(t, idx3, pool.add(foo{x: 10})) + + // the free list is exhausted + require.Equal(t, fooIndex(6), pool.add(foo{x: 11})) +} From 2cccc052748c09e9c1ac9b688edc1fb34dd1429b Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 20 Jun 2024 12:16:35 +0400 Subject: [PATCH 36/76] sql database --- sync2/dbsync/dbsync.go | 220 +++++++++++++++----- sync2/dbsync/dbsync_test.go | 389 +++++++++++++++++++++++++++++------- 2 files changed, 484 insertions(+), 125 deletions(-) diff --git a/sync2/dbsync/dbsync.go b/sync2/dbsync/dbsync.go index 56813569f8..04954957bf 100644 --- a/sync2/dbsync/dbsync.go +++ b/sync2/dbsync/dbsync.go @@ -7,9 +7,11 @@ import ( "fmt" "io" "math/bits" - // "os" "slices" "strconv" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/sql" ) const ( @@ -203,6 +205,10 @@ const ( type prefix uint64 +func mkprefix(bits uint64, l int) prefix { + return prefix(bits<>(p.len()-1) != 0 } +func (p prefix) minID(b []byte) { + if len(b) < 8 { + panic("BUG: id slice too small") + } + v := p.bits() << (64 - p.len()) + binary.BigEndian.PutUint64(b, v) + for n := 8; n < len(b); n++ { + b[n] = 0 + } +} + +func (p prefix) maxID(b []byte) { + if len(b) < 8 { + panic("BUG: id slice too small") + } + s := uint64(64 - p.len()) + v := (p.bits() << s) | ((1 << s) - 1) + binary.BigEndian.PutUint64(b, v) + for n := 8; n < len(b); n++ { + b[n] = 0xff + } +} + // shift removes the highest bit from the prefix -// TBD: QQQQQ: test shift func (p prefix) shift() prefix { - switch l := uint64(p.len()); l { + switch l := p.len(); l { case 0: panic("BUG: can't shift zero prefix") case 1: return 0 default: - return prefix(((p.bits() & ((1 << (l - 1)) - 1)) << prefixLenBits) + l - 1) + l-- + return mkprefix(p.bits()&((1<>(64-l))<>(64-l), l) +} + +type fpResult struct { + fp fingerprint + count uint32 } type aggResult struct { @@ -302,14 +336,20 @@ func (r *aggResult) update(node node) { // fmt.Fprintf(os.Stderr, "QQQQQ: r.count <= %d r.fp <= %s\n", r.count, r.fp) } +type idStore interface { + registerHash(h []byte, maxDepth int) error + iterateIDs(tailRefs []uint64, maxDepth int, toCall func(id []byte)) error +} + type fpTree struct { np *nodePool + idStore idStore root nodeIndex maxDepth int } -func newFPTree(np *nodePool, maxDepth int) *fpTree { - return &fpTree{np: np, root: noIndex, maxDepth: maxDepth} +func newFPTree(np *nodePool, idStore idStore, maxDepth int) *fpTree { + return &fpTree{np: np, idStore: idStore, root: noIndex, maxDepth: maxDepth} } func (ft *fpTree) pushDown(fpA, fpB fingerprint, p prefix, curCount uint32) nodeIndex { @@ -378,10 +418,11 @@ func (ft *fpTree) addValue(fp fingerprint, p prefix, idx nodeIndex) nodeIndex { } } -func (ft *fpTree) addHash(h []byte) { +func (ft *fpTree) addHash(h []byte) error { var fp fingerprint fp.update(h) ft.root = ft.addValue(fp, 0, ft.root) + return ft.idStore.registerHash(h, ft.maxDepth) // fmt.Fprintf(os.Stderr, "QQQQQ: addHash: new root %d\n", ft.root) } @@ -402,9 +443,11 @@ func (ft *fpTree) followPrefix(from nodeIndex, p prefix) (nodeIndex, bool) { } func (ft *fpTree) tailRefFromPrefix(p prefix) uint64 { - if p.len() != ft.maxDepth { - panic("BUG: tail from short prefix") - } + // TODO: QQQQ: FIXME: this may happen with reverse intervals, + // but should we even be checking the prefixes in this case? + // if p.len() != ft.maxDepth { + // panic("BUG: tail from short prefix") + // } return p.bits() } @@ -550,6 +593,19 @@ func (ft *fpTree) aggregateInterval(x, y []byte) aggResult { return r } +func (ft *fpTree) fingerprintInterval(x, y []byte) (fpResult, error) { + r := ft.aggregateInterval(x, y) + if err := ft.idStore.iterateIDs(r.tailRefs, ft.maxDepth, func(id []byte) { + if idWithinInterval(id, x, y, r.itype) { + r.fp.update(id) + r.count++ + } + }); err != nil { + return fpResult{}, err + } + return fpResult{fp: r.fp, count: r.count}, nil +} + func (ft *fpTree) dumpNode(w io.Writer, idx nodeIndex, indent, dir string) { if idx == noIndex { return @@ -578,58 +634,126 @@ func (ft *fpTree) dump(w io.Writer) { } } -type inMemFPTree struct { - tree *fpTree - ids [][][]byte +type memIDStore struct { + ids [][][]byte } -type fpResult struct { - fp fingerprint - count uint32 -} +var _ idStore = &memIDStore{} -func newInMemFPTree(np *nodePool, maxDepth int) *inMemFPTree { - if maxDepth == 0 { - panic("BUG: can't use newInMemFPTree with zero maxDepth") +func (m *memIDStore) registerHash(h []byte, maxDepth int) error { + if m.ids == nil { + m.ids = make([][][]byte, 1<> (64 - mft.tree.maxDepth) - s := mft.ids[idx] + idx := load64(h) >> (64 - maxDepth) + s := m.ids[idx] n := slices.IndexFunc(s, func(cur []byte) bool { return bytes.Compare(cur, h) > 0 }) if n < 0 { - mft.ids[idx] = append(s, h) + m.ids[idx] = append(s, h) } else { - mft.ids[idx] = slices.Insert(s, n, h) + m.ids[idx] = slices.Insert(s, n, h) } + return nil } -func (mft *inMemFPTree) aggregateInterval(x, y []byte) fpResult { - r := mft.tree.aggregateInterval(x, y) - for _, t := range r.tailRefs { - ids := mft.ids[t] +func (m *memIDStore) iterateIDs(tailRefs []uint64, maxDepth int, toCall func(id []byte)) error { + for _, t := range tailRefs { + ids := m.ids[t] for _, id := range ids { - // FIXME: this can be optimized as the IDs are ordered - if idWithinInterval(id, x, y, r.itype) { - // fmt.Fprintf(os.Stderr, "QQQQQ: including tail: %s\n", hex.EncodeToString(id)) - r.fp.update(id) - r.count++ - } else { - // fmt.Fprintf(os.Stderr, "QQQQQ: NOT including tail: %s\n", hex.EncodeToString(id)) - } + toCall(id) + } + } + return nil +} + +type sqlIDStore struct { + db sql.StateDatabase +} + +func newSQLIDStore(db sql.StateDatabase) *sqlIDStore { + return &sqlIDStore{db: db} +} + +func (s *sqlIDStore) registerHash(h []byte, maxDepth int) error { + // should be registered by the handler code + return nil +} + +func (s *sqlIDStore) iterateIDs(tailRefs []uint64, maxDepth int, toCall func(id []byte)) error { + for _, t := range tailRefs { + p := mkprefix(t, maxDepth) + var minID, maxID types.Hash32 + p.minID(minID[:]) + p.maxID(maxID[:]) + // start := time.Now() + if _, err := s.db.Exec( + "select id from atxs where id between ? and ?", + func(stmt *sql.Statement) { + stmt.BindBytes(1, minID[:]) + stmt.BindBytes(2, maxID[:]) + }, + func(stmt *sql.Statement) bool { + var id types.Hash32 + stmt.ColumnBytes(0, id[:]) + toCall(id[:]) + return true + }, + ); err != nil { + return err } + // fmt.Fprintf(os.Stderr, "QQQQQ: %v: sel atxs between %s and %s\n", time.Now().Sub(start), minID.String(), maxID.String()) } - return fpResult{fp: r.fp, count: r.count} + return nil } +// type inMemFPTree struct { +// tree *fpTree +// ids [][][]byte +// } + +// func newInMemFPTree(np *nodePool, maxDepth int) *inMemFPTree { +// if maxDepth == 0 { +// panic("BUG: can't use newInMemFPTree with zero maxDepth") +// } +// return &inMemFPTree{ +// tree: newFPTree(np, maxDepth), +// ids: make([][][]byte, 1<> (64 - mft.tree.maxDepth) +// s := mft.ids[idx] +// n := slices.IndexFunc(s, func(cur []byte) bool { +// return bytes.Compare(cur, h) > 0 +// }) +// if n < 0 { +// mft.ids[idx] = append(s, h) +// } else { +// mft.ids[idx] = slices.Insert(s, n, h) +// } +// } + +// func (mft *inMemFPTree) aggregateInterval(x, y []byte) fpResult { +// r := mft.tree.aggregateInterval(x, y) +// for _, t := range r.tailRefs { +// ids := mft.ids[t] +// for _, id := range ids { +// // FIXME: this can be optimized as the IDs are ordered +// if idWithinInterval(id, x, y, r.itype) { +// // fmt.Fprintf(os.Stderr, "QQQQQ: including tail: %s\n", hex.EncodeToString(id)) +// r.fp.update(id) +// r.count++ +// } else { +// // fmt.Fprintf(os.Stderr, "QQQQQ: NOT including tail: %s\n", hex.EncodeToString(id)) +// } +// } +// } +// return fpResult{fp: r.fp, count: r.count} +// } + func idWithinInterval(id, x, y []byte, itype int) bool { switch itype { case 0: diff --git a/sync2/dbsync/dbsync_test.go b/sync2/dbsync/dbsync_test.go index 36889ad329..45302e2ca5 100644 --- a/sync2/dbsync/dbsync_test.go +++ b/sync2/dbsync/dbsync_test.go @@ -2,14 +2,15 @@ package dbsync import ( "bytes" - "math/bits" + "fmt" "math/rand" + "runtime" "slices" "strings" "testing" + "time" "github.com/stretchr/testify/require" - "golang.org/x/exp/maps" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" @@ -25,6 +26,8 @@ func TestPrefix(t *testing.T) { left prefix right prefix shift prefix + minID string + maxID string }{ { p: 0, @@ -33,6 +36,8 @@ func TestPrefix(t *testing.T) { bits: 0, left: 0b0_000001, right: 0b1_000001, + minID: "0000000000000000000000000000000000000000000000000000000000000000", + maxID: "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", }, { p: 0b0_000001, @@ -42,6 +47,8 @@ func TestPrefix(t *testing.T) { left: 0b00_000010, right: 0b01_000010, shift: 0, + minID: "0000000000000000000000000000000000000000000000000000000000000000", + maxID: "7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", }, { p: 0b1_000001, @@ -51,6 +58,8 @@ func TestPrefix(t *testing.T) { left: 0b10_000010, right: 0b11_000010, shift: 0, + minID: "8000000000000000000000000000000000000000000000000000000000000000", + maxID: "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", }, { p: 0b00_000010, @@ -60,6 +69,8 @@ func TestPrefix(t *testing.T) { left: 0b000_000011, right: 0b001_000011, shift: 0b0_000001, + minID: "0000000000000000000000000000000000000000000000000000000000000000", + maxID: "3FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", }, { p: 0b01_000010, @@ -69,6 +80,8 @@ func TestPrefix(t *testing.T) { left: 0b010_000011, right: 0b011_000011, shift: 0b1_000001, + minID: "4000000000000000000000000000000000000000000000000000000000000000", + maxID: "7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", }, { p: 0b10_000010, @@ -78,6 +91,8 @@ func TestPrefix(t *testing.T) { left: 0b100_000011, right: 0b101_000011, shift: 0b0_000001, + minID: "8000000000000000000000000000000000000000000000000000000000000000", + maxID: "BFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", }, { p: 0b11_000010, @@ -87,6 +102,8 @@ func TestPrefix(t *testing.T) { left: 0b110_000011, right: 0b111_000011, shift: 0b1_000001, + minID: "C000000000000000000000000000000000000000000000000000000000000000", + maxID: "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", }, { p: 0x3fffffd8, @@ -96,6 +113,8 @@ func TestPrefix(t *testing.T) { left: 0x7fffff99, right: 0x7fffffd9, shift: 0x1fffffd7, + minID: "FFFFFF0000000000000000000000000000000000000000000000000000000000", + maxID: "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", }, { p: 0x7fffff99, @@ -105,16 +124,30 @@ func TestPrefix(t *testing.T) { left: 0xffffff1a, right: 0xffffff5a, shift: 0x3fffff98, + minID: "FFFFFF0000000000000000000000000000000000000000000000000000000000", + maxID: "FFFFFF7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", }, } { - require.Equal(t, tc.s, tc.p.String()) - require.Equal(t, tc.bits, tc.p.bits()) - require.Equal(t, tc.len, tc.p.len()) - require.Equal(t, tc.left, tc.p.left()) - require.Equal(t, tc.right, tc.p.right()) - if tc.p != 0 { - require.Equal(t, tc.shift, tc.p.shift()) - } + t.Run(fmt.Sprint(tc.p), func(t *testing.T) { + require.Equal(t, tc.s, tc.p.String()) + require.Equal(t, tc.bits, tc.p.bits()) + require.Equal(t, tc.len, tc.p.len()) + require.Equal(t, tc.left, tc.p.left()) + require.Equal(t, tc.right, tc.p.right()) + if tc.p != 0 { + require.Equal(t, tc.shift, tc.p.shift()) + } + + expMinID := types.HexToHash32(tc.minID) + var minID types.Hash32 + tc.p.minID(minID[:]) + require.Equal(t, expMinID, minID) + + expMaxID := types.HexToHash32(tc.maxID) + var maxID types.Hash32 + tc.p.maxID(maxID[:]) + require.Equal(t, expMaxID, maxID) + }) } } @@ -155,49 +188,29 @@ func TestCommonPrefix(t *testing.T) { } } -const dbFile = "/Users/ivan4th/Library/Application Support/Spacemesh/node-data/7c8cef2b/state.sql" +type fakeATXStore struct { + db sql.StateDatabase + *sqlIDStore +} -func TestRmme(t *testing.T) { - t.Skip("slow tmp test") - counts := make(map[uint64]uint64) - prefLens := make(map[int]int) - db, err := statesql.Open("file:" + dbFile) - require.NoError(t, err) - defer db.Close() - var prev uint64 - first := true - // where epoch=23 - _, err = db.Exec("select id from atxs order by id", nil, func(stmt *sql.Statement) bool { - var id types.Hash32 - stmt.ColumnBytes(0, id[:]) - v := load64(id[:]) - counts[v>>40]++ - if first { - first = false - } else { - prefLens[bits.LeadingZeros64(prev^v)]++ - } - prev = v - return true - }) - require.NoError(t, err) - countFreq := make(map[uint64]int) - for _, c := range counts { - countFreq[c]++ - } - ks := maps.Keys(countFreq) - slices.Sort(ks) - for _, c := range ks { - t.Logf("%d: %d times", c, countFreq[c]) - } - pls := maps.Keys(prefLens) - slices.Sort(pls) - for _, pl := range pls { - t.Logf("pl %d: %d times", pl, prefLens[pl]) +func newFakeATXIDStore(db sql.StateDatabase) *fakeATXStore { + return &fakeATXStore{db: db, sqlIDStore: newSQLIDStore(db)} +} + +func (s *fakeATXStore) registerHash(h []byte, maxDepth int) error { + if err := s.sqlIDStore.registerHash(h, maxDepth); err != nil { + return err } + _, err := s.db.Exec(` + insert into atxs (id, epoch, effective_num_units, received) + values (?, 1, 1, 0)`, + func(stmt *sql.Statement) { + stmt.BindBytes(1, h) + }, nil) + return err } -func TestInMemFPTree(t *testing.T) { +func testFPTree(t *testing.T, idStore idStore) { for _, tc := range []struct { name string ids []string @@ -207,10 +220,10 @@ func TestInMemFPTree(t *testing.T) { name: "ids1", ids: []string{ "0000000000000000000000000000000000000000000000000000000000000000", - "123456789ABCDEF0000000000000000000000000000000000000000000000000", + "123456789abcdef0000000000000000000000000000000000000000000000000", "5555555555555555555555555555555555555555555555555555555555555555", "8888888888888888888888888888888888888888888888888888888888888888", - "ABCDEF1234567890000000000000000000000000000000000000000000000000", + "abcdef1234567890000000000000000000000000000000000000000000000000", }, results: map[[2]int]fpResult{ {0, 0}: { @@ -273,30 +286,43 @@ func TestInMemFPTree(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { var np nodePool - mft := newInMemFPTree(&np, 24) + ft := newFPTree(&np, idStore, 24) var hs []types.Hash32 for _, hex := range tc.ids { t.Logf("add: %s", hex) h := types.HexToHash32(hex) hs = append(hs, h) - mft.addHash(h[:]) + ft.addHash(h[:]) } var sb strings.Builder - mft.tree.dump(&sb) + ft.dump(&sb) t.Logf("tree:\n%s", sb.String()) - checkTree(t, mft.tree, 24) + checkTree(t, ft, 24) - for idRange, fpResult := range tc.results { + for idRange, expResult := range tc.results { x := hs[idRange[0]] y := hs[idRange[1]] - require.Equal(t, fpResult, mft.aggregateInterval(x[:], y[:])) + fpr, err := ft.fingerprintInterval(x[:], y[:]) + require.NoError(t, err) + require.Equal(t, expResult, fpr) } }) } } +func TestFPTree(t *testing.T) { + t.Run("in-memory id store", func(t *testing.T) { + testFPTree(t, &memIDStore{}) + }) + t.Run("fake ATX store", func(t *testing.T) { + db := statesql.InMemory() + defer db.Close() + testFPTree(t, newFakeATXIDStore(db)) + }) +} + type hashList []types.Hash32 func (l hashList) findGTE(h types.Hash32) int { @@ -336,42 +362,32 @@ func checkNode(t *testing.T, ft *fpTree, idx nodeIndex, depth int) { func checkTree(t *testing.T, ft *fpTree, maxDepth int) { require.Equal(t, maxDepth, ft.maxDepth) checkNode(t, ft, ft.root, 0) - } -func testInMemFPTreeManyItems(t *testing.T, randomXY bool) { +func testInMemFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool) { var np nodePool const ( numItems = 1 << 16 maxDepth = 24 ) - mft := newInMemFPTree(&np, maxDepth) + ft := newFPTree(&np, idStore, maxDepth) hs := make(hashList, numItems) var fp fingerprint - rmmeMap := make(map[types.Hash32]bool) for i := range hs { h := types.RandomHash() hs[i] = h - mft.addHash(h[:]) + ft.addHash(h[:]) fp.update(h[:]) - require.False(t, rmmeMap[h]) - rmmeMap[h] = true } slices.SortFunc(hs, func(a, b types.Hash32) int { return a.Compare(b) }) - total := 0 - nums := make(map[int]int) - for _, ids := range mft.ids { - nums[len(ids)]++ - total += len(ids) - } - t.Logf("total %d, numItems %d, nums %#v", total, numItems, nums) - - checkTree(t, mft.tree, maxDepth) + checkTree(t, ft, maxDepth) - require.Equal(t, fpResult{fp: fp, count: numItems}, mft.aggregateInterval(hs[0][:], hs[0][:])) + fpr, err := ft.fingerprintInterval(hs[0][:], hs[0][:]) + require.NoError(t, err) + require.Equal(t, fpResult{fp: fp, count: numItems}, fpr) for i := 0; i < 100; i++ { // TBD: allow reverse order var x, y types.Hash32 @@ -410,18 +426,237 @@ func testInMemFPTreeManyItems(t *testing.T, randomXY bool) { expFP = fp expN = numItems } + fpr, err := ft.fingerprintInterval(x[:], y[:]) + require.NoError(t, err) require.Equal(t, fpResult{ fp: expFP, count: expN, - }, mft.aggregateInterval(x[:], y[:])) + }, fpr) } } func TestInMemFPTreeManyItems(t *testing.T) { t.Run("bounds from the set", func(t *testing.T) { - testInMemFPTreeManyItems(t, false) + var idStore memIDStore + testInMemFPTreeManyItems(t, &idStore, false) + total := 0 + nums := make(map[int]int) + for _, ids := range idStore.ids { + nums[len(ids)]++ + total += len(ids) + } + t.Logf("total %d, nums %#v", total, nums) + }) t.Run("random bounds", func(t *testing.T) { - testInMemFPTreeManyItems(t, true) + testInMemFPTreeManyItems(t, &memIDStore{}, true) + }) + t.Run("SQL, bounds from the set", func(t *testing.T) { + db := statesql.InMemory() + defer db.Close() + testInMemFPTreeManyItems(t, newFakeATXIDStore(db), false) + + }) + t.Run("SQL, random bounds", func(t *testing.T) { + db := statesql.InMemory() + defer db.Close() + testInMemFPTreeManyItems(t, newFakeATXIDStore(db), true) }) } + +const dbFile = "/Users/ivan4th/Library/Application Support/Spacemesh/node-data/7c8cef2b/state.sql" + +func dumbAggATXs(t *testing.T, db sql.StateDatabase, x, y types.Hash32) fpResult { + var fp fingerprint + ts := time.Now() + nRows, err := db.Exec( + // BETWEEN is faster than >= and < + "select id from atxs where id between ? and ?", + func(stmt *sql.Statement) { + stmt.BindBytes(1, x[:]) + stmt.BindBytes(2, y[:]) + }, + func(stmt *sql.Statement) bool { + var id types.Hash32 + stmt.ColumnBytes(0, id[:]) + if id != y { + fp.update(id[:]) + } + return true + }, + ) + require.NoError(t, err) + t.Logf("QQQQQ: %v: dumb fp between %s and %s", time.Now().Sub(ts), x.String(), y.String()) + return fpResult{ + fp: fp, + count: uint32(nRows), + } +} + +func testFP(t *testing.T, maxDepth int) { + runtime.GC() + var stats1 runtime.MemStats + runtime.ReadMemStats(&stats1) + // t.Skip("slow tmp test") + // counts := make(map[uint64]uint64) + // prefLens := make(map[int]int) + db, err := statesql.Open("file:" + dbFile) + require.NoError(t, err) + defer db.Close() + // _, err = db.Exec("PRAGMA cache_size = -2000000", nil, nil) + // require.NoError(t, err) + // var prev uint64 + // first := true + // where epoch=23 + store := newSQLIDStore(db) + var np nodePool + ft := newFPTree(&np, store, maxDepth) + t.Logf("loading IDs") + _, err = db.Exec("select id from atxs order by id", nil, func(stmt *sql.Statement) bool { + var id types.Hash32 + stmt.ColumnBytes(0, id[:]) + ft.addHash(id[:]) + // v := load64(id[:]) + // counts[v>>40]++ + // if first { + // first = false + // } else { + // prefLens[bits.LeadingZeros64(prev^v)]++ + // } + // prev = v + return true + }) + require.NoError(t, err) + // countFreq := make(map[uint64]int) + // for _, c := range counts { + // countFreq[c]++ + // } + // ks := maps.Keys(countFreq) + // slices.Sort(ks) + // for _, c := range ks { + // t.Logf("%d: %d times", c, countFreq[c]) + // } + // pls := maps.Keys(prefLens) + // slices.Sort(pls) + // for _, pl := range pls { + // t.Logf("pl %d: %d times", pl, prefLens[pl]) + // } + + t.Logf("benchmarking ranges") + ts := time.Now() + const numIter = 20000 + for n := 0; n < numIter; n++ { + x := types.RandomHash() + y := types.RandomHash() + ft.fingerprintInterval(x[:], y[:]) + } + elapsed := time.Now().Sub(ts) + + runtime.GC() + var stats2 runtime.MemStats + runtime.ReadMemStats(&stats2) + t.Logf("range benchmark for maxDepth %d: %v per range, %f ranges/s, heap diff %d", + // it's important to use ft pointer here so it doesn't get freed + // before we read the mem stats + ft.maxDepth, + elapsed/numIter, + float64(numIter)/elapsed.Seconds(), + stats2.HeapInuse-stats1.HeapInuse) + + // TBD: restore !!!! + // t.Logf("testing ranges") + // for n := 0; n < 10; n++ { + // x := types.RandomHash() + // y := types.RandomHash() + // // TBD: QQQQQ: dumb rev / full intervals + // if x == y { + // continue + // } + // if x.Compare(y) > 0 { + // x, y = y, x + // } + // expFPResult := dumbAggATXs(t, db, x, y) + // fpr, err := ft.fingerprintInterval(x[:], y[:]) + // require.NoError(t, err) + // require.Equal(t, expFPResult, fpr) + // } +} + +func TestFP(t *testing.T) { + t.Skip("slow test") + for maxDepth := 15; maxDepth <= 23; maxDepth++ { + for i := 0; i < 3; i++ { + testFP(t, maxDepth) + } + } +} + +// benchmarks + +// maxDepth 18: 94.739µs per range, 10555.290991 ranges/s, heap diff 16621568 +// maxDepth 18: 95.837µs per range, 10434.316922 ranges/s, heap diff 16564224 +// maxDepth 18: 95.312µs per range, 10491.834238 ranges/s, heap diff 16588800 +// maxDepth 19: 60.822µs per range, 16441.200726 ranges/s, heap diff 32317440 +// maxDepth 19: 57.86µs per range, 17283.084675 ranges/s, heap diff 32333824 +// maxDepth 19: 58.183µs per range, 17187.139809 ranges/s, heap diff 32342016 +// maxDepth 20: 41.582µs per range, 24048.516680 ranges/s, heap diff 63094784 +// maxDepth 20: 41.384µs per range, 24163.830753 ranges/s, heap diff 63102976 +// maxDepth 20: 42.003µs per range, 23807.631953 ranges/s, heap diff 63053824 +// maxDepth 21: 31.996µs per range, 31253.349138 ranges/s, heap diff 123289600 +// maxDepth 21: 31.926µs per range, 31321.766830 ranges/s, heap diff 123256832 +// maxDepth 21: 31.839µs per range, 31407.657854 ranges/s, heap diff 123256832 +// maxDepth 22: 27.829µs per range, 35933.122150 ranges/s, heap diff 240689152 +// maxDepth 22: 27.524µs per range, 36330.976995 ranges/s, heap diff 240689152 +// maxDepth 22: 27.386µs per range, 36514.410406 ranges/s, heap diff 240689152 +// maxDepth 23: 24.378µs per range, 41020.262869 ranges/s, heap diff 470024192 +// maxDepth 23: 24.605µs per range, 40641.096389 ranges/s, heap diff 470056960 +// maxDepth 23: 24.51µs per range, 40799.444720 ranges/s, heap diff 470040576 + +// maxDepth 18: 94.518µs per range, 10579.885738 ranges/s, heap diff 16621568 +// maxDepth 18: 95.144µs per range, 10510.332936 ranges/s, heap diff 16572416 +// maxDepth 18: 94.55µs per range, 10576.359829 ranges/s, heap diff 16588800 +// maxDepth 19: 60.463µs per range, 16538.974879 ranges/s, heap diff 32325632 +// maxDepth 19: 60.47µs per range, 16537.108181 ranges/s, heap diff 32358400 +// maxDepth 19: 60.441µs per range, 16544.939001 ranges/s, heap diff 32333824 +// maxDepth 20: 41.131µs per range, 24311.982297 ranges/s, heap diff 63078400 +// maxDepth 20: 41.621µs per range, 24026.119996 ranges/s, heap diff 63086592 +// maxDepth 20: 41.568µs per range, 24056.912641 ranges/s, heap diff 63094784 +// maxDepth 21: 32.234µs per range, 31022.459566 ranges/s, heap diff 123256832 +// maxDepth 21: 30.856µs per range, 32408.240119 ranges/s, heap diff 123248640 +// maxDepth 21: 30.774µs per range, 32494.318758 ranges/s, heap diff 123224064 +// maxDepth 22: 27.476µs per range, 36394.375781 ranges/s, heap diff 240689152 +// maxDepth 22: 27.707µs per range, 36091.188900 ranges/s, heap diff 240705536 +// maxDepth 22: 27.281µs per range, 36654.794863 ranges/s, heap diff 240705536 +// maxDepth 23: 24.394µs per range, 40992.220132 ranges/s, heap diff 470048768 +// maxDepth 23: 24.697µs per range, 40489.695824 ranges/s, heap diff 470040576 +// maxDepth 23: 24.436µs per range, 40923.081488 ranges/s, heap diff 470032384 + +// maxDepth 15: 529.513µs per range, 1888.524885 ranges/s, heap diff 2293760 +// maxDepth 15: 528.783µs per range, 1891.132520 ranges/s, heap diff 2244608 +// maxDepth 15: 529.458µs per range, 1888.723450 ranges/s, heap diff 2252800 +// maxDepth 16: 281.809µs per range, 3548.498801 ranges/s, heap diff 4390912 +// maxDepth 16: 280.159µs per range, 3569.389929 ranges/s, heap diff 4382720 +// maxDepth 16: 280.449µs per range, 3565.709031 ranges/s, heap diff 4390912 +// maxDepth 17: 157.429µs per range, 6352.037713 ranges/s, heap diff 8527872 +// maxDepth 17: 156.569µs per range, 6386.942961 ranges/s, heap diff 8527872 +// maxDepth 17: 157.158µs per range, 6362.998907 ranges/s, heap diff 8527872 +// maxDepth 18: 94.689µs per range, 10560.886016 ranges/s, heap diff 16547840 +// maxDepth 18: 95.995µs per range, 10417.191145 ranges/s, heap diff 16564224 +// maxDepth 18: 94.469µs per range, 10585.428908 ranges/s, heap diff 16515072 +// maxDepth 19: 61.218µs per range, 16334.822475 ranges/s, heap diff 32342016 +// maxDepth 19: 61.733µs per range, 16198.549404 ranges/s, heap diff 32350208 +// maxDepth 19: 61.269µs per range, 16321.226214 ranges/s, heap diff 32309248 +// maxDepth 20: 42.336µs per range, 23620.054892 ranges/s, heap diff 63053824 +// maxDepth 20: 41.906µs per range, 23862.511368 ranges/s, heap diff 63094784 +// maxDepth 20: 41.647µs per range, 24011.273302 ranges/s, heap diff 63086592 +// maxDepth 21: 32.895µs per range, 30399.444906 ranges/s, heap diff 123256832 +// maxDepth 21: 31.798µs per range, 31447.748207 ranges/s, heap diff 123256832 +// maxDepth 21: 32.008µs per range, 31241.248008 ranges/s, heap diff 123265024 +// maxDepth 22: 27.014µs per range, 37017.223157 ranges/s, heap diff 240689152 +// maxDepth 22: 26.764µs per range, 37363.422097 ranges/s, heap diff 240664576 +// maxDepth 22: 26.938µs per range, 37121.580267 ranges/s, heap diff 240664576 +// maxDepth 23: 24.457µs per range, 40887.173321 ranges/s, heap diff 470040576 +// maxDepth 23: 24.997µs per range, 40003.930386 ranges/s, heap diff 470040576 +// maxDepth 23: 24.741µs per range, 40418.462446 ranges/s, heap diff 470040576 + +// TBD: ensure short prefix problem is not a bug!!! From 4452894b1e9d7e539291e23a23ce1ae4a9379024 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 20 Jun 2024 23:37:48 +0400 Subject: [PATCH 37/76] hashsync: fix tests --- sync2/hashsync/setsyncbase_test.go | 8 +++++--- sync2/hashsync/split_sync_test.go | 32 +++++++++++++++--------------- sync2/hashsync/sync_queue_test.go | 20 +++++++++---------- 3 files changed, 31 insertions(+), 29 deletions(-) diff --git a/sync2/hashsync/setsyncbase_test.go b/sync2/hashsync/setsyncbase_test.go index 333b1db695..b25d473b62 100644 --- a/sync2/hashsync/setsyncbase_test.go +++ b/sync2/hashsync/setsyncbase_test.go @@ -56,14 +56,15 @@ func (st *setSyncBaseTester) getWaitCh(k Ordered) chan error { return ch } -func (st *setSyncBaseTester) expectCopy(ctx context.Context, addedKeys ...types.Hash32) { +func (st *setSyncBaseTester) expectCopy(ctx context.Context, addedKeys ...types.Hash32) *MockItemStore { + copy := NewMockItemStore(st.ctrl) st.is.EXPECT().Copy().DoAndReturn(func() ItemStore { - copy := NewMockItemStore(st.ctrl) for _, k := range addedKeys { copy.EXPECT().Add(ctx, k) } return copy }) + return copy } func (st *setSyncBaseTester) expectSyncStore( @@ -117,7 +118,8 @@ func TestSetSyncBase(t *testing.T) { Count: 42, Sim: 0.99, } - st.ps.EXPECT().Probe(ctx, p2p.Peer("p1"), st.is, nil, nil).Return(expPr, nil) + store := st.expectCopy(ctx) + st.ps.EXPECT().Probe(ctx, p2p.Peer("p1"), store, nil, nil).Return(expPr, nil) pr, err := st.ssb.Probe(ctx, p2p.Peer("p1")) require.NoError(t, err) require.Equal(t, expPr, pr) diff --git a/sync2/hashsync/split_sync_test.go b/sync2/hashsync/split_sync_test.go index e28bc45294..2eff5ede4e 100644 --- a/sync2/hashsync/split_sync_test.go +++ b/sync2/hashsync/split_sync_test.go @@ -20,7 +20,7 @@ import ( func hexDelimiters(n int) (r []string) { for _, h := range getDelimiters(n) { - r = append(r, h.Hex()) + r = append(r, h.String()) } return r } @@ -41,22 +41,22 @@ func TestGetDelimiters(t *testing.T) { { numPeers: 2, values: []string{ - "0x8000000000000000000000000000000000000000000000000000000000000000", + "8000000000000000000000000000000000000000000000000000000000000000", }, }, { numPeers: 3, values: []string{ - "0x5555555555555554000000000000000000000000000000000000000000000000", - "0xaaaaaaaaaaaaaaa8000000000000000000000000000000000000000000000000", + "5555555555555554000000000000000000000000000000000000000000000000", + "aaaaaaaaaaaaaaa8000000000000000000000000000000000000000000000000", }, }, { numPeers: 4, values: []string{ - "0x4000000000000000000000000000000000000000000000000000000000000000", - "0x8000000000000000000000000000000000000000000000000000000000000000", - "0xc000000000000000000000000000000000000000000000000000000000000000", + "4000000000000000000000000000000000000000000000000000000000000000", + "8000000000000000000000000000000000000000000000000000000000000000", + "c000000000000000000000000000000000000000000000000000000000000000", }, }, } { @@ -85,20 +85,20 @@ type splitSyncTester struct { var tstRanges = []hexRange{ { - "0x0000000000000000000000000000000000000000000000000000000000000000", - "0x4000000000000000000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + "4000000000000000000000000000000000000000000000000000000000000000", }, { - "0x4000000000000000000000000000000000000000000000000000000000000000", - "0x8000000000000000000000000000000000000000000000000000000000000000", + "4000000000000000000000000000000000000000000000000000000000000000", + "8000000000000000000000000000000000000000000000000000000000000000", }, { - "0x8000000000000000000000000000000000000000000000000000000000000000", - "0xc000000000000000000000000000000000000000000000000000000000000000", + "8000000000000000000000000000000000000000000000000000000000000000", + "c000000000000000000000000000000000000000000000000000000000000000", }, { - "0xc000000000000000000000000000000000000000000000000000000000000000", - "0x0000000000000000000000000000000000000000000000000000000000000000", + "c000000000000000000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", }, } @@ -137,7 +137,7 @@ func newTestSplitSync(t testing.TB) *splitSyncTester { require.NotNil(t, ctx) require.NotNil(t, x) require.NotNil(t, y) - k := hexRange{x.Hex(), y.Hex()} + k := hexRange{x.String(), y.String()} tst.peerRanges[k] = append(tst.peerRanges[k], peer) count, found := tst.expPeerRanges[k] require.True(t, found, "peer range not found: x %s y %s", x, y) diff --git a/sync2/hashsync/sync_queue_test.go b/sync2/hashsync/sync_queue_test.go index 8b0ca984ca..180d2aa0f2 100644 --- a/sync2/hashsync/sync_queue_test.go +++ b/sync2/hashsync/sync_queue_test.go @@ -12,20 +12,20 @@ type hexRange [2]string func TestSyncQueue(t *testing.T) { expPeerRanges := map[hexRange]bool{ { - "0x0000000000000000000000000000000000000000000000000000000000000000", - "0x4000000000000000000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + "4000000000000000000000000000000000000000000000000000000000000000", }: false, { - "0x4000000000000000000000000000000000000000000000000000000000000000", - "0x8000000000000000000000000000000000000000000000000000000000000000", + "4000000000000000000000000000000000000000000000000000000000000000", + "8000000000000000000000000000000000000000000000000000000000000000", }: false, { - "0x8000000000000000000000000000000000000000000000000000000000000000", - "0xc000000000000000000000000000000000000000000000000000000000000000", + "8000000000000000000000000000000000000000000000000000000000000000", + "c000000000000000000000000000000000000000000000000000000000000000", }: false, { - "0xc000000000000000000000000000000000000000000000000000000000000000", - "0x0000000000000000000000000000000000000000000000000000000000000000", + "c000000000000000000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", }: false, } sq := newSyncQueue(4) @@ -37,7 +37,7 @@ func TestSyncQueue(t *testing.T) { require.True(t, sr.lastSyncStarted.IsZero()) require.False(t, sr.done) require.Zero(t, sr.numSyncers) - k := hexRange{sr.x.Hex(), sr.y.Hex()} + k := hexRange{sr.x.String(), sr.y.String()} processed, found := expPeerRanges[k] require.True(t, found) require.False(t, processed) @@ -60,7 +60,7 @@ func TestSyncQueue(t *testing.T) { require.Len(t, sq, 4) for i := 0; i < 4; i++ { sr := sq.popRange() - k := hexRange{sr.x.Hex(), sr.y.Hex()} + k := hexRange{sr.x.String(), sr.y.String()} t.Logf("pop range %v at %v", k, sr.lastSyncStarted) require.Equal(t, pushed[i], k) } From bf450f46ff33c01cdb2aabf396792fee37b1507e Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Fri, 21 Jun 2024 01:11:36 +0400 Subject: [PATCH 38/76] hashsync: make ItemStore and Iterator methods return errors --- sync2/hashsync/handler.go | 8 +- sync2/hashsync/handler_test.go | 15 ++-- sync2/hashsync/interface.go | 12 +-- sync2/hashsync/mocks_test.go | 79 +++++++++--------- sync2/hashsync/multipeer.go | 10 ++- sync2/hashsync/multipeer_test.go | 14 ++-- sync2/hashsync/rangesync.go | 58 +++++++++---- sync2/hashsync/rangesync_test.go | 125 +++++++++++++++++++---------- sync2/hashsync/setsyncbase.go | 28 +++++-- sync2/hashsync/setsyncbase_test.go | 13 ++- sync2/hashsync/sync_tree_store.go | 26 +++--- sync2/hashsync/xorsync_test.go | 14 +++- sync2/p2p_test.go | 15 ++-- 13 files changed, 269 insertions(+), 148 deletions(-) diff --git a/sync2/hashsync/handler.go b/sync2/hashsync/handler.go index a7555f0685..2b7939266e 100644 --- a/sync2/hashsync/handler.go +++ b/sync2/hashsync/handler.go @@ -164,7 +164,9 @@ func (c *wireConduit) SendItems(count, itemChunkSize int, it Iterator) error { panic("fakeConduit.SendItems: went got to the end of the tree") } msg.ContentKeys = append(msg.ContentKeys, it.Key().(types.Hash32)) - it.Next() + if err := it.Next(); err != nil { + return err + } n-- } if err := c.send(&msg); err != nil { @@ -209,7 +211,9 @@ func (c *wireConduit) SendProbeResponse(x, y Ordered, fingerprint any, count, sa for n := 0; n < sampleSize; n++ { m.Sample[n] = MinhashSampleItemFromHash32(it.Key().(types.Hash32)) // fmt.Fprintf(os.Stderr, "QQQQQ: m.Sample[%d] = %s\n", n, m.Sample[n]) - it.Next() + if err := it.Next(); err != nil { + return err + } } // fmt.Fprintf(os.Stderr, "QQQQQ: end sending items\n") if x == nil && y == nil { diff --git a/sync2/hashsync/handler_test.go b/sync2/hashsync/handler_test.go index ba96fbacbd..d60be49788 100644 --- a/sync2/hashsync/handler_test.go +++ b/sync2/hashsync/handler_test.go @@ -125,10 +125,11 @@ func (it *sliceIterator) Key() Ordered { return nil } -func (it *sliceIterator) Next() { +func (it *sliceIterator) Next() error { if len(it.s) != 0 { it.s = it.s[1:] } + return nil } type fakeSend struct { @@ -466,16 +467,20 @@ func testWireProbe(t *testing.T, getRequester getRequesterFunc) Requester { storeA, getRequester, opts, func(ctx context.Context, client Requester, srvPeerID p2p.Peer) { pss := NewPairwiseStoreSyncer(client, opts) - minA := storeA.Min().Key() - infoA := storeA.GetRangeInfo(nil, minA, minA, -1) + minA, err := storeA.Min() + require.NoError(t, err) + infoA, err := storeA.GetRangeInfo(nil, minA.Key(), minA.Key(), -1) + require.NoError(t, err) prA, err := pss.Probe(ctx, srvPeerID, storeB, nil, nil) require.NoError(t, err) require.Equal(t, infoA.Fingerprint, prA.FP) require.Equal(t, infoA.Count, prA.Count) require.InDelta(t, 0.98, prA.Sim, 0.05, "sim") - minA = storeA.Min().Key() - partInfoA := storeA.GetRangeInfo(nil, minA, minA, infoA.Count/2) + minA, err = storeA.Min() + require.NoError(t, err) + partInfoA, err := storeA.GetRangeInfo(nil, minA.Key(), minA.Key(), infoA.Count/2) + require.NoError(t, err) x := partInfoA.Start.Key().(types.Hash32) y := partInfoA.End.Key().(types.Hash32) // partInfoA = storeA.GetRangeInfo(nil, x, y, -1) diff --git a/sync2/hashsync/interface.go b/sync2/hashsync/interface.go index 6038750122..63fdea7d4f 100644 --- a/sync2/hashsync/interface.go +++ b/sync2/hashsync/interface.go @@ -19,7 +19,7 @@ type Iterator interface { // nil if the ItemStore is empty Key() Ordered // Next advances the iterator - Next() + Next() error } type RangeInfo struct { @@ -37,17 +37,17 @@ type ItemStore interface { // is returned for the corresponding subrange of the requested range. // If both x and y is nil, the whole set of items is used. // If only x or only y is nil, GetRangeInfo panics - GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo + GetRangeInfo(preceding Iterator, x, y Ordered, count int) (RangeInfo, error) // Min returns the iterator pointing at the minimum element // in the store. If the store is empty, it returns nil - Min() Iterator + Min() (Iterator, error) // Max returns the iterator pointing at the maximum element // in the store. If the store is empty, it returns nil - Max() Iterator + Max() (Iterator, error) // Copy makes a shallow copy of the ItemStore Copy() ItemStore // Has returns true if the specified key is present in ItemStore - Has(k Ordered) bool + Has(k Ordered) (bool, error) } type Requester interface { @@ -56,7 +56,7 @@ type Requester interface { } type SyncBase interface { - Count() int + Count() (int, error) Derive(p p2p.Peer) Syncer Probe(ctx context.Context, p p2p.Peer) (ProbeResult, error) Wait() error diff --git a/sync2/hashsync/mocks_test.go b/sync2/hashsync/mocks_test.go index a78fc16861..861c3d1ad8 100644 --- a/sync2/hashsync/mocks_test.go +++ b/sync2/hashsync/mocks_test.go @@ -120,9 +120,11 @@ func (c *MockIteratorKeyCall) DoAndReturn(f func() Ordered) *MockIteratorKeyCall } // Next mocks base method. -func (m *MockIterator) Next() { +func (m *MockIterator) Next() error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Next") + ret := m.ctrl.Call(m, "Next") + ret0, _ := ret[0].(error) + return ret0 } // Next indicates an expected call of Next. @@ -138,19 +140,19 @@ type MockIteratorNextCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockIteratorNextCall) Return() *MockIteratorNextCall { - c.Call = c.Call.Return() +func (c *MockIteratorNextCall) Return(arg0 error) *MockIteratorNextCall { + c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockIteratorNextCall) Do(f func()) *MockIteratorNextCall { +func (c *MockIteratorNextCall) Do(f func() error) *MockIteratorNextCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockIteratorNextCall) DoAndReturn(f func()) *MockIteratorNextCall { +func (c *MockIteratorNextCall) DoAndReturn(f func() error) *MockIteratorNextCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -255,11 +257,12 @@ func (c *MockItemStoreCopyCall) DoAndReturn(f func() ItemStore) *MockItemStoreCo } // GetRangeInfo mocks base method. -func (m *MockItemStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo { +func (m *MockItemStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) (RangeInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetRangeInfo", preceding, x, y, count) ret0, _ := ret[0].(RangeInfo) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // GetRangeInfo indicates an expected call of GetRangeInfo. @@ -275,29 +278,30 @@ type MockItemStoreGetRangeInfoCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockItemStoreGetRangeInfoCall) Return(arg0 RangeInfo) *MockItemStoreGetRangeInfoCall { - c.Call = c.Call.Return(arg0) +func (c *MockItemStoreGetRangeInfoCall) Return(arg0 RangeInfo, arg1 error) *MockItemStoreGetRangeInfoCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockItemStoreGetRangeInfoCall) Do(f func(Iterator, Ordered, Ordered, int) RangeInfo) *MockItemStoreGetRangeInfoCall { +func (c *MockItemStoreGetRangeInfoCall) Do(f func(Iterator, Ordered, Ordered, int) (RangeInfo, error)) *MockItemStoreGetRangeInfoCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreGetRangeInfoCall) DoAndReturn(f func(Iterator, Ordered, Ordered, int) RangeInfo) *MockItemStoreGetRangeInfoCall { +func (c *MockItemStoreGetRangeInfoCall) DoAndReturn(f func(Iterator, Ordered, Ordered, int) (RangeInfo, error)) *MockItemStoreGetRangeInfoCall { c.Call = c.Call.DoAndReturn(f) return c } // Has mocks base method. -func (m *MockItemStore) Has(k Ordered) bool { +func (m *MockItemStore) Has(k Ordered) (bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Has", k) ret0, _ := ret[0].(bool) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // Has indicates an expected call of Has. @@ -313,29 +317,30 @@ type MockItemStoreHasCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockItemStoreHasCall) Return(arg0 bool) *MockItemStoreHasCall { - c.Call = c.Call.Return(arg0) +func (c *MockItemStoreHasCall) Return(arg0 bool, arg1 error) *MockItemStoreHasCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockItemStoreHasCall) Do(f func(Ordered) bool) *MockItemStoreHasCall { +func (c *MockItemStoreHasCall) Do(f func(Ordered) (bool, error)) *MockItemStoreHasCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreHasCall) DoAndReturn(f func(Ordered) bool) *MockItemStoreHasCall { +func (c *MockItemStoreHasCall) DoAndReturn(f func(Ordered) (bool, error)) *MockItemStoreHasCall { c.Call = c.Call.DoAndReturn(f) return c } // Max mocks base method. -func (m *MockItemStore) Max() Iterator { +func (m *MockItemStore) Max() (Iterator, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Max") ret0, _ := ret[0].(Iterator) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // Max indicates an expected call of Max. @@ -351,29 +356,30 @@ type MockItemStoreMaxCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockItemStoreMaxCall) Return(arg0 Iterator) *MockItemStoreMaxCall { - c.Call = c.Call.Return(arg0) +func (c *MockItemStoreMaxCall) Return(arg0 Iterator, arg1 error) *MockItemStoreMaxCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockItemStoreMaxCall) Do(f func() Iterator) *MockItemStoreMaxCall { +func (c *MockItemStoreMaxCall) Do(f func() (Iterator, error)) *MockItemStoreMaxCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreMaxCall) DoAndReturn(f func() Iterator) *MockItemStoreMaxCall { +func (c *MockItemStoreMaxCall) DoAndReturn(f func() (Iterator, error)) *MockItemStoreMaxCall { c.Call = c.Call.DoAndReturn(f) return c } // Min mocks base method. -func (m *MockItemStore) Min() Iterator { +func (m *MockItemStore) Min() (Iterator, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Min") ret0, _ := ret[0].(Iterator) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // Min indicates an expected call of Min. @@ -389,19 +395,19 @@ type MockItemStoreMinCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockItemStoreMinCall) Return(arg0 Iterator) *MockItemStoreMinCall { - c.Call = c.Call.Return(arg0) +func (c *MockItemStoreMinCall) Return(arg0 Iterator, arg1 error) *MockItemStoreMinCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockItemStoreMinCall) Do(f func() Iterator) *MockItemStoreMinCall { +func (c *MockItemStoreMinCall) Do(f func() (Iterator, error)) *MockItemStoreMinCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreMinCall) DoAndReturn(f func() Iterator) *MockItemStoreMinCall { +func (c *MockItemStoreMinCall) DoAndReturn(f func() (Iterator, error)) *MockItemStoreMinCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -534,11 +540,12 @@ func (m *MockSyncBase) EXPECT() *MockSyncBaseMockRecorder { } // Count mocks base method. -func (m *MockSyncBase) Count() int { +func (m *MockSyncBase) Count() (int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Count") ret0, _ := ret[0].(int) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // Count indicates an expected call of Count. @@ -554,19 +561,19 @@ type MockSyncBaseCountCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockSyncBaseCountCall) Return(arg0 int) *MockSyncBaseCountCall { - c.Call = c.Call.Return(arg0) +func (c *MockSyncBaseCountCall) Return(arg0 int, arg1 error) *MockSyncBaseCountCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockSyncBaseCountCall) Do(f func() int) *MockSyncBaseCountCall { +func (c *MockSyncBaseCountCall) Do(f func() (int, error)) *MockSyncBaseCountCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockSyncBaseCountCall) DoAndReturn(f func() int) *MockSyncBaseCountCall { +func (c *MockSyncBaseCountCall) DoAndReturn(f func() (int, error)) *MockSyncBaseCountCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/sync2/hashsync/multipeer.go b/sync2/hashsync/multipeer.go index 7ea20d630d..3978d1a9ef 100644 --- a/sync2/hashsync/multipeer.go +++ b/sync2/hashsync/multipeer.go @@ -179,17 +179,21 @@ func (mpr *MultiPeerReconciler) probePeers(ctx context.Context, syncPeers []p2p. zap.Int("count", pr.Count)) } - if (1-pr.Sim)*float64(mpr.syncBase.Count()) < float64(mpr.maxFullDiff) { + c, err := mpr.syncBase.Count() + if err != nil { + return s, err + } + if (1-pr.Sim)*float64(c) < float64(mpr.maxFullDiff) { mpr.logger.Debug("nearFull peer", zap.Stringer("peer", p), zap.Float64("sim", pr.Sim), - zap.Int("localCount", mpr.syncBase.Count())) + zap.Int("localCount", c)) s.nearFullCount++ } else { mpr.logger.Debug("nearFull peer", zap.Stringer("peer", p), zap.Float64("sim", pr.Sim), - zap.Int("localCount", mpr.syncBase.Count())) + zap.Int("localCount", c)) } } return s, nil diff --git a/sync2/hashsync/multipeer_test.go b/sync2/hashsync/multipeer_test.go index ecc9ad6c1e..271a204476 100644 --- a/sync2/hashsync/multipeer_test.go +++ b/sync2/hashsync/multipeer_test.go @@ -131,7 +131,7 @@ func TestMultiPeerSync(t *testing.T) { mt.addPeers(10) // Advance by peer wait time. After that, 6 peers will be selected // randomly and probed - mt.syncBase.EXPECT().Count().Return(50).AnyTimes() + mt.syncBase.EXPECT().Count().Return(50, nil).AnyTimes() for i := 0; i < numSyncs; i++ { mt.expectProbe(6, ProbeResult{ FP: "foo", @@ -159,7 +159,7 @@ func TestMultiPeerSync(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(10) - mt.syncBase.EXPECT().Count().Return(100).AnyTimes() + mt.syncBase.EXPECT().Count().Return(100, nil).AnyTimes() for i := 0; i < numSyncs; i++ { mt.expectProbe(6, ProbeResult{ FP: "foo", @@ -179,7 +179,7 @@ func TestMultiPeerSync(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(1) - mt.syncBase.EXPECT().Count().Return(50).AnyTimes() + mt.syncBase.EXPECT().Count().Return(50, nil).AnyTimes() for i := 0; i < numSyncs; i++ { mt.expectProbe(1, ProbeResult{ FP: "foo", @@ -199,7 +199,7 @@ func TestMultiPeerSync(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(10) - mt.syncBase.EXPECT().Count().Return(100).AnyTimes() + mt.syncBase.EXPECT().Count().Return(100, nil).AnyTimes() mt.syncBase.EXPECT().Probe(gomock.Any(), gomock.Any()). Return(ProbeResult{}, errors.New("probe failed")) mt.expectProbe(5, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) @@ -214,7 +214,7 @@ func TestMultiPeerSync(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(10) - mt.syncBase.EXPECT().Count().Return(100).AnyTimes() + mt.syncBase.EXPECT().Count().Return(100, nil).AnyTimes() for i := 0; i < numSyncs; i++ { mt.expectProbe(6, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) mt.expectFullSync(6, 3) @@ -230,7 +230,7 @@ func TestMultiPeerSync(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(10) - mt.syncBase.EXPECT().Count().Return(100).AnyTimes() + mt.syncBase.EXPECT().Count().Return(100, nil).AnyTimes() for i := 0; i < numSyncs; i++ { mt.expectProbe(6, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) mt.expectFullSync(6, 0) @@ -246,7 +246,7 @@ func TestMultiPeerSync(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(10) - mt.syncBase.EXPECT().Count().Return(100).AnyTimes() + mt.syncBase.EXPECT().Count().Return(100, nil).AnyTimes() mt.expectProbe(6, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) mt.syncRunner.EXPECT().fullSync(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, peers []p2p.Peer) error { diff --git a/sync2/hashsync/rangesync.go b/sync2/hashsync/rangesync.go index 40efc9cc3f..a5908e0840 100644 --- a/sync2/hashsync/rangesync.go +++ b/sync2/hashsync/rangesync.go @@ -188,7 +188,10 @@ func (rsr *RangeSetReconciler) processSubrange(c Conduit, preceding Iterator, x, // fmt.Fprintf(os.Stderr, "QQQQQ: preceding=%q\n", // qqqqRmmeK(preceding)) // TODO: don't re-request range info for the first part of range after stop - info := rsr.is.GetRangeInfo(preceding, x, y, -1) + info, err := rsr.is.GetRangeInfo(preceding, x, y, -1) + if err != nil { + return nil, err + } // fmt.Fprintf(os.Stderr, "QQQQQ: start=%q end=%q info.Start=%q info.End=%q info.FP=%q x=%q y=%q\n", // qqqqRmmeK(start), qqqqRmmeK(end), qqqqRmmeK(info.Start), qqqqRmmeK(info.End), info.Fingerprint, x, y) switch { @@ -229,11 +232,17 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg if msg.Type() == MessageTypeEmptySet || (msg.Type() == MessageTypeProbe && x == nil && y == nil) { // The peer has no items at all so didn't // even send X & Y (SendEmptySet) - it := rsr.is.Min() + it, err := rsr.is.Min() + if err != nil { + return nil, false, err + } if it == nil { // We don't have any items at all, too if msg.Type() == MessageTypeProbe { - info := rsr.is.GetRangeInfo(preceding, nil, nil, -1) + info, err := rsr.is.GetRangeInfo(preceding, nil, nil, -1) + if err != nil { + return nil, false, err + } if err := c.SendProbeResponse(x, y, info.Fingerprint, info.Count, 0, it); err != nil { return nil, false, err } @@ -245,7 +254,10 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg } else if x == nil || y == nil { return nil, false, errors.New("bad X or Y") } - info := rsr.is.GetRangeInfo(preceding, x, y, -1) + info, err := rsr.is.GetRangeInfo(preceding, x, y, -1) + if err != nil { + return nil, false, err + } // fmt.Fprintf(os.Stderr, "QQQQQ msg %s %#v fp %v start %#v end %#v count %d\n", msg.Type(), msg, info.Fingerprint, info.Start, info.End, info.Count) switch { case msg.Type() == MessageTypeEmptyRange || @@ -311,7 +323,10 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg // Note that there's no special handling for rollover ranges with x >= y // These need to be handled by ItemStore.GetRangeInfo() count := (info.Count + 1) / 2 - part := rsr.is.GetRangeInfo(preceding, x, y, count) + part, err := rsr.is.GetRangeInfo(preceding, x, y, count) + if err != nil { + return nil, false, err + } if part.End == nil { panic("BUG: can't split range with count > 1") } @@ -333,7 +348,10 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg } func (rsr *RangeSetReconciler) Initiate(c Conduit) error { - it := rsr.is.Min() + it, err := rsr.is.Min() + if err != nil { + return err + } var x Ordered if it != nil { x = it.Key() @@ -347,7 +365,10 @@ func (rsr *RangeSetReconciler) InitiateBounded(c Conduit, x, y Ordered) error { return err } } else { - info := rsr.is.GetRangeInfo(nil, x, y, -1) + info, err := rsr.is.GetRangeInfo(nil, x, y, -1) + if err != nil { + return err + } switch { case info.Count == 0: panic("empty full min-min range") @@ -396,7 +417,10 @@ func (rsr *RangeSetReconciler) InitiateProbe(c Conduit) (RangeInfo, error) { } func (rsr *RangeSetReconciler) InitiateBoundedProbe(c Conduit, x, y Ordered) (RangeInfo, error) { - info := rsr.is.GetRangeInfo(nil, x, y, -1) + info, err := rsr.is.GetRangeInfo(nil, x, y, -1) + if err != nil { + return RangeInfo{}, err + } // fmt.Fprintf(os.Stderr, "QQQQQ: x %#v y %#v count %d\n", x, y, info.Count) if err := c.SendProbe(x, y, info.Fingerprint, rsr.sampleSize); err != nil { return RangeInfo{}, err @@ -407,12 +431,12 @@ func (rsr *RangeSetReconciler) InitiateBoundedProbe(c Conduit, x, y Ordered) (Ra return info, nil } -func (rsr *RangeSetReconciler) calcSim(c Conduit, info RangeInfo, remoteSample []Ordered, fp any) float64 { +func (rsr *RangeSetReconciler) calcSim(c Conduit, info RangeInfo, remoteSample []Ordered, fp any) (float64, error) { if fingerprintEqual(info.Fingerprint, fp) { - return 1 + return 1, nil } if info.Start == nil { - return 0 + return 0, nil } sampleSize := min(info.Count, rsr.sampleSize) localSample := make([]Ordered, sampleSize) @@ -424,7 +448,9 @@ func (rsr *RangeSetReconciler) calcSim(c Conduit, info RangeInfo, remoteSample [ panic("BUG: no key") } localSample[n] = c.ShortenKey(it.Key()) - it.Next() + if err := it.Next(); err != nil { + return 0, err + } } slices.SortFunc(remoteSample, func(a, b Ordered) int { return a.Compare(b) }) slices.SortFunc(localSample, func(a, b Ordered) int { return a.Compare(b) }) @@ -448,7 +474,7 @@ func (rsr *RangeSetReconciler) calcSim(c Conduit, info RangeInfo, remoteSample [ } maxSampleSize := max(sampleSize, len(remoteSample)) // fmt.Fprintf(os.Stderr, "QQQQQ: numEq %d maxSampleSize %d\n", numEq, maxSampleSize) - return float64(numEq) / float64(maxSampleSize) + return float64(numEq) / float64(maxSampleSize), nil } func (rsr *RangeSetReconciler) HandleProbeResponse(c Conduit, info RangeInfo) (pr ProbeResult, err error) { @@ -480,7 +506,11 @@ func (rsr *RangeSetReconciler) HandleProbeResponse(c Conduit, info RangeInfo) (p } pr.FP = msg.Fingerprint() pr.Count = msg.Count() - pr.Sim = rsr.calcSim(c, info, msg.Keys(), msg.Fingerprint()) + sim, err := rsr.calcSim(c, info, msg.Keys(), msg.Fingerprint()) + if err != nil { + return ProbeResult{}, fmt.Errorf("database error: %w", err) + } + pr.Sim = sim gotRange = true case MessageTypeEmptySet, MessageTypeEmptyRange: if gotRange { diff --git a/sync2/hashsync/rangesync_test.go b/sync2/hashsync/rangesync_test.go index deeac28c79..b579391974 100644 --- a/sync2/hashsync/rangesync_test.go +++ b/sync2/hashsync/rangesync_test.go @@ -123,7 +123,9 @@ func (fc *fakeConduit) SendItems(count, itemChunkSize int, it Iterator) error { panic("fakeConduit.SendItems: went got to the end of the tree") } msg.keys = append(msg.keys, it.Key()) - it.Next() + if err := it.Next(); err != nil { + return err + } n-- } fc.sendMsg(msg) @@ -164,7 +166,9 @@ func (fc *fakeConduit) SendProbeResponse(x, y Ordered, fingerprint any, count, s for n := 0; n < sampleSize; n++ { require.NotNil(fc.t, it.Key()) msg.keys[n] = it.Key() - it.Next() + if err := it.Next(); err != nil { + return err + } } fc.sendMsg(msg) return nil @@ -193,10 +197,11 @@ func (it *dumbStoreIterator) Key() Ordered { return it.ds.keys[it.n] } -func (it *dumbStoreIterator) Next() { +func (it *dumbStoreIterator) Next() error { if len(it.ds.keys) != 0 { it.n = (it.n + 1) % len(it.ds.keys) } + return nil } type dumbStore struct { @@ -248,13 +253,16 @@ func (ds *dumbStore) iterFor(s sampleID) Iterator { return ds.iter(n) } -func (ds *dumbStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo { +func (ds *dumbStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) (RangeInfo, error) { if x == nil && y == nil { - it := ds.Min() + it, err := ds.Min() + if err != nil { + return RangeInfo{}, err + } if it == nil { return RangeInfo{ Fingerprint: "", - } + }, nil } else { x = it.Key() y = x @@ -280,40 +288,40 @@ func (ds *dumbStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) R r.Start = ds.iterFor(sampleID(startStr)) r.End = ds.iterFor(sampleID(endStr)) } - return r + return r, nil } -func (ds *dumbStore) Min() Iterator { +func (ds *dumbStore) Min() (Iterator, error) { if len(ds.keys) == 0 { - return nil + return nil, nil } return &dumbStoreIterator{ ds: ds, n: 0, - } + }, nil } -func (ds *dumbStore) Max() Iterator { +func (ds *dumbStore) Max() (Iterator, error) { if len(ds.keys) == 0 { - return nil + return nil, nil } return &dumbStoreIterator{ ds: ds, n: len(ds.keys) - 1, - } + }, nil } func (ds *dumbStore) Copy() ItemStore { return &dumbStore{keys: slices.Clone(ds.keys)} } -func (ds *dumbStore) Has(k Ordered) bool { +func (ds *dumbStore) Has(k Ordered) (bool, error) { for _, cur := range ds.keys { if k.Compare(cur) == 0 { - return true + return true, nil } } - return false + return false, nil } type verifiedStoreIterator struct { @@ -342,10 +350,18 @@ func (it verifiedStoreIterator) Key() Ordered { return k2 } -func (it verifiedStoreIterator) Next() { - it.knownGood.Next() - it.it.Next() - assert.Equal(it.t, it.knownGood.Key(), it.it.Key(), "keys for Next()") +func (it verifiedStoreIterator) Next() error { + err1 := it.knownGood.Next() + err2 := it.it.Next() + switch { + case err1 == nil && err2 == nil: + assert.Equal(it.t, it.knownGood.Key(), it.it.Key(), "keys for Next()") + case err1 != nil && err2 != nil: + return err2 + default: + assert.Fail(it.t, "iterator error mismatch") + } + return nil } type verifiedStore struct { @@ -382,15 +398,22 @@ func (vs *verifiedStore) Add(ctx context.Context, k Ordered) error { return nil } -func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo { - var ri1, ri2 RangeInfo +func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) (RangeInfo, error) { + var ( + ri1, ri2 RangeInfo + err error + ) if preceding != nil { p := preceding.(verifiedStoreIterator) - ri1 = vs.knownGood.GetRangeInfo(p.knownGood, x, y, count) - ri2 = vs.store.GetRangeInfo(p.it, x, y, count) + ri1, err = vs.knownGood.GetRangeInfo(p.knownGood, x, y, count) + require.NoError(vs.t, err) + ri2, err = vs.store.GetRangeInfo(p.it, x, y, count) + require.NoError(vs.t, err) } else { - ri1 = vs.knownGood.GetRangeInfo(nil, x, y, count) - ri2 = vs.store.GetRangeInfo(nil, x, y, count) + ri1, err = vs.knownGood.GetRangeInfo(nil, x, y, count) + require.NoError(vs.t, err) + ri2, err = vs.store.GetRangeInfo(nil, x, y, count) + require.NoError(vs.t, err) } require.Equal(vs.t, ri1.Fingerprint, ri2.Fingerprint, "range info fingerprint") require.Equal(vs.t, ri1.Count, ri2.Count, "range info count") @@ -426,15 +449,17 @@ func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count in } // QQQQQ: TODO: if count >= 0 and start+end != nil, do more calls to GetRangeInfo using resulting // end iterator key to make sure the range is correct - return ri + return ri, nil } -func (vs *verifiedStore) Min() Iterator { - m1 := vs.knownGood.Min() - m2 := vs.store.Min() +func (vs *verifiedStore) Min() (Iterator, error) { + m1, err := vs.knownGood.Min() + require.NoError(vs.t, err) + m2, err := vs.store.Min() + require.NoError(vs.t, err) if m1 == nil { require.Nil(vs.t, m2, "Min") - return nil + return nil, nil } else { require.NotNil(vs.t, m2, "Min") require.Equal(vs.t, m1.Key(), m2.Key(), "Min key") @@ -443,15 +468,17 @@ func (vs *verifiedStore) Min() Iterator { t: vs.t, knownGood: m1, it: m2, - } + }, nil } -func (vs *verifiedStore) Max() Iterator { - m1 := vs.knownGood.Max() - m2 := vs.store.Max() +func (vs *verifiedStore) Max() (Iterator, error) { + m1, err := vs.knownGood.Max() + require.NoError(vs.t, err) + m2, err := vs.store.Max() + require.NoError(vs.t, err) if m1 == nil { require.Nil(vs.t, m2, "Max") - return nil + return nil, nil } else { require.NotNil(vs.t, m2, "Max") require.Equal(vs.t, m1.Key(), m2.Key(), "Max key") @@ -460,7 +487,7 @@ func (vs *verifiedStore) Max() Iterator { t: vs.t, knownGood: m1, it: m2, - } + }, nil } func (vs *verifiedStore) Copy() ItemStore { @@ -472,11 +499,13 @@ func (vs *verifiedStore) Copy() ItemStore { } } -func (vs *verifiedStore) Has(k Ordered) bool { - h1 := vs.knownGood.Has(k) - h2 := vs.store.Has(k) +func (vs *verifiedStore) Has(k Ordered) (bool, error) { + h1, err := vs.knownGood.Has(k) + require.NoError(vs.t, err) + h2, err := vs.store.Has(k) + require.NoError(vs.t, err) require.Equal(vs.t, h1, h2) - return h2 + return h2, nil } type storeFactory func(t *testing.T) ItemStore @@ -506,15 +535,23 @@ func makeStore(t *testing.T, f storeFactory, items string) ItemStore { } func storeItemStr(is ItemStore) string { - it := is.Min() + it, err := is.Min() + if err != nil { + panic("store min error") + } if it == nil { return "" } - endAt := is.Min() + endAt, err := is.Min() + if err != nil { + panic("store min error") + } r := "" for { r += string(it.Key().(sampleID)) - it.Next() + if err := it.Next(); err != nil { + panic("iterator error") + } if it.Equal(endAt) { return r } diff --git a/sync2/hashsync/setsyncbase.go b/sync2/hashsync/setsyncbase.go index 869a5a57c4..ed49c86e17 100644 --- a/sync2/hashsync/setsyncbase.go +++ b/sync2/hashsync/setsyncbase.go @@ -35,15 +35,20 @@ func NewSetSyncBase(ps PairwiseSyncer, is ItemStore, handler SyncKeyHandler) *Se } // Count implements syncBase. -func (ssb *SetSyncBase) Count() int { +func (ssb *SetSyncBase) Count() (int, error) { + // TODO: don't lock on db-bound operations ssb.Lock() defer ssb.Unlock() - it := ssb.is.Min() - if it == nil { - return 0 + it, err := ssb.is.Min() + if it == nil || err != nil { + return 0, err } x := it.Key() - return ssb.is.GetRangeInfo(nil, x, x, -1).Count + info, err := ssb.is.GetRangeInfo(nil, x, x, -1) + if err != nil { + return 0, err + } + return info.Count, nil } // Derive implements syncBase. @@ -67,11 +72,15 @@ func (ssb *SetSyncBase) Probe(ctx context.Context, p p2p.Peer) (ProbeResult, err return ssb.ps.Probe(ctx, p, is, nil, nil) } -func (ssb *SetSyncBase) acceptKey(ctx context.Context, k Ordered, p p2p.Peer) { +func (ssb *SetSyncBase) acceptKey(ctx context.Context, k Ordered, p p2p.Peer) error { ssb.Lock() defer ssb.Unlock() key := k.(fmt.Stringer).String() - if !ssb.is.Has(k) { + has, err := ssb.is.Has(k) + if err != nil { + return err + } + if !has { ssb.waiting = append(ssb.waiting, ssb.g.DoChan(key, func() (any, error) { err := ssb.handler(ctx, k, p) @@ -83,6 +92,7 @@ func (ssb *SetSyncBase) acceptKey(ctx context.Context, k Ordered, p p2p.Peer) { return key, err })) } + return nil } func (ssb *SetSyncBase) Wait() error { @@ -131,6 +141,8 @@ func (ss *setSyncer) Serve(ctx context.Context, req []byte, stream io.ReadWriter // Add implements ItemStore. func (ss *setSyncer) Add(ctx context.Context, k Ordered) error { - ss.acceptKey(ctx, k, ss.p) + if err := ss.acceptKey(ctx, k, ss.p); err != nil { + return err + } return ss.ItemStore.Add(ctx, k) } diff --git a/sync2/hashsync/setsyncbase_test.go b/sync2/hashsync/setsyncbase_test.go index b25d473b62..fe92763f8b 100644 --- a/sync2/hashsync/setsyncbase_test.go +++ b/sync2/hashsync/setsyncbase_test.go @@ -240,8 +240,11 @@ func TestSetSyncBase(t *testing.T) { ss.(ItemStore).Add(context.Background(), hs[2]) ss.(ItemStore).Add(context.Background(), hs[3]) // syncer's cloned ItemStore has new key immediately - require.True(t, ss.(ItemStore).Has(hs[2])) - require.True(t, ss.(ItemStore).Has(hs[3])) + has, err := ss.(ItemStore).Has(hs[2]) + require.NoError(t, err) + require.True(t, has) + has, err = ss.(ItemStore).Has(hs[3]) + require.True(t, has) handlerErr := errors.New("fail") st.getWaitCh(hs[2]) <- handlerErr close(st.getWaitCh(hs[3])) @@ -249,7 +252,9 @@ func TestSetSyncBase(t *testing.T) { require.ErrorIs(t, err, handlerErr) require.ElementsMatch(t, hs[2:], handledKeys) // only successfully handled key propagate the syncBase - require.False(t, is.Has(hs[2])) - require.True(t, is.Has(hs[3])) + has, err = is.Has(hs[2]) + require.False(t, has) + has, err = is.Has(hs[3]) + require.True(t, has) }) } diff --git a/sync2/hashsync/sync_tree_store.go b/sync2/hashsync/sync_tree_store.go index bfc949f362..65ce3506b7 100644 --- a/sync2/hashsync/sync_tree_store.go +++ b/sync2/hashsync/sync_tree_store.go @@ -21,11 +21,12 @@ func (it *syncTreeIterator) Key() Ordered { return it.ptr.Key() } -func (it *syncTreeIterator) Next() { +func (it *syncTreeIterator) Next() error { it.ptr.Next() if it.ptr.Key() == nil { it.ptr = it.st.Min() } + return nil } type SyncTreeStore struct { @@ -59,13 +60,16 @@ func (sts *SyncTreeStore) iter(ptr SyncTreePointer) Iterator { } // GetRangeInfo implements ItemStore. -func (sts *SyncTreeStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) RangeInfo { +func (sts *SyncTreeStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) (RangeInfo, error) { if x == nil && y == nil { - it := sts.Min() + it, err := sts.Min() + if err != nil { + return RangeInfo{}, err + } if it == nil { return RangeInfo{ Fingerprint: sts.identity, - } + }, nil } else { x = it.Key() y = x @@ -94,17 +98,17 @@ func (sts *SyncTreeStore) GetRangeInfo(preceding Iterator, x, y Ordered, count i Count: cfp.Second.(int), Start: sts.iter(startPtr), End: sts.iter(endPtr), - } + }, nil } // Min implements ItemStore. -func (sts *SyncTreeStore) Min() Iterator { - return sts.iter(sts.st.Min()) +func (sts *SyncTreeStore) Min() (Iterator, error) { + return sts.iter(sts.st.Min()), nil } // Max implements ItemStore. -func (sts *SyncTreeStore) Max() Iterator { - return sts.iter(sts.st.Max()) +func (sts *SyncTreeStore) Max() (Iterator, error) { + return sts.iter(sts.st.Max()), nil } // Copy implements ItemStore. @@ -116,7 +120,7 @@ func (sts *SyncTreeStore) Copy() ItemStore { } // Has implements ItemStore. -func (sts *SyncTreeStore) Has(k Ordered) bool { +func (sts *SyncTreeStore) Has(k Ordered) (bool, error) { _, found := sts.st.Lookup(k) - return found + return found, nil } diff --git a/sync2/hashsync/xorsync_test.go b/sync2/hashsync/xorsync_test.go index aad8ccdbce..a017b24155 100644 --- a/sync2/hashsync/xorsync_test.go +++ b/sync2/hashsync/xorsync_test.go @@ -31,14 +31,22 @@ func TestHash32To12Xor(t *testing.T) { } func collectStoreItems[K Ordered](is ItemStore) (r []K) { - it := is.Min() + it, err := is.Min() + if err != nil { + panic("store min error") + } if it == nil { return nil } - endAt := is.Min() + endAt, err := is.Min() + if err != nil { + panic("store min error") + } for { r = append(r, it.Key().(K)) - it.Next() + if err := it.Next(); err != nil { + panic("iterator error") + } if it.Equal(endAt) { return r } diff --git a/sync2/p2p_test.go b/sync2/p2p_test.go index 35665a2fd2..a72e5d7e14 100644 --- a/sync2/p2p_test.go +++ b/sync2/p2p_test.go @@ -70,11 +70,14 @@ func TestP2P(t *testing.T) { for _, hsync := range hs { // use a snapshot to avoid races is := hsync.ItemStore().Copy() - it := is.Min() + it, err := is.Min() + require.NoError(t, err) if it == nil { return false } - if is.GetRangeInfo(nil, it.Key(), it.Key(), -1).Count < numHashes { + info, err := is.GetRangeInfo(nil, it.Key(), it.Key(), -1) + require.NoError(t, err) + if info.Count < numHashes { return false } } @@ -83,14 +86,16 @@ func TestP2P(t *testing.T) { for _, hsync := range hs { hsync.Stop() - min := hsync.ItemStore().Min() - it := hsync.ItemStore().Min() + min, err := hsync.ItemStore().Min() + require.NoError(t, err) + it, err := hsync.ItemStore().Min() + require.NoError(t, err) require.NotNil(t, it) var actualItems []types.Hash32 for { k := it.Key().(types.Hash32) actualItems = append(actualItems, k) - it.Next() + require.NoError(t, it.Next()) if it.Equal(min) { break } From 2950cab2fa1c55c3b562a64ac25ccaa1155abee9 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 25 Jun 2024 08:37:49 +0400 Subject: [PATCH 39/76] wip: fptree w/o end node works --- sync2/dbsync/dbitemstore.go | 97 ++ sync2/dbsync/dbiter.go | 211 ++++ sync2/dbsync/dbiter_test.go | 276 +++++ sync2/dbsync/dbsync.go | 771 ------------ sync2/dbsync/fptree.go | 1101 +++++++++++++++++ .../dbsync/{dbsync_test.go => fptree_test.go} | 537 +++++--- sync2/dbsync/refcountpool.go | 34 +- sync2/dbsync/refcountpool_test.go | 19 +- sync2/hashsync/handler_test.go | 19 +- sync2/hashsync/interface.go | 7 +- sync2/hashsync/mocks_test.go | 77 -- sync2/hashsync/rangesync.go | 31 + sync2/hashsync/rangesync_test.go | 70 +- sync2/hashsync/sync_tree.go | 15 - sync2/hashsync/sync_tree_store.go | 5 - sync2/hashsync/xorsync_test.go | 29 +- sync2/p2p_test.go | 14 +- 17 files changed, 2184 insertions(+), 1129 deletions(-) create mode 100644 sync2/dbsync/dbitemstore.go create mode 100644 sync2/dbsync/dbiter.go create mode 100644 sync2/dbsync/dbiter_test.go delete mode 100644 sync2/dbsync/dbsync.go create mode 100644 sync2/dbsync/fptree.go rename sync2/dbsync/{dbsync_test.go => fptree_test.go} (56%) diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go new file mode 100644 index 0000000000..000e5cdee2 --- /dev/null +++ b/sync2/dbsync/dbitemstore.go @@ -0,0 +1,97 @@ +package dbsync + +import ( + "context" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/hashsync" +) + +type DBItemStore struct { + db sql.Database + ft *fpTree + query string + keyLen int + maxDepth int + chunkSize int +} + +var _ hashsync.ItemStore = &DBItemStore{} + +func NewDBItemStore( + np *nodePool, + db sql.Database, + query string, + keyLen, maxDepth, chunkSize int, +) *DBItemStore { + dbStore := newDBBackedStore(db, query, keyLen, maxDepth) + return &DBItemStore{ + db: db, + ft: newFPTree(np, dbStore, maxDepth), + query: query, + keyLen: keyLen, + maxDepth: maxDepth, + chunkSize: chunkSize, + } +} + +// Add implements hashsync.ItemStore. +func (d *DBItemStore) Add(ctx context.Context, k hashsync.Ordered) error { + return d.ft.addHash(k.(KeyBytes)) +} + +func (d *DBItemStore) iter(min, max KeyBytes) (hashsync.Iterator, error) { + return newDBRangeIterator(d.db, d.query, min, max, d.chunkSize) +} + +// GetRangeInfo implements hashsync.ItemStore. +func (d *DBItemStore) GetRangeInfo( + preceding hashsync.Iterator, + x, y hashsync.Ordered, + count int, +) (hashsync.RangeInfo, error) { + // QQQQQ: note: iter's max is inclusive!!!! + // TBD: QQQQQ: need count limiting in ft.fingerprintInterval + panic("unimplemented") +} + +// Min implements hashsync.ItemStore. +func (d *DBItemStore) Min() (hashsync.Iterator, error) { + it1 := make(KeyBytes, d.keyLen) + it2 := make(KeyBytes, d.keyLen) + for i := range it2 { + it2[i] = 0xff + } + return d.iter(it1, it2) +} + +// Copy implements hashsync.ItemStore. +func (d *DBItemStore) Copy() hashsync.ItemStore { + return &DBItemStore{ + db: d.db, + ft: d.ft.clone(), + query: d.query, + keyLen: d.keyLen, + maxDepth: d.maxDepth, + chunkSize: d.chunkSize, + } +} + +// Has implements hashsync.ItemStore. +func (d *DBItemStore) Has(k hashsync.Ordered) (bool, error) { + id := k.(KeyBytes) + if len(id) < d.keyLen { + panic("BUG: short key passed") + } + tailRefs := []tailRef{ + {ref: load64(id) >> (64 - d.maxDepth), limit: -1}, + } + found := false + if err := d.ft.iterateIDs(tailRefs, func(_ tailRef, cur KeyBytes) bool { + c := id.Compare(cur) + found = c == 0 + return c > 0 + }); err != nil { + return false, err + } + return found, nil +} diff --git a/sync2/dbsync/dbiter.go b/sync2/dbsync/dbiter.go new file mode 100644 index 0000000000..359e43f1ba --- /dev/null +++ b/sync2/dbsync/dbiter.go @@ -0,0 +1,211 @@ +package dbsync + +import ( + "bytes" + "errors" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/hashsync" +) + +type KeyBytes []byte + +var _ hashsync.Ordered = KeyBytes(nil) + +func (k KeyBytes) Compare(other any) int { + return bytes.Compare(k, other.(KeyBytes)) +} + +type dbRangeIterator struct { + db sql.Database + from, to KeyBytes + query string + chunkSize int + chunk []KeyBytes + pos int + keyLen int +} + +var _ hashsync.Iterator = &dbRangeIterator{} + +// makeDBIterator creates a dbRangeIterator and initializes it from the database. +// Note that [from, to] range is inclusive. +func newDBRangeIterator( + db sql.Database, + query string, + from, to KeyBytes, + chunkSize int, +) (hashsync.Iterator, error) { + if from == nil { + panic("BUG: makeDBIterator: nil from") + } + if to == nil { + panic("BUG: makeDBIterator: nil to") + } + if chunkSize <= 0 { + panic("BUG: makeDBIterator: chunkSize must be > 0") + } + it := &dbRangeIterator{ + db: db, + from: from, + to: to, + query: query, + chunkSize: chunkSize, + keyLen: len(from), + chunk: make([]KeyBytes, chunkSize), + } + if err := it.load(); err != nil { + return nil, err + } + return it, nil +} + +func (it *dbRangeIterator) load() error { + n := 0 + var ierr error + _, err := it.db.Exec( + it.query, func(stmt *sql.Statement) { + stmt.BindBytes(1, it.from) + stmt.BindBytes(2, it.to) + stmt.BindInt64(3, int64(it.chunkSize)) + }, + func(stmt *sql.Statement) bool { + if n >= len(it.chunk) { + ierr = errors.New("too many rows") + return false + } + // we reuse existing slices when possible for retrieving new IDs + id := it.chunk[n] + if id == nil { + id = make([]byte, it.keyLen) + it.chunk[n] = id + } + stmt.ColumnBytes(0, id) + n++ + return true + }) + if err != nil || ierr != nil { + return errors.Join(ierr, err) + } + it.pos = 0 + if n < len(it.chunk) { + // short chunk means there are no more data + it.from = nil + it.chunk = it.chunk[:n] + } else { + copy(it.from, it.chunk[n-1]) + if incID(it.from) || bytes.Compare(it.from, it.to) >= 0 { + // no more items after this full chunk + it.from = nil + } + } + return nil +} + +func (it *dbRangeIterator) Key() hashsync.Ordered { + if it.pos < len(it.chunk) { + key := make(KeyBytes, it.keyLen) + copy(key, it.chunk[it.pos]) + return key + } + return nil +} + +func (it *dbRangeIterator) Next() error { + if it.pos >= len(it.chunk) { + return nil + } + it.pos++ + if it.pos < len(it.chunk) || it.from == nil { + return nil + } + return it.load() +} + +func incID(id []byte) (overflow bool) { + for i := len(id) - 1; i >= 0; i-- { + id[i]++ + if id[i] != 0 { + return false + } + } + + return true +} + +type concatIterator struct { + iters []hashsync.Iterator +} + +var _ hashsync.Iterator = &concatIterator{} + +// concatIterators concatenates multiple iterators into one. +// It assumes that the iterators follow one after another in the order of their keys. +func concatIterators(iters ...hashsync.Iterator) hashsync.Iterator { + return &concatIterator{iters: iters} +} + +func (c *concatIterator) Key() hashsync.Ordered { + if len(c.iters) == 0 { + return nil + } + return c.iters[0].Key() +} + +func (c *concatIterator) Next() error { + if len(c.iters) == 0 { + return nil + } + if err := c.iters[0].Next(); err != nil { + return err + } + for len(c.iters) > 0 { + if c.iters[0].Key() != nil { + break + } + c.iters = c.iters[1:] + } + return nil +} + +type combinedIterator struct { + iters []hashsync.Iterator + ahead hashsync.Iterator +} + +// combineIterators combines multiple iterators into one. +// Unlike concatIterator, it does not assume that the iterators follow one after another +// in the order of their keys. Instead, it always returns the smallest key among all +// iterators. +func combineIterators(iters ...hashsync.Iterator) hashsync.Iterator { + return &combinedIterator{iters: iters} +} + +func (c *combinedIterator) aheadIterator() hashsync.Iterator { + if c.ahead == nil { + if len(c.iters) == 0 { + return nil + } + c.ahead = c.iters[0] + for i := 1; i < len(c.iters); i++ { + if c.iters[i].Key() != nil { + if c.ahead.Key() == nil || c.iters[i].Key().Compare(c.ahead.Key()) < 0 { + c.ahead = c.iters[i] + } + } + } + } + return c.ahead +} + +func (c *combinedIterator) Key() hashsync.Ordered { + return c.aheadIterator().Key() +} + +func (c *combinedIterator) Next() error { + if err := c.aheadIterator().Next(); err != nil { + return err + } + c.ahead = nil + return nil +} diff --git a/sync2/dbsync/dbiter_test.go b/sync2/dbsync/dbiter_test.go new file mode 100644 index 0000000000..7b37ecea13 --- /dev/null +++ b/sync2/dbsync/dbiter_test.go @@ -0,0 +1,276 @@ +package dbsync + +import ( + "encoding/hex" + "errors" + "fmt" + "testing" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/hashsync" + "github.com/stretchr/testify/require" +) + +func TestIncID(t *testing.T) { + for _, tc := range []struct { + id, expected KeyBytes + overflow bool + }{ + { + id: KeyBytes{0x00, 0x00, 0x00, 0x00}, + expected: KeyBytes{0x00, 0x00, 0x00, 0x01}, + overflow: false, + }, + { + id: KeyBytes{0x00, 0x00, 0x00, 0xff}, + expected: KeyBytes{0x00, 0x00, 0x01, 0x00}, + overflow: false, + }, + { + id: KeyBytes{0xff, 0xff, 0xff, 0xff}, + expected: KeyBytes{0x00, 0x00, 0x00, 0x00}, + overflow: true, + }, + } { + id := make(KeyBytes, len(tc.id)) + copy(id, tc.id) + require.Equal(t, tc.overflow, incID(id)) + require.Equal(t, tc.expected, id) + } +} + +func populateDB(t *testing.T, keyLen int, content []KeyBytes) sql.Database { + db := sql.InMemory(sql.WithIgnoreSchemaDrift()) + t.Cleanup(func() { + require.NoError(t, db.Close()) + }) + _, err := db.Exec(fmt.Sprintf("create table foo(id char(%d))", keyLen), nil, nil) + require.NoError(t, err) + for _, id := range content { + _, err := db.Exec( + "insert into foo(id) values(?)", + func(stmt *sql.Statement) { + stmt.BindBytes(1, id) + }, nil) + require.NoError(t, err) + } + return db +} + +const testQuery = "select id from foo where id between ? and ? order by id limit ?" + +func TestDBRangeIterator(t *testing.T) { + db := populateDB(t, 4, []KeyBytes{ + {0x00, 0x00, 0x00, 0x01}, + {0x00, 0x00, 0x00, 0x03}, + {0x00, 0x00, 0x00, 0x05}, + {0x00, 0x00, 0x00, 0x07}, + {0x00, 0x00, 0x01, 0x00}, + {0x00, 0x00, 0x03, 0x00}, + {0x00, 0x01, 0x00, 0x00}, + {0x00, 0x05, 0x00, 0x00}, + {0x03, 0x05, 0x00, 0x00}, + {0x09, 0x05, 0x00, 0x00}, + {0x0a, 0x05, 0x00, 0x00}, + {0xff, 0xff, 0xff, 0xff}, + }) + for _, tc := range []struct { + from, to KeyBytes + chunkSize int + items []KeyBytes + }{ + { + from: KeyBytes{0x00, 0x00, 0x00, 0x00}, + to: KeyBytes{0x00, 0x00, 0x00, 0x00}, + chunkSize: 4, + items: nil, + }, + { + from: KeyBytes{0x00, 0x00, 0x00, 0x00}, + to: KeyBytes{0x00, 0x00, 0x00, 0x08}, + chunkSize: 4, + items: []KeyBytes{ + {0x00, 0x00, 0x00, 0x01}, + {0x00, 0x00, 0x00, 0x03}, + {0x00, 0x00, 0x00, 0x05}, + {0x00, 0x00, 0x00, 0x07}, + }, + }, + { + from: KeyBytes{0x00, 0x00, 0x00, 0x00}, + to: KeyBytes{0x00, 0x00, 0x03, 0x00}, + chunkSize: 4, + items: []KeyBytes{ + {0x00, 0x00, 0x00, 0x01}, + {0x00, 0x00, 0x00, 0x03}, + {0x00, 0x00, 0x00, 0x05}, + {0x00, 0x00, 0x00, 0x07}, + {0x00, 0x00, 0x01, 0x00}, + {0x00, 0x00, 0x03, 0x00}, + }, + }, + { + from: KeyBytes{0x00, 0x00, 0x03, 0x00}, + to: KeyBytes{0x09, 0x05, 0x00, 0x00}, + chunkSize: 4, + items: []KeyBytes{ + {0x00, 0x00, 0x03, 0x00}, + {0x00, 0x01, 0x00, 0x00}, + {0x00, 0x05, 0x00, 0x00}, + {0x03, 0x05, 0x00, 0x00}, + {0x09, 0x05, 0x00, 0x00}, + }, + }, + { + from: KeyBytes{0x00, 0x00, 0x00, 0x00}, + to: KeyBytes{0xff, 0xff, 0xff, 0xff}, + chunkSize: 4, + items: []KeyBytes{ + {0x00, 0x00, 0x00, 0x01}, + {0x00, 0x00, 0x00, 0x03}, + {0x00, 0x00, 0x00, 0x05}, + {0x00, 0x00, 0x00, 0x07}, + {0x00, 0x00, 0x01, 0x00}, + {0x00, 0x00, 0x03, 0x00}, + {0x00, 0x01, 0x00, 0x00}, + {0x00, 0x05, 0x00, 0x00}, + {0x03, 0x05, 0x00, 0x00}, + {0x09, 0x05, 0x00, 0x00}, + {0x0a, 0x05, 0x00, 0x00}, + {0xff, 0xff, 0xff, 0xff}, + }, + }, + } { + it, err := newDBRangeIterator(db, testQuery, tc.from, tc.to, tc.chunkSize) + require.NoError(t, err) + if len(tc.items) == 0 { + require.Nil(t, it.Key()) + } else { + var collected []KeyBytes + for i := 0; i < len(tc.items); i++ { + if k := it.Key(); k != nil { + collected = append(collected, k.(KeyBytes)) + } else { + break + } + require.NoError(t, it.Next()) + } + require.Nil(t, it.Key()) + require.Equal(t, tc.items, collected, "from=%s to=%s chunkSize=%d", + hex.EncodeToString(tc.from), hex.EncodeToString(tc.to), tc.chunkSize) + } + } +} + +type fakeIterator struct { + items []KeyBytes +} + +var _ hashsync.Iterator = &fakeIterator{} + +func (it *fakeIterator) Key() hashsync.Ordered { + if len(it.items) == 0 { + return nil + } + return KeyBytes(it.items[0]) +} + +func (it *fakeIterator) Next() error { + if len(it.items) != 0 { + it.items = it.items[1:] + } + if len(it.items) != 0 && string(it.items[0]) == "error" { + return errors.New("iterator error") + } + return nil +} + +func TestConcatIterators(t *testing.T) { + it1 := &fakeIterator{ + items: []KeyBytes{ + {0x00, 0x00, 0x00, 0x01}, + {0x00, 0x00, 0x00, 0x03}, + }, + } + it2 := &fakeIterator{ + items: []KeyBytes{ + {0x0a, 0x05, 0x00, 0x00}, + {0xff, 0xff, 0xff, 0xff}, + }, + } + + it := concatIterators(it1, it2) + var collected []KeyBytes + for i := 0; i < 4; i++ { + collected = append(collected, it.Key().(KeyBytes)) + require.NoError(t, it.Next()) + } + require.Nil(t, it.Key()) + require.Equal(t, []KeyBytes{ + {0x00, 0x00, 0x00, 0x01}, + {0x00, 0x00, 0x00, 0x03}, + {0x0a, 0x05, 0x00, 0x00}, + {0xff, 0xff, 0xff, 0xff}, + }, collected) + + it1 = &fakeIterator{items: []KeyBytes{KeyBytes{0, 0, 0, 0}, KeyBytes("error")}} + it2 = &fakeIterator{items: nil} + + it = concatIterators(it1, it2) + require.Equal(t, KeyBytes{0, 0, 0, 0}, it.Key()) + require.Error(t, it.Next()) + + it1 = &fakeIterator{items: []KeyBytes{KeyBytes{0, 0, 0, 0}}} + it2 = &fakeIterator{items: []KeyBytes{KeyBytes{0, 0, 0, 1}, KeyBytes("error")}} + + it = concatIterators(it1, it2) + require.Equal(t, KeyBytes{0, 0, 0, 0}, it.Key()) + require.NoError(t, it.Next()) + require.Equal(t, KeyBytes{0, 0, 0, 1}, it.Key()) + require.Error(t, it.Next()) +} + +func TestCombineIterators(t *testing.T) { + it1 := &fakeIterator{ + items: []KeyBytes{ + {0x00, 0x00, 0x00, 0x01}, + {0x0a, 0x05, 0x00, 0x00}, + }, + } + it2 := &fakeIterator{ + items: []KeyBytes{ + {0x00, 0x00, 0x00, 0x03}, + {0xff, 0xff, 0xff, 0xff}, + }, + } + + it := combineIterators(it1, it2) + var collected []KeyBytes + for i := 0; i < 4; i++ { + collected = append(collected, it.Key().(KeyBytes)) + require.NoError(t, it.Next()) + } + require.Nil(t, it.Key()) + require.Equal(t, []KeyBytes{ + {0x00, 0x00, 0x00, 0x01}, + {0x00, 0x00, 0x00, 0x03}, + {0x0a, 0x05, 0x00, 0x00}, + {0xff, 0xff, 0xff, 0xff}, + }, collected) + + it1 = &fakeIterator{items: []KeyBytes{KeyBytes{0, 0, 0, 0}, KeyBytes("error")}} + it2 = &fakeIterator{items: nil} + + it = combineIterators(it1, it2) + require.Equal(t, KeyBytes{0, 0, 0, 0}, it.Key()) + require.Error(t, it.Next()) + + it1 = &fakeIterator{items: []KeyBytes{KeyBytes{0, 0, 0, 0}}} + it2 = &fakeIterator{items: []KeyBytes{KeyBytes{0, 0, 0, 1}, KeyBytes("error")}} + + it = combineIterators(it1, it2) + require.Equal(t, KeyBytes{0, 0, 0, 0}, it.Key()) + require.NoError(t, it.Next()) + require.Equal(t, KeyBytes{0, 0, 0, 1}, it.Key()) + require.Error(t, it.Next()) +} diff --git a/sync2/dbsync/dbsync.go b/sync2/dbsync/dbsync.go deleted file mode 100644 index 04954957bf..0000000000 --- a/sync2/dbsync/dbsync.go +++ /dev/null @@ -1,771 +0,0 @@ -package dbsync - -import ( - "bytes" - "encoding/binary" - "encoding/hex" - "fmt" - "io" - "math/bits" - "slices" - "strconv" - - "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" -) - -const ( - fingerprintBytes = 12 - // cachedBits = 24 - // cachedSize = 1 << cachedBits - // cacheMask = cachedSize - 1 - maxIDBytes = 32 - bit63 = 1 << 63 -) - -type fingerprint [fingerprintBytes]byte - -func (fp fingerprint) Compare(other fingerprint) int { - return bytes.Compare(fp[:], other[:]) -} - -func (fp fingerprint) String() string { - return hex.EncodeToString(fp[:]) -} - -func (fp *fingerprint) update(h []byte) { - for n := range *fp { - (*fp)[n] ^= h[n] - } -} - -func (fp *fingerprint) bitFromLeft(n int) bool { - if n > fingerprintBytes*8 { - panic("BUG: bad fingerprint bit index") - } - return (fp[n>>3]>>(7-n&0x7))&1 != 0 -} - -func hexToFingerprint(s string) fingerprint { - b, err := hex.DecodeString(s) - if err != nil { - panic("bad hex fingerprint: " + err.Error()) - } - var fp fingerprint - if len(b) != len(fp) { - panic("bad hex fingerprint") - } - copy(fp[:], b) - return fp -} - -// const ( -// nodeFlagLeaf = 1 << 31 -// nodeFlagMask = nodeFlagLeaf -// ) - -// NOTE: all leafs are on the last level - -// type nodeIndex uint32 - -// const noIndex nodeIndex = ^nodeIndex(0) - -// // TODO: nodePool limiting -// type nodePool struct { -// mtx sync.Mutex -// nodes []node -// // freeList is 1-based so that nodePool doesn't need a constructor -// freeList nodeIndex -// } - -// func (np *nodePool) node(idx nodeIndex) node { -// np.mtx.Lock() -// defer np.mtx.Unlock() -// return np.nodeUnlocked(idx) -// } - -// func (np *nodePool) nodeUnlocked(idx nodeIndex) node { -// node := &np.nodes[idx] -// refs := node.refCount -// if refs < 0 { -// panic("BUG: negative nodePool entry refcount") -// } else if refs == 0 { -// panic("BUG: referencing a free nodePool entry") -// } -// return *node -// } - -// func (np *nodePool) add(fp fingerprint, c uint32, left, right nodeIndex) nodeIndex { -// np.mtx.Lock() -// defer np.mtx.Unlock() -// var idx nodeIndex -// // validate indices -// if left != noIndex { -// np.nodeUnlocked(left) -// } -// if right != noIndex { -// np.nodeUnlocked(right) -// } -// if np.freeList != 0 { -// idx = nodeIndex(np.freeList - 1) -// np.freeList = np.nodes[idx].left -// np.nodes[idx].refCount++ -// if np.nodes[idx].refCount != 1 { -// panic("BUG: refCount != 1 for a node taken from the freelist") -// } -// } else { -// idx = nodeIndex(len(np.nodes)) -// np.nodes = append(np.nodes, node{refCount: 1}) -// } -// node := &np.nodes[idx] -// node.fp = fp -// node.c = c -// node.left = left -// node.right = right -// return idx -// } - -// func (np *nodePool) release(idx nodeIndex) { -// np.mtx.Lock() -// defer np.mtx.Unlock() -// node := &np.nodes[idx] -// if node.refCount <= 0 { -// panic("BUG: negative nodePool entry refcount") -// } -// node.refCount-- -// if node.refCount == 0 { -// node.left = np.freeList -// np.freeList = idx + 1 -// } -// } - -// func (np *nodePool) ref(idx nodeIndex) { -// np.mtx.Lock() -// np.nodes[idx].refCount++ -// np.mtx.Unlock() -// } - -type nodeIndex uint32 - -const noIndex = ^nodeIndex(0) - -type nodePool struct { - rcPool[node, nodeIndex] -} - -func (np *nodePool) add(fp fingerprint, c uint32, left, right nodeIndex) nodeIndex { - return np.rcPool.add(node{fp: fp, c: c, left: left, right: right}) -} - -func (np *nodePool) node(idx nodeIndex) node { - return np.rcPool.item(idx) -} - -// fpTree node. -// The nodes are immutable except for refCount field, which should -// only be used directly by nodePool methods -type node struct { - fp fingerprint - c uint32 - left, right nodeIndex -} - -func (n node) leaf() bool { - return n.left == noIndex && n.right == noIndex -} - -// type node struct { -// // 16-byte structure with alignment -// // The cache is 512 MiB per 1<<24 (16777216) IDs -// fp fingerprint -// c uint32 -// } - -// func (node *node) empty() bool { -// return node.c == 0 -// } - -// func (node *node) leaf() bool { -// return node.c&nodeFlagLeaf != 0 -// } - -// func (node *node) count() uint32 { -// if node.leaf() { -// return 1 -// } -// return node.c -// } - -const ( - prefixLenBits = 6 - prefixLenMask = 1<> prefixLenBits) -} - -func (p prefix) left() prefix { - l := uint64(p) & prefixLenMask - if l == maxPrefixLen { - panic("BUG: max prefix len reached") - } - return prefix((uint64(p)&prefixBitMask)<<1 + l + 1) -} - -func (p prefix) right() prefix { - return p.left() + (1 << prefixLenBits) -} - -func (p prefix) dir(bit bool) prefix { - if bit { - return p.right() - } - return p.left() -} - -func (p prefix) String() string { - if p.len() == 0 { - return "<0>" - } - b := fmt.Sprintf("%064b", p.bits()) - return fmt.Sprintf("<%d:%s>", p.len(), b[64-p.len():]) -} - -func (p prefix) highBit() bool { - if p == 0 { - return false - } - return p.bits()>>(p.len()-1) != 0 -} - -func (p prefix) minID(b []byte) { - if len(b) < 8 { - panic("BUG: id slice too small") - } - v := p.bits() << (64 - p.len()) - binary.BigEndian.PutUint64(b, v) - for n := 8; n < len(b); n++ { - b[n] = 0 - } -} - -func (p prefix) maxID(b []byte) { - if len(b) < 8 { - panic("BUG: id slice too small") - } - s := uint64(64 - p.len()) - v := (p.bits() << s) | ((1 << s) - 1) - binary.BigEndian.PutUint64(b, v) - for n := 8; n < len(b); n++ { - b[n] = 0xff - } -} - -// shift removes the highest bit from the prefix -func (p prefix) shift() prefix { - switch l := p.len(); l { - case 0: - panic("BUG: can't shift zero prefix") - case 1: - return 0 - default: - l-- - return mkprefix(p.bits()&((1< maxPrefixLen { -// panic("BUG: bad prefix length") -// } -// if nbits == 0 { -// return 0 -// } -// v := load64(h) -// return prefix((v>>(64-nbits-prefixLenBits))&prefixBitMask + uint64(nbits)) -// } - -func preFirst0(h []byte) prefix { - l := min(maxPrefixLen, bits.LeadingZeros64(^load64(h))) - return mkprefix((1<>(64-l), l) -} - -type fpResult struct { - fp fingerprint - count uint32 -} - -type aggResult struct { - tailRefs []uint64 - fp fingerprint - count uint32 - itype int -} - -func (r *aggResult) update(node node) { - r.fp.update(node.fp[:]) - r.count += node.c - // fmt.Fprintf(os.Stderr, "QQQQQ: r.count <= %d r.fp <= %s\n", r.count, r.fp) -} - -type idStore interface { - registerHash(h []byte, maxDepth int) error - iterateIDs(tailRefs []uint64, maxDepth int, toCall func(id []byte)) error -} - -type fpTree struct { - np *nodePool - idStore idStore - root nodeIndex - maxDepth int -} - -func newFPTree(np *nodePool, idStore idStore, maxDepth int) *fpTree { - return &fpTree{np: np, idStore: idStore, root: noIndex, maxDepth: maxDepth} -} - -func (ft *fpTree) pushDown(fpA, fpB fingerprint, p prefix, curCount uint32) nodeIndex { - // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: fpA %s fpB %s p %s\n", fpA.String(), fpB.String(), p) - fpCombined := fpA - fpCombined.update(fpB[:]) - if ft.maxDepth != 0 && p.len() == ft.maxDepth { - // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: add at maxDepth\n") - return ft.np.add(fpCombined, curCount+1, noIndex, noIndex) - } - if curCount != 1 { - panic("BUG: pushDown of non-1-leaf below maxDepth") - } - dirA := fpA.bitFromLeft(p.len()) - dirB := fpB.bitFromLeft(p.len()) - // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: bitFromLeft %d: dirA %v dirB %v\n", p.len(), dirA, dirB) - if dirA == dirB { - childIdx := ft.pushDown(fpA, fpB, p.dir(dirA), 1) - if dirA { - // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: sameDir: left\n") - return ft.np.add(fpCombined, 2, noIndex, childIdx) - } else { - // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: sameDir: right\n") - return ft.np.add(fpCombined, 2, childIdx, noIndex) - } - } - - idxA := ft.np.add(fpA, 1, noIndex, noIndex) - idxB := ft.np.add(fpB, curCount, noIndex, noIndex) - if dirA { - // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: add A-B\n") - return ft.np.add(fpCombined, 2, idxB, idxA) - } else { - // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: add B-A\n") - return ft.np.add(fpCombined, 2, idxA, idxB) - } -} - -func (ft *fpTree) addValue(fp fingerprint, p prefix, idx nodeIndex) nodeIndex { - if idx == noIndex { - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: addNew fp %s p %s idx %d\n", fp.String(), p.String(), idx) - return ft.np.add(fp, 1, noIndex, noIndex) - } - node := ft.np.node(idx) - // We've got a copy of the node, so we release it right away. - // This way, it'll likely be reused for the new nodes created - // as this hash is being added, as the node pool's freeList is - // LIFO - ft.np.release(idx) - if node.c == 1 || (ft.maxDepth != 0 && p.len() == ft.maxDepth) { - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: pushDown fp %s p %s idx %d\n", fp.String(), p.String(), idx) - // we're at a leaf node, need to push down the old fingerprint, or, - // if we've reached the max depth, just update the current node - return ft.pushDown(fp, node.fp, p, node.c) - } - fpCombined := fp - fpCombined.update(node.fp[:]) - if fp.bitFromLeft(p.len()) { - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceRight fp %s p %s idx %d\n", fp.String(), p.String(), idx) - newRight := ft.addValue(fp, p.right(), node.right) - return ft.np.add(fpCombined, node.c+1, node.left, newRight) - } else { - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceLeft fp %s p %s idx %d\n", fp.String(), p.String(), idx) - newLeft := ft.addValue(fp, p.left(), node.left) - return ft.np.add(fpCombined, node.c+1, newLeft, node.right) - } -} - -func (ft *fpTree) addHash(h []byte) error { - var fp fingerprint - fp.update(h) - ft.root = ft.addValue(fp, 0, ft.root) - return ft.idStore.registerHash(h, ft.maxDepth) - // fmt.Fprintf(os.Stderr, "QQQQQ: addHash: new root %d\n", ft.root) -} - -func (ft *fpTree) followPrefix(from nodeIndex, p prefix) (nodeIndex, bool) { - // fmt.Fprintf(os.Stderr, "QQQQQ: followPrefix: from %d p %s highBit %v\n", from, p, p.highBit()) - switch { - case p == 0: - return from, true - case from == noIndex: - return noIndex, false - case ft.np.node(from).leaf(): - return from, false - case p.highBit(): - return ft.followPrefix(ft.np.node(from).right, p.shift()) - default: - return ft.followPrefix(ft.np.node(from).left, p.shift()) - } -} - -func (ft *fpTree) tailRefFromPrefix(p prefix) uint64 { - // TODO: QQQQ: FIXME: this may happen with reverse intervals, - // but should we even be checking the prefixes in this case? - // if p.len() != ft.maxDepth { - // panic("BUG: tail from short prefix") - // } - return p.bits() -} - -func (ft *fpTree) tailRefFromFingerprint(fp fingerprint) uint64 { - v := load64(fp[:]) - if ft.maxDepth >= 64 { - return v - } - // fmt.Fprintf(os.Stderr, "QQQQQ: AAAAA: v %016x maxDepth %d shift %d\n", v, ft.maxDepth, (64 - ft.maxDepth)) - return v >> (64 - ft.maxDepth) -} - -func (ft *fpTree) tailRefFromNodeAndPrefix(n node, p prefix) uint64 { - if n.c == 1 { - return ft.tailRefFromFingerprint(n.fp) - } else { - return ft.tailRefFromPrefix(p) - } -} - -func (ft *fpTree) aggregateLeft(idx nodeIndex, v uint64, p prefix, r *aggResult) { - if idx == noIndex { - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft %d %016x %s: noIndex\n", idx, v, p) - return - } - node := ft.np.node(idx) - switch { - case p.len() == ft.maxDepth: - if node.left != noIndex || node.right != noIndex { - panic("BUG: node @ maxDepth has children") - } - tail := ft.tailRefFromPrefix(p) - r.tailRefs = append(r.tailRefs, tail) - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft %d %016x %s: hit maxDepth, add prefix to the tails: %016x\n", idx, v, p, tail) - case node.leaf(): - // For leaf 1-nodes, we can use the fingerprint to get tailRef - // by which the actual IDs will be selected - if node.c != 1 { - panic("BUG: leaf non-1 node below maxDepth") - } - tail := ft.tailRefFromFingerprint(node.fp) - r.tailRefs = append(r.tailRefs, tail) - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft %d %016x %s: hit 1-leaf, add prefix to the tails: %016x (fp %s)\n", idx, v, p, tail, node.fp) - case v&bit63 == 0: - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft %d %016x %s: incl right node %d + go left to node %d\n", idx, v, p, node.right, node.left) - if node.right != noIndex { - r.update(ft.np.node(node.right)) - } - ft.aggregateLeft(node.left, v<<1, p.left(), r) - default: - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft %d %016x %s: go right node %d\n", idx, v, p, node.right) - ft.aggregateLeft(node.right, v<<1, p.right(), r) - } -} - -func (ft *fpTree) aggregateRight(idx nodeIndex, v uint64, p prefix, r *aggResult) { - if idx == noIndex { - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight %d %016x %s: noIndex\n", idx, v, p) - return - } - node := ft.np.node(idx) - switch { - case p.len() == ft.maxDepth: - if node.left != noIndex || node.right != noIndex { - panic("BUG: node @ maxDepth has children") - } - tail := ft.tailRefFromPrefix(p) - r.tailRefs = append(r.tailRefs, tail) - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight %d %016x %s: hit maxDepth, add prefix to the tails: %016x\n", idx, v, p, tail) - case node.leaf(): - // For leaf 1-nodes, we can use the fingerprint to get tailRef - // by which the actual IDs will be selected - if node.c != 1 { - panic("BUG: leaf non-1 node below maxDepth") - } - tail := ft.tailRefFromFingerprint(node.fp) - r.tailRefs = append(r.tailRefs, tail) - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight %d %016x %s: hit 1-leaf, add prefix to the tails: %016x (fp %s)\n", idx, v, p, tail, node.fp) - case v&bit63 == 0: - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight %d %016x %s: go left to node %d\n", idx, v, p, node.left) - ft.aggregateRight(node.left, v<<1, p.left(), r) - default: - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateRight %d %016x %s: incl left node %d + go right to node %d\n", idx, v, p, node.left, node.right) - if node.left != noIndex { - r.update(ft.np.node(node.left)) - } - ft.aggregateRight(node.right, v<<1, p.right(), r) - } -} - -func (ft *fpTree) aggregateInterval(x, y []byte) aggResult { - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateInterval: %s %s\n", hex.EncodeToString(x), hex.EncodeToString(y)) - var r aggResult - r.itype = bytes.Compare(x, y) - switch { - case r.itype == 0: - // the whole set - if ft.root != noIndex { - r.update(ft.np.node(ft.root)) - } - case r.itype < 0: - // "proper" interval: [x; lca); (lca; y) - p := commonPrefix(x, y) - lca, found := ft.followPrefix(ft.root, p) - // fmt.Fprintf(os.Stderr, "QQQQQ: commonPrefix %s lca %d found %v\n", p, lca, found) - switch { - case found: - lcaNode := ft.np.node(lca) - ft.aggregateLeft(lcaNode.left, load64(x)<<(p.len()+1), p.left(), &r) - ft.aggregateRight(lcaNode.right, load64(y)<<(p.len()+1), p.right(), &r) - case lca != noIndex: - // fmt.Fprintf(os.Stderr, "QQQQQ: commonPrefix %s NOT found but have lca %d\n", p, lca) - // Didn't reach LCA in the tree b/c ended up - // at a leaf, just use the prefix to go - // through the IDs - lcaNode := ft.np.node(lca) - r.tailRefs = append(r.tailRefs, ft.tailRefFromNodeAndPrefix(lcaNode, p)) - } - default: - // inverse interval: [min; y); [x; max] - pf1 := preFirst1(y) - idx1, found := ft.followPrefix(ft.root, pf1) - // fmt.Fprintf(os.Stderr, "QQQQQ: pf1 %s idx1 %d found %v\n", pf1, idx1, found) - switch { - case found: - ft.aggregateRight(idx1, load64(y)<> (64 - maxDepth) - s := m.ids[idx] - n := slices.IndexFunc(s, func(cur []byte) bool { - return bytes.Compare(cur, h) > 0 - }) - if n < 0 { - m.ids[idx] = append(s, h) - } else { - m.ids[idx] = slices.Insert(s, n, h) - } - return nil -} - -func (m *memIDStore) iterateIDs(tailRefs []uint64, maxDepth int, toCall func(id []byte)) error { - for _, t := range tailRefs { - ids := m.ids[t] - for _, id := range ids { - toCall(id) - } - } - return nil -} - -type sqlIDStore struct { - db sql.StateDatabase -} - -func newSQLIDStore(db sql.StateDatabase) *sqlIDStore { - return &sqlIDStore{db: db} -} - -func (s *sqlIDStore) registerHash(h []byte, maxDepth int) error { - // should be registered by the handler code - return nil -} - -func (s *sqlIDStore) iterateIDs(tailRefs []uint64, maxDepth int, toCall func(id []byte)) error { - for _, t := range tailRefs { - p := mkprefix(t, maxDepth) - var minID, maxID types.Hash32 - p.minID(minID[:]) - p.maxID(maxID[:]) - // start := time.Now() - if _, err := s.db.Exec( - "select id from atxs where id between ? and ?", - func(stmt *sql.Statement) { - stmt.BindBytes(1, minID[:]) - stmt.BindBytes(2, maxID[:]) - }, - func(stmt *sql.Statement) bool { - var id types.Hash32 - stmt.ColumnBytes(0, id[:]) - toCall(id[:]) - return true - }, - ); err != nil { - return err - } - // fmt.Fprintf(os.Stderr, "QQQQQ: %v: sel atxs between %s and %s\n", time.Now().Sub(start), minID.String(), maxID.String()) - } - return nil -} - -// type inMemFPTree struct { -// tree *fpTree -// ids [][][]byte -// } - -// func newInMemFPTree(np *nodePool, maxDepth int) *inMemFPTree { -// if maxDepth == 0 { -// panic("BUG: can't use newInMemFPTree with zero maxDepth") -// } -// return &inMemFPTree{ -// tree: newFPTree(np, maxDepth), -// ids: make([][][]byte, 1<> (64 - mft.tree.maxDepth) -// s := mft.ids[idx] -// n := slices.IndexFunc(s, func(cur []byte) bool { -// return bytes.Compare(cur, h) > 0 -// }) -// if n < 0 { -// mft.ids[idx] = append(s, h) -// } else { -// mft.ids[idx] = slices.Insert(s, n, h) -// } -// } - -// func (mft *inMemFPTree) aggregateInterval(x, y []byte) fpResult { -// r := mft.tree.aggregateInterval(x, y) -// for _, t := range r.tailRefs { -// ids := mft.ids[t] -// for _, id := range ids { -// // FIXME: this can be optimized as the IDs are ordered -// if idWithinInterval(id, x, y, r.itype) { -// // fmt.Fprintf(os.Stderr, "QQQQQ: including tail: %s\n", hex.EncodeToString(id)) -// r.fp.update(id) -// r.count++ -// } else { -// // fmt.Fprintf(os.Stderr, "QQQQQ: NOT including tail: %s\n", hex.EncodeToString(id)) -// } -// } -// } -// return fpResult{fp: r.fp, count: r.count} -// } - -func idWithinInterval(id, x, y []byte, itype int) bool { - switch itype { - case 0: - return true - case -1: - return bytes.Compare(id, x) >= 0 && bytes.Compare(id, y) < 0 - default: - return bytes.Compare(id, y) < 0 || bytes.Compare(id, x) >= 0 - } -} - -// TBD: perhaps use json-based SELECTs -// TBD: extra cache for after-24bit entries -// TBD: benchmark 24-bit limit (not going beyond the cache) -// TBD: optimize, get rid of binary.BigEndian.* diff --git a/sync2/dbsync/fptree.go b/sync2/dbsync/fptree.go new file mode 100644 index 0000000000..f5404639c6 --- /dev/null +++ b/sync2/dbsync/fptree.go @@ -0,0 +1,1101 @@ +package dbsync + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "fmt" + "io" + "math/bits" + "os" + "runtime" + "slices" + "strconv" + "strings" + "sync" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +type trace struct { + traceEnabled bool + traceStack []string +} + +func (t *trace) out(msg string) { + fmt.Fprintf(os.Stderr, "TRACE: %s%s\n", strings.Repeat(" ", len(t.traceStack)), msg) +} + +func (t *trace) enter(format string, args ...any) { + if !t.traceEnabled { + return + } + msg := fmt.Sprintf(format, args...) + t.out("ENTER: " + msg) + t.traceStack = append(t.traceStack, msg) +} + +func (t *trace) leave(results ...any) { + if !t.traceEnabled { + return + } + if len(t.traceStack) == 0 { + panic("BUG: trace stack underflow") + } + msg := t.traceStack[len(t.traceStack)-1] + if len(results) != 0 { + var r []string + for _, res := range results { + r = append(r, fmt.Sprint(res)) + } + msg += " => " + strings.Join(r, ", ") + } + t.traceStack = t.traceStack[:len(t.traceStack)-1] + t.out("LEAVE: " + msg) +} + +func (t *trace) log(format string, args ...any) { + if t.traceEnabled { + msg := fmt.Sprintf(format, args...) + t.out(msg) + } +} + +const ( + fingerprintBytes = 12 + // cachedBits = 24 + // cachedSize = 1 << cachedBits + // cacheMask = cachedSize - 1 + maxIDBytes = 32 + bit63 = 1 << 63 +) + +type fingerprint [fingerprintBytes]byte + +func (fp fingerprint) Compare(other fingerprint) int { + return bytes.Compare(fp[:], other[:]) +} + +func (fp fingerprint) String() string { + return hex.EncodeToString(fp[:]) +} + +func (fp *fingerprint) update(h []byte) { + for n := range *fp { + (*fp)[n] ^= h[n] + } +} + +func (fp *fingerprint) bitFromLeft(n int) bool { + if n > fingerprintBytes*8 { + panic("BUG: bad fingerprint bit index") + } + return (fp[n>>3]>>(7-n&0x7))&1 != 0 +} + +func hexToFingerprint(s string) fingerprint { + b, err := hex.DecodeString(s) + if err != nil { + panic("bad hex fingerprint: " + err.Error()) + } + var fp fingerprint + if len(b) != len(fp) { + panic("bad hex fingerprint") + } + copy(fp[:], b) + return fp +} + +type nodeIndex uint32 + +const noIndex = ^nodeIndex(0) + +type nodePool struct { + rcPool[node, nodeIndex] +} + +func (np *nodePool) add(fp fingerprint, c uint32, left, right nodeIndex) nodeIndex { + // panic("TBD: this is invalid, adds unneeded refs") + // if left != noIndex { + // np.rcPool.ref(left) + // } + // if right != noIndex { + // np.rcPool.ref(right) + // } + idx := np.rcPool.add(node{fp: fp, c: c, left: left, right: right}) + // fmt.Fprintf(os.Stderr, "QQQQQ: add: idx %d fp %s c %d left %d right %d\n", idx, fp, c, left, right) + return idx +} + +func (np *nodePool) ref(idx nodeIndex) { // TBD: QQQQ: rmme + // fmt.Fprintf(os.Stderr, "QQQQQ: ref: idx %d\n", idx) + np.rcPool.ref(idx) +} + +func (np *nodePool) release(idx nodeIndex) bool { // TBD: QQQQ: rmme + r := np.rcPool.release(idx) + // fmt.Fprintf(os.Stderr, "QQQQQ: release: idx %d: %v\n", idx, r) + return r +} + +func (np *nodePool) node(idx nodeIndex) node { + return np.rcPool.item(idx) +} + +// fpTree node. +// The nodes are immutable except for refCount field, which should +// only be used directly by nodePool methods +type node struct { + fp fingerprint + c uint32 + left, right nodeIndex +} + +func (n node) leaf() bool { + return n.left == noIndex && n.right == noIndex +} + +const ( + prefixLenBits = 6 + prefixLenMask = 1<> prefixLenBits) +} + +func (p prefix) left() prefix { + l := uint64(p) & prefixLenMask + if l == maxPrefixLen { + panic("BUG: max prefix len reached") + } + return prefix((uint64(p)&prefixBitMask)<<1 + l + 1) +} + +func (p prefix) right() prefix { + return p.left() + (1 << prefixLenBits) +} + +func (p prefix) dir(bit bool) prefix { + if bit { + return p.right() + } + return p.left() +} + +func (p prefix) String() string { + if p.len() == 0 { + return "<0>" + } + b := fmt.Sprintf("%064b", p.bits()) + return fmt.Sprintf("<%d:%s>", p.len(), b[64-p.len():]) +} + +func (p prefix) highBit() bool { + if p == 0 { + return false + } + return p.bits()>>(p.len()-1) != 0 +} + +func (p prefix) lowBit() bool { + return p&(1< maxPrefixLen { +// panic("BUG: bad prefix length") +// } +// if nbits == 0 { +// return 0 +// } +// v := load64(h) +// return prefix((v>>(64-nbits-prefixLenBits))&prefixBitMask + uint64(nbits)) +// } + +func preFirst0(h KeyBytes) prefix { + l := min(maxPrefixLen, bits.LeadingZeros64(^load64(h))) + return mkprefix((1<>(64-l), l) +} + +type fpResult struct { + fp fingerprint + count uint32 + itype int +} + +type tailRef struct { + // node from which this tailRef has been derived + idx nodeIndex + // maxDepth bits of the key + ref uint64 + // max count to get from this tail ref, -1 for unlimited + limit int +} + +type aggResult struct { + tailRefs []tailRef + fp fingerprint + count uint32 + itype int + limit int + lastVisited nodeIndex + lastPrefix prefix +} + +func (r *aggResult) takeAtMost(count int) int { + switch { + case r.limit < 0: + return -1 + case count <= r.limit: + r.limit -= count + default: + count = r.limit + r.limit = 0 + } + return count +} + +func (r *aggResult) update(node node) { + r.fp.update(node.fp[:]) + r.count += node.c + // // fmt.Fprintf(os.Stderr, "QQQQQ: r.count <= %d r.fp <= %s\n", r.count, r.fp) +} + +type idStore interface { + clone() idStore + registerHash(h KeyBytes) error + iterateIDs(tailRefs []tailRef, toCall func(tailRef, KeyBytes) bool) error +} + +type fpTree struct { + trace // rmme + idStore + np *nodePool + rootMtx sync.Mutex + root nodeIndex + maxDepth int +} + +func newFPTree(np *nodePool, idStore idStore, maxDepth int) *fpTree { + ft := &fpTree{np: np, idStore: idStore, root: noIndex, maxDepth: maxDepth} + runtime.SetFinalizer(ft, (*fpTree).release) + return ft +} + +func (ft *fpTree) releaseNode(idx nodeIndex) { + if idx == noIndex { + return + } + node := ft.np.node(idx) + if ft.np.release(idx) { + // fmt.Fprintf(os.Stderr, "QQQQQ: releaseNode: freed %d, release l %d r %d\n", idx, node.left, node.right) + ft.releaseNode(node.left) + ft.releaseNode(node.right) + } else { + // fmt.Fprintf(os.Stderr, "QQQQQ: releaseNode: keep %d\n", idx) + } +} + +func (ft *fpTree) release() { + ft.rootMtx.Lock() + defer ft.rootMtx.Unlock() + ft.releaseNode(ft.root) + ft.root = noIndex +} + +func (ft *fpTree) clone() *fpTree { + ft.rootMtx.Lock() + defer ft.rootMtx.Unlock() + if ft.root != noIndex { + ft.np.ref(ft.root) + } + return &fpTree{ + np: ft.np, + idStore: ft.idStore.clone(), + root: ft.root, + maxDepth: ft.maxDepth, + } +} + +func (ft *fpTree) pushDown(fpA, fpB fingerprint, p prefix, curCount uint32) nodeIndex { + // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: fpA %s fpB %s p %s\n", fpA.String(), fpB.String(), p) + fpCombined := fpA + fpCombined.update(fpB[:]) + if ft.maxDepth != 0 && p.len() == ft.maxDepth { + // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: add at maxDepth\n") + return ft.np.add(fpCombined, curCount+1, noIndex, noIndex) + } + if curCount != 1 { + panic("BUG: pushDown of non-1-leaf below maxDepth") + } + dirA := fpA.bitFromLeft(p.len()) + dirB := fpB.bitFromLeft(p.len()) + // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: bitFromLeft %d: dirA %v dirB %v\n", p.len(), dirA, dirB) + if dirA == dirB { + childIdx := ft.pushDown(fpA, fpB, p.dir(dirA), 1) + if dirA { + r := ft.np.add(fpCombined, 2, noIndex, childIdx) + // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: sameDir: left => %d\n", r) + return r + } else { + r := ft.np.add(fpCombined, 2, childIdx, noIndex) + // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: sameDir: right => %d\n", r) + return r + } + } + + idxA := ft.np.add(fpA, 1, noIndex, noIndex) + idxB := ft.np.add(fpB, curCount, noIndex, noIndex) + if dirA { + r := ft.np.add(fpCombined, 2, idxB, idxA) + // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: add A-B => %d\n", r) + return r + } else { + r := ft.np.add(fpCombined, 2, idxA, idxB) + // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: add B-A => %d\n", r) + return r + } +} + +func (ft *fpTree) addValue(fp fingerprint, p prefix, idx nodeIndex) nodeIndex { + if idx == noIndex { + r := ft.np.add(fp, 1, noIndex, noIndex) + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: addNew fp %s p %s => %d\n", fp.String(), p.String(), r) + return r + } + node := ft.np.node(idx) + // defer ft.releaseNode(idx) + if node.c == 1 || (ft.maxDepth != 0 && p.len() == ft.maxDepth) { + // we're at a leaf node, need to push down the old fingerprint, or, + // if we've reached the max depth, just update the current node + r := ft.pushDown(fp, node.fp, p, node.c) + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: pushDown fp %s p %s oldIdx %d => %d\n", fp.String(), p.String(), idx, r) + return r + } + fpCombined := fp + fpCombined.update(node.fp[:]) + if fp.bitFromLeft(p.len()) { + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceRight fp %s p %s oldIdx %d\n", fp.String(), p.String(), idx) + if node.left != noIndex { + ft.np.ref(node.left) + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: ref left %d -- refCount %d\n", node.left, ft.np.entry(node.left).refCount) + } + newRight := ft.addValue(fp, p.right(), node.right) + r := ft.np.add(fpCombined, node.c+1, node.left, newRight) + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceRight fp %s p %s oldIdx %d => %d node.left %d newRight %d\n", fp.String(), p.String(), idx, r, node.left, newRight) + return r + } else { + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceLeft fp %s p %s oldIdx %d\n", fp.String(), p.String(), idx) + if node.right != noIndex { + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: ref right %d -- refCount %d\n", node.right, ft.np.entry(node.right).refCount) + ft.np.ref(node.right) + } + newLeft := ft.addValue(fp, p.left(), node.left) + r := ft.np.add(fpCombined, node.c+1, newLeft, node.right) + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceLeft fp %s p %s oldIdx %d => %d newLeft %d node.right %d\n", fp.String(), p.String(), idx, r, newLeft, node.right) + return r + } +} + +func (ft *fpTree) addHash(h KeyBytes) error { + // fmt.Fprintf(os.Stderr, "QQQQQ: addHash: %s\n", hex.EncodeToString(h)) + var fp fingerprint + fp.update(h) + ft.rootMtx.Lock() + defer ft.rootMtx.Unlock() + oldRoot := ft.root + ft.root = ft.addValue(fp, 0, ft.root) + ft.releaseNode(oldRoot) + // fmt.Fprintf(os.Stderr, "QQQQQ: addHash: new root %d\n", ft.root) + return ft.idStore.registerHash(h) +} + +func (ft *fpTree) followPrefix(from nodeIndex, p, followed prefix) (idx nodeIndex, rp prefix, found bool) { + ft.enter("followPrefix: from %d p %s highBit %v", from, p, p.highBit()) + defer func() { ft.leave(idx, rp, found) }() + + switch { + case from == noIndex: + return noIndex, followed, false + case p == 0: + return from, followed, true + case ft.np.node(from).leaf(): + return from, followed, false + case p.highBit(): + return ft.followPrefix(ft.np.node(from).right, p.shift(), followed.right()) + default: + return ft.followPrefix(ft.np.node(from).left, p.shift(), followed.left()) + } +} + +func (ft *fpTree) tailRefFromPrefix(idx nodeIndex, p prefix, limit int) tailRef { + // TODO: QQQQ: FIXME: this may happen with reverse intervals, + // but should we even be checking the prefixes in this case? + if p.len() != ft.maxDepth { + panic("BUG: tail from short prefix") + } + return tailRef{idx: idx, ref: p.bits(), limit: limit} +} + +func (ft *fpTree) tailRefFromFingerprint(idx nodeIndex, fp fingerprint, limit int) tailRef { + v := load64(fp[:]) + if ft.maxDepth >= 64 { + return tailRef{idx: idx, ref: v, limit: limit} + } + // // fmt.Fprintf(os.Stderr, "QQQQQ: AAAAA: v %016x maxDepth %d shift %d\n", v, ft.maxDepth, (64 - ft.maxDepth)) + return tailRef{idx: idx, ref: v >> (64 - ft.maxDepth), limit: limit} +} + +func (ft *fpTree) tailRefFromNodeAndPrefix(idx nodeIndex, n node, p prefix, limit int) tailRef { + if n.c == 1 { + return ft.tailRefFromFingerprint(idx, n.fp, limit) + } else { + return ft.tailRefFromPrefix(idx, p, limit) + } +} + +func (ft *fpTree) descendToLeftmostLeaf(idx nodeIndex, p prefix) (nodeIndex, prefix) { + switch { + case idx == noIndex: + return noIndex, p + case ft.np.node(idx).leaf(): + return idx, p + default: + return ft.descendToLeftmostLeaf(ft.np.node(idx).left, p.left()) + } +} + +// func (ft *fpTree) descendToNextLeaf(idx nodeIndex, p, rem prefix) (nodeIndex, prefix) { +// switch { +// case idx == noIndex: +// panic("BUG: descendToNextLeaf: no node") +// case rem == 0: +// return noIndex, p +// case rem.highBit(): +// // Descending to the right branch by following p: +// // the next leaf, if there's any, is further down the right branch. +// newIdx, newP := ft.descendToNextLeaf(ft.np.node(idx).right, p.right(), rem.shift()) +// return newIdx, newP +// default: +// // Descending to the left branch by following p: +// // if the leaf is not found in the left branch, it's the leftmost leaf +// // on the right branch +// newIdx, newP := ft.descendToNextLeaf(ft.np.node(idx).left, p.left(), rem.shift()) +// if newIdx != noIndex { +// return newIdx, newP +// } +// return ft.descendToLeftmostLeaf(ft.np.node(idx).right, p.right()) +// } +// } + +// func (ft *fpTree) nextLeaf(p prefix) (nodeIndex, prefix) { +// if ft.root == noIndex { +// return noIndex, 0 +// } +// return ft.descendToNextLeaf(ft.root, 0, p) +// } + +func (ft *fpTree) visitNode(idx nodeIndex, p prefix, r *aggResult) (node, bool) { + if idx == noIndex { + return node{}, false + } + ft.log("visitNode: idx %d p %s", idx, p) + r.lastVisited = idx + r.lastPrefix = p + return ft.np.node(idx), true +} + +func (ft *fpTree) aggregateUpToLimit(idx nodeIndex, p prefix, r *aggResult) { + ft.enter("aggregateUpToLimit: idx %d p %s limit %d cur_fp %s cur_count %d", idx, p, r.limit, + r.fp.String(), r.count) + defer func() { + ft.leave(r.fp, r.count) + }() + node, ok := ft.visitNode(idx, p, r) + switch { + case !ok || r.limit == 0: + // for r.limit == 0, it's important that we still visit the node + // so that we can get the item immediately following the included items + ft.log("stop: ok %v r.limit %d", ok, r.limit) + case r.limit < 0: + // no limit + ft.log("no limit") + r.update(node) + case node.c <= uint32(r.limit): + // node is fully included + ft.log("included fully") + r.update(node) + r.limit -= int(node.c) + case node.leaf(): + tail := ft.tailRefFromPrefix(idx, p, r.takeAtMost(int(node.c))) + r.tailRefs = append(r.tailRefs, tail) + ft.log("add prefix to the tails: %016x => limit %d", tail.ref, r.limit) + default: + pLeft := p.left() + left, haveLeft := ft.visitNode(node.left, pLeft, r) + if haveLeft { + if int(left.c) <= r.limit { + // left node is fully included + ft.log("include left in full") + r.update(left) + r.limit -= int(left.c) + } else { + // we must stop somewhere in the left subtree + ft.log("descend to the left") + ft.aggregateUpToLimit(node.left, pLeft, r) + return + } + } + ft.log("descend to the right") + ft.aggregateUpToLimit(node.right, p.right(), r) + } +} + +func (ft *fpTree) aggregateLeft(idx nodeIndex, v uint64, p prefix, r *aggResult) { + ft.enter("aggregateLeft: idx %d v %016x p %s limit %d", idx, v, p, r.limit) + defer func() { + ft.leave(r.fp, r.count, r.tailRefs) + }() + node, ok := ft.visitNode(idx, p, r) + switch { + case !ok || r.limit == 0: + // for r.limit == 0, it's important that we still visit the node + // so that we can get the item immediately following the included items + ft.log("stop: ok %v r.limit %d", ok, r.limit) + case p.len() == ft.maxDepth: + if node.left != noIndex || node.right != noIndex { + panic("BUG: node @ maxDepth has children") + } + tail := ft.tailRefFromPrefix(idx, p, r.takeAtMost(int(node.c))) + r.tailRefs = append(r.tailRefs, tail) + ft.log("add prefix to the tails: %016x => limit %d", tail.ref, r.limit) + case node.leaf(): // TBD: combine with prev + // For leaf 1-nodes, we can use the fingerprint to get tailRef + // by which the actual IDs will be selected + if node.c != 1 { + panic("BUG: leaf non-1 node below maxDepth") + } + tail := ft.tailRefFromFingerprint(idx, node.fp, r.takeAtMost(1)) + r.tailRefs = append(r.tailRefs, tail) + ft.log("add prefix to the tails (1-leaf): %016x (fp %s) => limit %d", tail.ref, node.fp, r.limit) + case v&bit63 == 0: + ft.log("incl right node %d + go left to node %d", node.right, node.left) + if node.right != noIndex { + ft.aggregateUpToLimit(node.right, p.right(), r) + } + ft.aggregateLeft(node.left, v<<1, p.left(), r) + default: + ft.log("go right to node %d", node.right) + ft.aggregateLeft(node.right, v<<1, p.right(), r) + } +} + +func (ft *fpTree) aggregateRight(idx nodeIndex, v uint64, p prefix, r *aggResult) { + ft.enter("aggregateRight: idx %d v %016x p %s limit %d", idx, v, p, r.limit) + defer func() { + ft.leave(r.fp, r.count, r.tailRefs) + }() + node, ok := ft.visitNode(idx, p, r) + switch { + case !ok || r.limit == 0: + // for r.limit == 0, it's important that we still visit the node + // so that we can get the item immediately following the included items + ft.log("stop: ok %v r.limit %d", ok, r.limit) + case p.len() == ft.maxDepth: + if node.left != noIndex || node.right != noIndex { + panic("BUG: node @ maxDepth has children") + } + tail := ft.tailRefFromPrefix(idx, p, r.takeAtMost(int(node.c))) + r.tailRefs = append(r.tailRefs, tail) + ft.log("add prefix to the tails: %016x => limit %d", tail.ref, r.limit) + case node.leaf(): + // For leaf 1-nodes, we can use the fingerprint to get tailRef + // by which the actual IDs will be selected + if node.c != 1 { + panic("BUG: leaf non-1 node below maxDepth") + } + tail := ft.tailRefFromFingerprint(idx, node.fp, r.takeAtMost(1)) + r.tailRefs = append(r.tailRefs, tail) + ft.log("add prefix to the tails (1-leaf): %016x (fp %s) => limit %d", tail.ref, node.fp, r.limit) + case v&bit63 == 0: + ft.log("go left to node %d", node.left) + ft.aggregateRight(node.left, v<<1, p.left(), r) + default: + ft.log("incl left node %d + go right to node %d", node.left, node.right) + if node.left != noIndex { + ft.aggregateUpToLimit(node.left, p.left(), r) + } + ft.aggregateRight(node.right, v<<1, p.right(), r) + } +} + +func (ft *fpTree) aggregateInterval(x, y KeyBytes, limit int) (r aggResult) { + ft.rootMtx.Lock() + defer ft.rootMtx.Unlock() + ft.enter("aggregateInterval: x %s y %s limit %d", hex.EncodeToString(x), hex.EncodeToString(y), limit) + defer func() { + ft.leave(r) + }() + r = aggResult{limit: limit, lastVisited: noIndex} + r.itype = bytes.Compare(x, y) + switch { + case r.itype == 0: + // the whole set + if ft.root != noIndex { + ft.log("whole set") + ft.aggregateUpToLimit(ft.root, 0, &r) + } else { + ft.log("empty set (no root)") + } + case r.itype < 0: + // "proper" interval: [x; lca); (lca; y) + p := commonPrefix(x, y) + lcaIdx, followedPrefix, found := ft.followPrefix(ft.root, p, 0) + var lcaNode node + if lcaIdx != noIndex { + lcaNode = ft.np.node(lcaIdx) + } + ft.log("commonPrefix %s lca %d found %v", p, lcaIdx, found) + switch { + case found && !lcaNode.leaf(): + if followedPrefix != p { + panic("BUG: bad followedPrefix") + } + ft.visitNode(lcaIdx, followedPrefix, &r) + ft.aggregateLeft(lcaNode.left, load64(x)<<(p.len()+1), p.left(), &r) + ft.aggregateRight(lcaNode.right, load64(y)<<(p.len()+1), p.right(), &r) + case lcaIdx != noIndex: + ft.log("commonPrefix %s NOT found but have lca %d", p, lcaIdx) + // Didn't reach LCA in the tree b/c ended up + // at a leaf, just use the prefix to go + // through the IDs + if lcaNode.leaf() { + ft.visitNode(lcaIdx, followedPrefix, &r) + r.tailRefs = append(r.tailRefs, + ft.tailRefFromNodeAndPrefix( + lcaIdx, lcaNode, followedPrefix, r.takeAtMost(limit))) + } + } + default: + // inverse interval: [min; y); [x; max] + pf0 := preFirst0(x) + idx0, followedPrefix, found := ft.followPrefix(ft.root, pf0, 0) + var pf0Node node + if idx0 != noIndex { + pf0Node = ft.np.node(idx0) + } + ft.log("pf0 %s idx0 %d found %v", pf0, idx0, found) + switch { + case found && !pf0Node.leaf(): + if followedPrefix != pf0 { + panic("BUG: bad followedPrefix") + } + ft.aggregateLeft(idx0, load64(x)< fp %s count %d", + tailRef, + hex.EncodeToString(id), + r.fp.String(), r.count) + wasWithinRange = true + } else { + // if we were within the range but now we're out of it, + // this means we're at or beyond y and can stop + // return !wasWithinRange + // QQQQQ: rmme + if wasWithinRange { + ft.log("tailRef %v: id %s outside range after id(s) within range => terminating", + tailRef, + hex.EncodeToString(id)) + // TBD: QQQQQ: terminate only for this tailRef + return noStop + } else { + ft.log("tailRef %v: id %s outside range => continuing", + tailRef, + hex.EncodeToString(id)) + return true + } + } + return true + }); err != nil { + return fpResult{}, err + } + return fpResult{fp: r.fp, count: r.count, itype: r.itype}, nil +} + +func (ft *fpTree) dumpNode(w io.Writer, idx nodeIndex, indent, dir string) { + if idx == noIndex { + return + } + node := ft.np.node(idx) + + leaf := node.leaf() + countStr := strconv.Itoa(int(node.c)) + if leaf { + countStr = "LEAF-" + countStr + } + fmt.Fprintf(w, "%s%sidx=%d %s %s [%d]\n", indent, dir, idx, node.fp, countStr, ft.np.refCount(idx)) + if !leaf { + indent += " " + ft.dumpNode(w, node.left, indent, "l: ") + ft.dumpNode(w, node.right, indent, "r: ") + } +} + +func (ft *fpTree) dump(w io.Writer) { + if ft.root == noIndex { + fmt.Fprintln(w, "empty tree") + } else { + ft.dumpNode(w, ft.root, "", "") + } +} + +type memIDStore struct { + mtx sync.Mutex + ids map[uint64][]KeyBytes + maxDepth int +} + +var _ idStore = &memIDStore{} + +func newMemIDStore(maxDepth int) *memIDStore { + return &memIDStore{maxDepth: maxDepth} +} + +func (m *memIDStore) clone() idStore { + m.mtx.Lock() + defer m.mtx.Unlock() + s := newMemIDStore(m.maxDepth) + if m.ids != nil { + s.ids = make(map[uint64][]KeyBytes, len(m.ids)) + for k, v := range m.ids { + s.ids[k] = slices.Clone(v) + } + } + return s +} + +func (m *memIDStore) registerHash(h KeyBytes) error { + m.mtx.Lock() + defer m.mtx.Unlock() + if m.ids == nil { + m.ids = make(map[uint64][]KeyBytes, 1<> (64 - m.maxDepth) + s := m.ids[idx] + n := slices.IndexFunc(s, func(cur KeyBytes) bool { + return bytes.Compare(cur, h) > 0 + }) + if n < 0 { + m.ids[idx] = append(s, h) + } else { + m.ids[idx] = slices.Insert(s, n, h) + } + return nil +} + +func (m *memIDStore) iterateIDs(tailRefs []tailRef, toCall func(tailRef, KeyBytes) bool) error { + m.mtx.Lock() + defer m.mtx.Unlock() + if m.ids == nil { + return nil + } + // fmt.Fprintf(os.Stderr, "QQQQQ: memIDStore: iterateIDs: maxDepth %d tailRefs %v\n", m.maxDepth, tailRefs) + // fmt.Fprintf(os.Stderr, "QQQQQ: memIDStore: iterateIDs: ids %#v\n", m.ids) + for _, t := range tailRefs { + count := t.limit + if count == 0 { + // fmt.Fprintf(os.Stderr, "QQQQQ: memIDStore: iterateIDs: t %v: count == 0\n", t) + continue + } + for _, id := range m.ids[t.ref] { + if count == 0 { + // fmt.Fprintf(os.Stderr, "QQQQQ: memIDStore: iterateIDs: t %v id %s: count == 0\n", t, hex.EncodeToString(id)) + break + } + if count > 0 { + // fmt.Fprintf(os.Stderr, "QQQQQ: memIDStore: iterateIDs: t %v id %s: dec count\n", t, hex.EncodeToString(id)) + count-- + } + // fmt.Fprintf(os.Stderr, "QQQQQ: memIDStore: iterateIDs: t %v id %s: call\n", t, hex.EncodeToString(id)) + if !toCall(t, id) { + // fmt.Fprintf(os.Stderr, "QQQQQ: memIDStore: iterateIDs: t %v id %s: stop\n", t, hex.EncodeToString(id)) + return nil + } + } + } + return nil +} + +type sqlIDStore struct { + db sql.Database + query string + keyLen int + maxDepth int +} + +var _ idStore = &sqlIDStore{} + +func newSQLIDStore(db sql.Database, query string, keyLen, maxDepth int) *sqlIDStore { + return &sqlIDStore{db: db, query: query, keyLen: keyLen, maxDepth: maxDepth} +} + +func (s *sqlIDStore) clone() idStore { + return newSQLIDStore(s.db, s.query, s.keyLen, s.maxDepth) +} + +func (s *sqlIDStore) registerHash(h KeyBytes) error { + // should be registered by the handler code + return nil +} + +func (s *sqlIDStore) iterateIDs(tailRefs []tailRef, toCall func(tailRef, KeyBytes) bool) error { + cont := true + for _, t := range tailRefs { + if t.limit == 0 { + continue + } + p := mkprefix(t.ref, s.maxDepth) + minID := make([]byte, s.keyLen) + maxID := make([]byte, s.keyLen) + p.minID(minID[:]) + p.maxID(maxID[:]) + // start := time.Now() + query := s.query + if t.limit > 0 { + query += " LIMIT " + strconv.Itoa(t.limit) + } + if _, err := s.db.Exec( + query, + func(stmt *sql.Statement) { + stmt.BindBytes(1, minID) + stmt.BindBytes(2, maxID) + }, + func(stmt *sql.Statement) bool { + id := make(KeyBytes, s.keyLen) + stmt.ColumnBytes(0, id) + cont = toCall(t, id) + return cont + }, + ); err != nil { + return err + } + // fmt.Fprintf(os.Stderr, "QQQQQ: %v: sel atxs between %s and %s\n", time.Now().Sub(start), minID.String(), maxID.String()) + if !cont { + break + } + } + return nil +} + +type dbBackedStore struct { + *sqlIDStore + *memIDStore + maxDepth int +} + +var _ idStore = &dbBackedStore{} + +func newDBBackedStore(db sql.Database, query string, keyLen, maxDepth int) *dbBackedStore { + return &dbBackedStore{ + sqlIDStore: newSQLIDStore(db, query, keyLen, maxDepth), + memIDStore: newMemIDStore(maxDepth), + maxDepth: maxDepth, + } +} + +func (s *dbBackedStore) clone() idStore { + return &dbBackedStore{ + sqlIDStore: s.sqlIDStore.clone().(*sqlIDStore), + memIDStore: s.memIDStore.clone().(*memIDStore), + maxDepth: s.maxDepth, + } +} + +func (s *dbBackedStore) registerHash(h KeyBytes) error { + return s.memIDStore.registerHash(h) +} + +func (s *dbBackedStore) iterateIDs(tailRefs []tailRef, toCall func(tailRef, KeyBytes) bool) error { + type memItem struct { + tailRef tailRef + id KeyBytes + } + var memItems []memItem + s.memIDStore.iterateIDs(tailRefs, func(tailRef tailRef, id KeyBytes) bool { + memItems = append(memItems, memItem{tailRef: tailRef, id: id}) + return true + }) + cont := true + limits := make(map[uint64]int, len(tailRefs)) + for _, t := range tailRefs { + if t.limit >= 0 { + limits[t.ref] += t.limit + } + } + if err := s.sqlIDStore.iterateIDs(tailRefs, func(tailRef tailRef, id KeyBytes) bool { + ref := load64(id) >> (64 - s.maxDepth) + limit, haveLimit := limits[ref] + for len(memItems) > 0 && bytes.Compare(memItems[0].id, id) < 0 { + if haveLimit && limit == 0 { + return false + } + cont = toCall(memItems[0].tailRef, memItems[0].id) + if !cont { + return false + } + limits[ref] = limit - 1 + memItems = memItems[1:] + } + if haveLimit && limit == 0 { + return false + } + cont = toCall(tailRef, id) + limits[ref] = limit - 1 + return cont + }); err != nil { + return err + } + if cont { + for _, mi := range memItems { + ref := load64(mi.id) >> (64 - s.maxDepth) + limit, haveLimit := limits[ref] + if haveLimit && limit == 0 { + break + } + if !toCall(mi.tailRef, mi.id) { + break + } + limits[ref] = limit - 1 + } + } + return nil +} + +func idWithinInterval(id, x, y KeyBytes, itype int) bool { + switch itype { + case 0: + return true + case -1: + return bytes.Compare(id, x) >= 0 && bytes.Compare(id, y) < 0 + default: + return bytes.Compare(id, y) < 0 || bytes.Compare(id, x) >= 0 + } +} + +// TBD: optimize, get rid of binary.BigEndian.* diff --git a/sync2/dbsync/dbsync_test.go b/sync2/dbsync/fptree_test.go similarity index 56% rename from sync2/dbsync/dbsync_test.go rename to sync2/dbsync/fptree_test.go index 45302e2ca5..c5f73cf399 100644 --- a/sync2/dbsync/dbsync_test.go +++ b/sync2/dbsync/fptree_test.go @@ -1,9 +1,9 @@ package dbsync import ( - "bytes" "fmt" "math/rand" + "reflect" "runtime" "slices" "strings" @@ -188,36 +188,42 @@ func TestCommonPrefix(t *testing.T) { } } -type fakeATXStore struct { - db sql.StateDatabase +type fakeIDDBStore struct { + db sql.Database *sqlIDStore } -func newFakeATXIDStore(db sql.StateDatabase) *fakeATXStore { - return &fakeATXStore{db: db, sqlIDStore: newSQLIDStore(db)} +var _ idStore = &fakeIDDBStore{} + +const fakeIDQuery = "select id from foo where id between ? and ? order by id" + +func newFakeATXIDStore(db sql.Database, maxDepth int) *fakeIDDBStore { + return &fakeIDDBStore{db: db, sqlIDStore: newSQLIDStore(db, fakeIDQuery, 32, maxDepth)} } -func (s *fakeATXStore) registerHash(h []byte, maxDepth int) error { - if err := s.sqlIDStore.registerHash(h, maxDepth); err != nil { +func (s *fakeIDDBStore) registerHash(h KeyBytes) error { + if err := s.sqlIDStore.registerHash(h); err != nil { return err } - _, err := s.db.Exec(` - insert into atxs (id, epoch, effective_num_units, received) - values (?, 1, 1, 0)`, + _, err := s.db.Exec("insert into foo (id) values (?)", func(stmt *sql.Statement) { stmt.BindBytes(1, h) }, nil) return err } -func testFPTree(t *testing.T, idStore idStore) { +type idStoreFunc func(maxDepth int) idStore + +func testFPTree(t *testing.T, makeIDStore idStoreFunc) { for _, tc := range []struct { - name string - ids []string - results map[[2]int]fpResult + name string + maxDepth int + ids []string + results map[[3]int]fpResult }{ { - name: "ids1", + name: "ids1", + maxDepth: 24, ids: []string{ "0000000000000000000000000000000000000000000000000000000000000000", "123456789abcdef0000000000000000000000000000000000000000000000000", @@ -225,68 +231,156 @@ func testFPTree(t *testing.T, idStore idStore) { "8888888888888888888888888888888888888888888888888888888888888888", "abcdef1234567890000000000000000000000000000000000000000000000000", }, - results: map[[2]int]fpResult{ - {0, 0}: { + results: map[[3]int]fpResult{ + {0, 0, -1}: { fp: hexToFingerprint("642464b773377bbddddddddd"), count: 5, + itype: 0, }, - {4, 4}: { + {0, 0, 3}: { + fp: hexToFingerprint("4761032dcfe98ba555555555"), + count: 3, + itype: 0, + }, + {4, 4, -1}: { fp: hexToFingerprint("642464b773377bbddddddddd"), count: 5, + itype: 0, }, - {0, 1}: { + {0, 1, -1}: { fp: hexToFingerprint("000000000000000000000000"), count: 1, + itype: -1, + }, + {0, 3, -1}: { + fp: hexToFingerprint("4761032dcfe98ba555555555"), + count: 3, + itype: -1, + }, + {0, 4, 3}: { + fp: hexToFingerprint("4761032dcfe98ba555555555"), + count: 3, + itype: -1, }, - {1, 4}: { + {1, 4, -1}: { fp: hexToFingerprint("cfe98ba54761032ddddddddd"), count: 3, + itype: -1, }, - {1, 0}: { + {1, 0, -1}: { fp: hexToFingerprint("642464b773377bbddddddddd"), count: 4, + itype: 1, }, - {2, 0}: { + {2, 0, -1}: { fp: hexToFingerprint("761032cfe98ba54ddddddddd"), count: 3, + itype: 1, }, - {3, 1}: { + {3, 1, -1}: { fp: hexToFingerprint("2345679abcdef01888888888"), count: 3, + itype: 1, }, - {3, 2}: { + {3, 2, -1}: { fp: hexToFingerprint("317131e226622ee888888888"), count: 4, + itype: 1, + }, + {3, 2, 3}: { + fp: hexToFingerprint("2345679abcdef01888888888"), + count: 3, + itype: 1, }, }, }, { - name: "ids2", + name: "ids2", + maxDepth: 24, ids: []string{ "6e476ca729c3840d0118785496e488124ee7dade1aef0c87c6edc78f72e4904f", "829977b444c8408dcddc1210536f3b3bdc7fd97777426264b9ac8f70b97a7fd1", "a280bcb8123393e0d4a15e5c9850aab5dddffa03d5efa92e59bc96202e8992bc", "e93163f908630280c2a8bffd9930aa684be7a3085432035f5c641b0786590d1d", }, - results: map[[2]int]fpResult{ - {0, 0}: { + results: map[[3]int]fpResult{ + {0, 0, -1}: { fp: hexToFingerprint("a76fc452775b55e0dacd8be5"), count: 4, + itype: 0, }, - {0, 3}: { + {0, 0, 3}: { fp: hexToFingerprint("4e5ea7ab7f38576018653418"), count: 3, + itype: 0, }, - {3, 1}: { + {0, 3, -1}: { + fp: hexToFingerprint("4e5ea7ab7f38576018653418"), + count: 3, + itype: -1, + }, + {3, 1, -1}: { fp: hexToFingerprint("87760f5e21a0868dc3b0c7a9"), count: 2, + itype: 1, + }, + {3, 2, -1}: { + fp: hexToFingerprint("05ef78ea6568c6000e6cd5b9"), + count: 3, + itype: 1, + }, + }, + }, + { + name: "ids3", + maxDepth: 4, + ids: []string{ + "01dd08ec0c477312f0ef010789b4a7c65d664e3d07e9fde246c70ee2af71f4c7", + "051f49b4621dad18ab3582eeeda995bba5fdd0a23d0ae0387e312e4706c62d26", + "0743ede445d407d164e4139c440e6f09273d6ac088f929c5781ffd6c63806622", + "114991f28f34d1239d9b617ad1d0e3497fd8f7c5320c1bfc51042cddb3c4d4d1", + "120bf12c57659760f1b0a5cf5f85e23492f92822e714543fc4be732d4de3d284", + "20e8cb9ba6fba6926ed5e0101e57881094d831a9b26a68d73b04d30a2100075b", + "2403eb652598ee893b84d854f222fc0231ee1c3823bba9dfbe7bc8521eb10831", + "282ed276fe896730d856ca373837ef6f89b2109d04a0b17eac152df73fc21d90", + "2e6690d307c831a1e87039fcb67a0cdd44867271a8955b8003e74f4c644bd7bd", + "360ca30d3013940704a5a095318e022ee5d36618c4ad1b2d084e2bc797a1793d", + "3f52547180ba19ae700cb24b220fac01159c489e4ab127ee7ae046069165587a", + "4df3f9fb5b1cc7a7921dbdaf27afd16f1749f4134d611eead0a1e9cf34c51994", + "625df1cf9e472cd647b3e5fd065be537385889b1b913a0336787a37f12d55a02", + "6feaf52c2f8030e3eb21935f67d6ced8b37535387a086d46de8f31e5b67e1f71", + "75a5176eb4cc182302120e991f88cbe3b01e19a28dfd972a441a5bcde57f6879", + "768281853be35aa50156598308f6c5b12a4457615551c688712607069517714f", + "7686323c12f0853555450ce1ec22700861530fa67d523587bf7078f915204cc5", + "a6df4f61a0e351bc539b32b4262446ac27766073515ef4b5203941fef7343ebc", + "a740ea1cdb1c144da5bc4f96833a4c611fa7196d4ebaa89a1bd209abe519503a", + "ab0960667a9bf57138c1a3f7d54b242e23b6c36fd8f2a645ed9217050dd5e011", + "af5adcf404035e9ee88377230d26406702259ad25a04d425bd3c2cff546d32c0", + "afd06a52970126024887099ed40d2400b9bb9505f171fb203baf74f7199f7c7e", + "b520c3bb04061813e57d75db0a06f711b635b0aef1561d01859f122439437d61", + "b525b9ecbf8a888a3b01669c7c7d5656b6b6a7c4df3bbe5402fbe4e718bad4bb", + "b84d4bf077d68821ee9203aaf6eee90fe892f42faee939c974f719c29117ddb6", + "bf0f6ef1cee0eb3131fb24ef52e6ac8f0a22d85d32c3fe3255d921037423df1b", + "c72caa7c9822d6c77a254c12bc17eae8e5d637a929c94cc84aa4662d4baa508d", + "d4375ae1c64c3d2167bb467acc63083851d834fa24f285d4a1220c407287cd56", + "d552081889142b74ab0f0cb9da0de192cdd549213a2d348e0cc21061c196ed6a", + "e1729d5eda4d6dac38070551a0956f3bcf0d8ac34b45a0b7e5553315cc662ebe", + "e41d8c3a7607ec5423cc376a34d21494f2d0c625fb9bebcec09d06c188ab7f3f", + "e9110a384198b47be2bb63e64f094069a0ee9a013e013176bbe8189834c5e4c8", + }, + results: map[[3]int]fpResult{ + {31, 0, -1}: { + fp: hexToFingerprint("e9110a384198b47be2bb63e6"), + count: 1, + itype: 1, }, }, }, } { t.Run(tc.name, func(t *testing.T) { var np nodePool - ft := newFPTree(&np, idStore, 24) + idStore := makeIDStore(tc.maxDepth) + ft := newFPTree(&np, idStore, tc.maxDepth) var hs []types.Hash32 for _, hex := range tc.ids { t.Logf("add: %s", hex) @@ -299,30 +393,111 @@ func testFPTree(t *testing.T, idStore idStore) { ft.dump(&sb) t.Logf("tree:\n%s", sb.String()) - checkTree(t, ft, 24) + checkTree(t, ft, tc.maxDepth) for idRange, expResult := range tc.results { x := hs[idRange[0]] y := hs[idRange[1]] - fpr, err := ft.fingerprintInterval(x[:], y[:]) + fpr, err := ft.fingerprintInterval(x[:], y[:], idRange[2]) require.NoError(t, err) require.Equal(t, expResult, fpr) } + + ft.release() + require.Zero(t, np.count()) }) } } func TestFPTree(t *testing.T) { - t.Run("in-memory id store", func(t *testing.T) { - testFPTree(t, &memIDStore{}) - }) + // t.Run("in-memory id store", func(t *testing.T) { + // testFPTree(t, func(maxDepth int) idStore { return newMemIDStore(maxDepth) }) + // }) t.Run("fake ATX store", func(t *testing.T) { - db := statesql.InMemory() - defer db.Close() - testFPTree(t, newFakeATXIDStore(db)) + db := populateDB(t, 32, nil) + testFPTree(t, func(maxDepth int) idStore { + _, err := db.Exec("delete from foo", nil, nil) + require.NoError(t, err) + return newFakeATXIDStore(db, maxDepth) + }) }) } +func TestFPTreeClone(t *testing.T) { + var np nodePool + ft1 := newFPTree(&np, newMemIDStore(24), 24) + hashes := []types.Hash32{ + types.HexToHash32("1111111111111111111111111111111111111111111111111111111111111111"), + types.HexToHash32("3333333333333333333333333333333333333333333333333333333333333333"), + types.HexToHash32("4444444444444444444444444444444444444444444444444444444444444444"), + } + ft1.addHash(hashes[0][:]) + ft1.addHash(hashes[1][:]) + + fpr, err := ft1.fingerprintInterval(hashes[0][:], hashes[0][:], -1) + require.NoError(t, err) + require.Equal(t, fpResult{ + fp: hexToFingerprint("222222222222222222222222"), + count: 2, + itype: 0, + }, fpr) + + var sb strings.Builder + ft1.dump(&sb) + t.Logf("ft1 pre-clone:\n%s", sb.String()) + + ft2 := ft1.clone() + + sb.Reset() + ft1.dump(&sb) + t.Logf("ft1 after-clone:\n%s", sb.String()) + + sb.Reset() + ft2.dump(&sb) + t.Logf("ft2 after-clone:\n%s", sb.String()) + + // original tree unchanged --- rmme!!!! + fpr, err = ft1.fingerprintInterval(hashes[0][:], hashes[0][:], -1) + require.NoError(t, err) + require.Equal(t, fpResult{ + fp: hexToFingerprint("222222222222222222222222"), + count: 2, + itype: 0, + }, fpr) + + ft2.addHash(hashes[2][:]) + + fpr, err = ft2.fingerprintInterval(hashes[0][:], hashes[0][:], -1) + require.NoError(t, err) + require.Equal(t, fpResult{ + fp: hexToFingerprint("666666666666666666666666"), + count: 3, + itype: 0, + }, fpr) + + // original tree unchanged + fpr, err = ft1.fingerprintInterval(hashes[0][:], hashes[0][:], -1) + require.NoError(t, err) + require.Equal(t, fpResult{ + fp: hexToFingerprint("222222222222222222222222"), + count: 2, + itype: 0, + }, fpr) + + sb.Reset() + ft1.dump(&sb) + t.Logf("ft1:\n%s", sb.String()) + + sb.Reset() + ft2.dump(&sb) + t.Logf("ft2:\n%s", sb.String()) + + ft1.release() + ft2.release() + + require.Zero(t, np.count()) +} + type hashList []types.Hash32 func (l hashList) findGTE(h types.Hash32) int { @@ -364,13 +539,53 @@ func checkTree(t *testing.T, ft *fpTree, maxDepth int) { checkNode(t, ft, ft.root, 0) } -func testInMemFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool) { +func repeatTestFPTreeManyItems( + t *testing.T, + makeIDStore idStoreFunc, + randomXY bool, + numItems, maxDepth int, +) { + for i := 0; i < 100; i++ { + testFPTreeManyItems(t, makeIDStore(maxDepth), randomXY, numItems, maxDepth) + } +} + +func dumbFP(hs hashList, x, y types.Hash32) fpResult { + var fpr fpResult + fpr.itype = x.Compare(y) + switch fpr.itype { + case -1: + pX := hs.findGTE(x) + pY := hs.findGTE(y) + // t.Logf("x=%s y=%s pX=%d y=%d", x.String(), y.String(), pX, pY) + for p := pX; p < pY; p++ { + // t.Logf("XOR %s", hs[p].String()) + fpr.fp.update(hs[p][:]) + } + fpr.count = uint32(pY - pX) + case 1: + pX := hs.findGTE(x) + pY := hs.findGTE(y) + for p := 0; p < pY; p++ { + fpr.fp.update(hs[p][:]) + } + for p := pX; p < len(hs); p++ { + fpr.fp.update(hs[p][:]) + } + fpr.count = uint32(pY + len(hs) - pX) + default: + for _, h := range hs { + fpr.fp.update(h[:]) + } + fpr.count = uint32(len(hs)) + } + return fpr +} + +func testFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool, numItems, maxDepth int) { var np nodePool - const ( - numItems = 1 << 16 - maxDepth = 24 - ) ft := newFPTree(&np, idStore, maxDepth) + // ft.traceEnabled = true hs := make(hashList, numItems) var fp fingerprint for i := range hs { @@ -385,9 +600,9 @@ func testInMemFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool) { checkTree(t, ft, maxDepth) - fpr, err := ft.fingerprintInterval(hs[0][:], hs[0][:]) + fpr, err := ft.fingerprintInterval(hs[0][:], hs[0][:], -1) require.NoError(t, err) - require.Equal(t, fpResult{fp: fp, count: numItems}, fpr) + require.Equal(t, fpResult{fp: fp, count: uint32(numItems), itype: 0}, fpr) for i := 0; i < 100; i++ { // TBD: allow reverse order var x, y types.Hash32 @@ -398,102 +613,93 @@ func testInMemFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool) { x = hs[rand.Intn(numItems)] y = hs[rand.Intn(numItems)] } - var ( - expFP fingerprint - expN uint32 - ) - switch bytes.Compare(x[:], y[:]) { - case -1: - pX := hs.findGTE(x) - pY := hs.findGTE(y) - // t.Logf("x=%s y=%s pX=%d y=%d", x.String(), y.String(), pX, pY) - for p := pX; p < pY; p++ { - // t.Logf("XOR %s", hs[p].String()) - expFP.update(hs[p][:]) - } - expN = uint32(pY - pX) - case 1: - pX := hs.findGTE(x) - pY := hs.findGTE(y) - for p := 0; p < pY; p++ { - expFP.update(hs[p][:]) - } - for p := pX; p < len(hs); p++ { - expFP.update(hs[p][:]) + expFPR := dumbFP(hs, x, y) + fpr, err := ft.fingerprintInterval(x[:], y[:], -1) + require.NoError(t, err) + + // QQQQQ: rm + if !reflect.DeepEqual(fpr, expFPR) { + t.Logf("QQQQQ: x=%s y=%s", x.String(), y.String()) + for _, h := range hs { + t.Logf("QQQQQ: hash: %s", h.String()) } - expN = uint32(pY + len(hs) - pX) - default: - expFP = fp - expN = numItems + var sb strings.Builder + ft.dump(&sb) + t.Logf("QQQQQ: tree:\n%s", sb.String()) } - fpr, err := ft.fingerprintInterval(x[:], y[:]) - require.NoError(t, err) - require.Equal(t, fpResult{ - fp: expFP, - count: expN, - }, fpr) + // QQQQQ: /rm + + require.Equal(t, expFPR, fpr) } } -func TestInMemFPTreeManyItems(t *testing.T) { +func TestFPTreeManyItems(t *testing.T) { + const ( + // numItems = 1 << 16 + // maxDepth = 24 + numItems = 1 << 5 + maxDepth = 4 + ) t.Run("bounds from the set", func(t *testing.T) { - var idStore memIDStore - testInMemFPTreeManyItems(t, &idStore, false) - total := 0 - nums := make(map[int]int) - for _, ids := range idStore.ids { - nums[len(ids)]++ - total += len(ids) - } - t.Logf("total %d, nums %#v", total, nums) + repeatTestFPTreeManyItems(t, func(maxDepth int) idStore { + return newMemIDStore(maxDepth) + }, false, numItems, maxDepth) }) t.Run("random bounds", func(t *testing.T) { - testInMemFPTreeManyItems(t, &memIDStore{}, true) + repeatTestFPTreeManyItems(t, func(maxDepth int) idStore { + return newMemIDStore(maxDepth) + }, true, numItems, maxDepth) }) t.Run("SQL, bounds from the set", func(t *testing.T) { - db := statesql.InMemory() - defer db.Close() - testInMemFPTreeManyItems(t, newFakeATXIDStore(db), false) - + db := populateDB(t, 32, nil) + repeatTestFPTreeManyItems(t, func(maxDepth int) idStore { + _, err := db.Exec("delete from foo", nil, nil) + require.NoError(t, err) + return newFakeATXIDStore(db, maxDepth) + }, false, numItems, maxDepth) }) t.Run("SQL, random bounds", func(t *testing.T) { - db := statesql.InMemory() - defer db.Close() - testInMemFPTreeManyItems(t, newFakeATXIDStore(db), true) + db := populateDB(t, 32, nil) + repeatTestFPTreeManyItems(t, func(maxDepth int) idStore { + _, err := db.Exec("delete from foo", nil, nil) + require.NoError(t, err) + return newFakeATXIDStore(db, maxDepth) + }, true, numItems, maxDepth) }) } const dbFile = "/Users/ivan4th/Library/Application Support/Spacemesh/node-data/7c8cef2b/state.sql" -func dumbAggATXs(t *testing.T, db sql.StateDatabase, x, y types.Hash32) fpResult { - var fp fingerprint - ts := time.Now() - nRows, err := db.Exec( - // BETWEEN is faster than >= and < - "select id from atxs where id between ? and ?", - func(stmt *sql.Statement) { - stmt.BindBytes(1, x[:]) - stmt.BindBytes(2, y[:]) - }, - func(stmt *sql.Statement) bool { - var id types.Hash32 - stmt.ColumnBytes(0, id[:]) - if id != y { - fp.update(id[:]) - } - return true - }, - ) - require.NoError(t, err) - t.Logf("QQQQQ: %v: dumb fp between %s and %s", time.Now().Sub(ts), x.String(), y.String()) - return fpResult{ - fp: fp, - count: uint32(nRows), - } -} - -func testFP(t *testing.T, maxDepth int) { +// func dumbAggATXs(t *testing.T, db sql.StateDatabase, x, y types.Hash32) fpResult { +// var fp fingerprint +// ts := time.Now() +// nRows, err := db.Exec( +// // BETWEEN is faster than >= and < +// "select id from atxs where id between ? and ? order by id", +// func(stmt *sql.Statement) { +// stmt.BindBytes(1, x[:]) +// stmt.BindBytes(2, y[:]) +// }, +// func(stmt *sql.Statement) bool { +// var id types.Hash32 +// stmt.ColumnBytes(0, id[:]) +// if id != y { +// fp.update(id[:]) +// } +// return true +// }, +// ) +// require.NoError(t, err) +// t.Logf("QQQQQ: %v: dumb fp between %s and %s", time.Now().Sub(ts), x.String(), y.String()) +// return fpResult{ +// fp: fp, +// count: uint32(nRows), +// itype: x.Compare(y), +// } +// } + +func testATXFP(t *testing.T, maxDepth int) { runtime.GC() var stats1 runtime.MemStats runtime.ReadMemStats(&stats1) @@ -508,14 +714,16 @@ func testFP(t *testing.T, maxDepth int) { // var prev uint64 // first := true // where epoch=23 - store := newSQLIDStore(db) + store := newSQLIDStore(db, "select id from atxs where id between ? and ? order by id", 32, maxDepth) var np nodePool ft := newFPTree(&np, store, maxDepth) t.Logf("loading IDs") + var hs []types.Hash32 _, err = db.Exec("select id from atxs order by id", nil, func(stmt *sql.Statement) bool { var id types.Hash32 stmt.ColumnBytes(0, id[:]) ft.addHash(id[:]) + hs = append(hs, id) // v := load64(id[:]) // counts[v>>40]++ // if first { @@ -548,7 +756,7 @@ func testFP(t *testing.T, maxDepth int) { for n := 0; n < numIter; n++ { x := types.RandomHash() y := types.RandomHash() - ft.fingerprintInterval(x[:], y[:]) + ft.fingerprintInterval(x[:], y[:], -1) } elapsed := time.Now().Sub(ts) @@ -563,34 +771,67 @@ func testFP(t *testing.T, maxDepth int) { float64(numIter)/elapsed.Seconds(), stats2.HeapInuse-stats1.HeapInuse) - // TBD: restore !!!! - // t.Logf("testing ranges") - // for n := 0; n < 10; n++ { - // x := types.RandomHash() - // y := types.RandomHash() - // // TBD: QQQQQ: dumb rev / full intervals - // if x == y { - // continue - // } - // if x.Compare(y) > 0 { - // x, y = y, x - // } - // expFPResult := dumbAggATXs(t, db, x, y) - // fpr, err := ft.fingerprintInterval(x[:], y[:]) - // require.NoError(t, err) - // require.Equal(t, expFPResult, fpr) - // } + // TODO: test incomplete ranges (with limit) + t.Logf("testing ranges") + for n := 0; n < 50; n++ { + x := types.RandomHash() + y := types.RandomHash() + t.Logf("QQQQQ: x=%s y=%s", x.String(), y.String()) + expFPResult := dumbFP(hs, x, y) + //expFPResult := dumbAggATXs(t, db, x, y) + fpr, err := ft.fingerprintInterval(x[:], y[:], -1) + require.NoError(t, err) + require.Equal(t, expFPResult, fpr, "x=%s y=%s", x.String(), y.String()) + } } -func TestFP(t *testing.T) { - t.Skip("slow test") +func TestATXFP(t *testing.T) { + // t.Skip("slow test") for maxDepth := 15; maxDepth <= 23; maxDepth++ { for i := 0; i < 3; i++ { - testFP(t, maxDepth) + testATXFP(t, maxDepth) } } } +func TestDBBackedStore(t *testing.T) { + // create an in-memory-database, put some ids into it, + // create dbBackedStore, read the ids from the database and check them, + // then add some ids to the dbBackedStore but not to the database, + // and re-check the dbBackedStore contents using iterateIDs method + // use plain sql.InMemory and foo table like in TestDBRangeIterator + initialIDs := []KeyBytes{ + {0, 0, 0, 1, 0, 0, 0, 0}, + {0, 0, 0, 3, 0, 0, 0, 0}, + {0, 0, 0, 5, 0, 0, 0, 0}, + {0, 0, 0, 7, 0, 0, 0, 0}, + } + db := populateDB(t, 8, initialIDs) + store := newDBBackedStore(db, fakeIDQuery, 8, 24) + var actualIDs []KeyBytes + require.NoError(t, store.iterateIDs([]tailRef{{ref: 0, limit: -1}}, func(_ tailRef, id KeyBytes) bool { + actualIDs = append(actualIDs, id) + return true + })) + require.Equal(t, initialIDs, actualIDs) + + require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 2, 0, 0, 0, 0})) + require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 9, 0, 0, 0, 0})) + actualIDs = nil + require.NoError(t, store.iterateIDs([]tailRef{{ref: 0, limit: -1}}, func(_ tailRef, id KeyBytes) bool { + actualIDs = append(actualIDs, id) + return true + })) + require.Equal(t, []KeyBytes{ + {0, 0, 0, 1, 0, 0, 0, 0}, + {0, 0, 0, 2, 0, 0, 0, 0}, + {0, 0, 0, 3, 0, 0, 0, 0}, + {0, 0, 0, 5, 0, 0, 0, 0}, + {0, 0, 0, 7, 0, 0, 0, 0}, + {0, 0, 0, 9, 0, 0, 0, 0}, + }, actualIDs) +} + // benchmarks // maxDepth 18: 94.739µs per range, 10555.290991 ranges/s, heap diff 16621568 @@ -659,4 +900,8 @@ func TestFP(t *testing.T) { // maxDepth 23: 24.997µs per range, 40003.930386 ranges/s, heap diff 470040576 // maxDepth 23: 24.741µs per range, 40418.462446 ranges/s, heap diff 470040576 -// TBD: ensure short prefix problem is not a bug!!! +// TODO: maxDepth should be used when creating a store and not passed +// to registerHash and iterateIDs +// TODO: ensure short prefix problem is not a bug!!! +// TODO: QQQQQ: retrieve the end of the interval w/count in fpTree.fingerprintInterval() +// TODO: QQQQQ: test limits in TestInMemFPTreeManyItems (sep test cases SQL / non-SQL) diff --git a/sync2/dbsync/refcountpool.go b/sync2/dbsync/refcountpool.go index 211831af83..8a112a521e 100644 --- a/sync2/dbsync/refcountpool.go +++ b/sync2/dbsync/refcountpool.go @@ -1,6 +1,9 @@ package dbsync -import "sync" +import ( + "sync" + "sync/atomic" +) const freeBit = 1 << 31 const freeListMask = freeBit - 1 @@ -14,7 +17,12 @@ type rcPool[T any, I ~uint32] struct { mtx sync.Mutex entries []poolEntry[T, I] // freeList is 1-based so that rcPool doesn't need a constructor - freeList uint32 + freeList uint32 + allocCount atomic.Int64 +} + +func (rc *rcPool[T, I]) count() int { + return int(rc.allocCount.Load()) } func (rc *rcPool[T, I]) item(idx I) T { @@ -38,27 +46,41 @@ func (rc *rcPool[T, I]) add(item T) I { if rc.freeList != 0 { idx = I(rc.freeList - 1) rc.freeList = rc.entries[idx].refCount & freeListMask + if rc.freeList > uint32(len(rc.entries)) { + panic("BUG: bad freeList linkage") + } rc.entries[idx].refCount = 1 } else { idx = I(len(rc.entries)) rc.entries = append(rc.entries, poolEntry[T, I]{refCount: 1}) } rc.entries[idx].content = item + rc.allocCount.Add(1) return idx } -func (rc *rcPool[T, I]) release(idx I) { +func (rc *rcPool[T, I]) release(idx I) bool { rc.mtx.Lock() defer rc.mtx.Unlock() entry := &rc.entries[idx] + if entry.refCount&freeBit != 0 { + panic("BUG: excess release of rcPool[T, I] entry") + } if entry.refCount <= 0 { panic("BUG: negative rcPool[T, I] entry refcount") } entry.refCount-- if entry.refCount == 0 { + if rc.freeList > uint32(len(rc.entries)) { + panic("BUG: bad freeList") + } entry.refCount = rc.freeList | freeBit rc.freeList = uint32(idx + 1) + rc.allocCount.Add(-1) + return true } + + return false } func (rc *rcPool[T, I]) ref(idx I) { @@ -67,4 +89,10 @@ func (rc *rcPool[T, I]) ref(idx I) { rc.mtx.Unlock() } +func (rc *rcPool[T, I]) refCount(idx I) uint32 { + rc.mtx.Lock() + defer rc.mtx.Unlock() + return rc.entries[idx].refCount +} + // TODO: convert TestNodePool to TestRCPool diff --git a/sync2/dbsync/refcountpool_test.go b/sync2/dbsync/refcountpool_test.go index de637c61d7..2bac0c63ca 100644 --- a/sync2/dbsync/refcountpool_test.go +++ b/sync2/dbsync/refcountpool_test.go @@ -13,36 +13,45 @@ func TestRCPool(t *testing.T) { var pool rcPool[foo, fooIndex] idx1 := pool.add(foo{x: 1}) foo1 := pool.item(idx1) + require.Equal(t, 1, pool.count()) idx2 := pool.add(foo{x: 2}) foo2 := pool.item(idx2) + require.Equal(t, 2, pool.count()) require.Equal(t, foo{x: 1}, foo1) require.Equal(t, foo{x: 2}, foo2) idx3 := pool.add(foo{x: 3}) idx4 := pool.add(foo{x: 4}) require.Equal(t, fooIndex(3), idx4) pool.ref(idx4) + require.Equal(t, 4, pool.count()) - pool.release(idx4) + require.False(t, pool.release(idx4)) // not yet released due to an extra ref require.Equal(t, fooIndex(4), pool.add(foo{x: 5})) + require.Equal(t, 5, pool.count()) - pool.release(idx4) + require.True(t, pool.release(idx4)) // idx4 was freed require.Equal(t, idx4, pool.add(foo{x: 6})) + require.Equal(t, 5, pool.count()) // free item used just once require.Equal(t, fooIndex(5), pool.add(foo{x: 7})) + require.Equal(t, 6, pool.count()) // form a free list containing several items - pool.release(idx3) - pool.release(idx2) - pool.release(idx1) + require.True(t, pool.release(idx3)) + require.True(t, pool.release(idx2)) + require.True(t, pool.release(idx1)) + require.Equal(t, 3, pool.count()) // the free list is LIFO require.Equal(t, idx1, pool.add(foo{x: 8})) require.Equal(t, idx2, pool.add(foo{x: 9})) require.Equal(t, idx3, pool.add(foo{x: 10})) + require.Equal(t, 6, pool.count()) // the free list is exhausted require.Equal(t, fooIndex(6), pool.add(foo{x: 11})) + require.Equal(t, 7, pool.count()) } diff --git a/sync2/hashsync/handler_test.go b/sync2/hashsync/handler_test.go index d60be49788..1764626b42 100644 --- a/sync2/hashsync/handler_test.go +++ b/sync2/hashsync/handler_test.go @@ -416,12 +416,21 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) Requester { // maxNumSpecificA: 20000, // minNumSpecificB: 15, // maxNumSpecificB: 20, + + // QQQQQ: restore! + // maxSendRange: 1, + // numTestHashes: 100000, + // minNumSpecificA: 4, + // maxNumSpecificA: 100, + // minNumSpecificB: 4, + // maxNumSpecificB: 100, + maxSendRange: 1, - numTestHashes: 100000, - minNumSpecificA: 4, - maxNumSpecificA: 100, - minNumSpecificB: 4, - maxNumSpecificB: 100, + numTestHashes: 100, + minNumSpecificA: 2, + maxNumSpecificA: 4, + minNumSpecificB: 2, + maxNumSpecificB: 4, } var client Requester verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []RangeSetReconcilerOption) bool { diff --git a/sync2/hashsync/interface.go b/sync2/hashsync/interface.go index 63fdea7d4f..88aa1456ec 100644 --- a/sync2/hashsync/interface.go +++ b/sync2/hashsync/interface.go @@ -13,10 +13,10 @@ import ( // Iterator points to in item in ItemStore type Iterator interface { - // Equal returns true if this iterator is equal to another Iterator - Equal(other Iterator) bool // Key returns the key corresponding to iterator position. It returns // nil if the ItemStore is empty + // If the iterator is returned along with a count, the return value of Key() + // after calling Next() count times is dependent on the implementation. Key() Ordered // Next advances the iterator Next() error @@ -41,9 +41,6 @@ type ItemStore interface { // Min returns the iterator pointing at the minimum element // in the store. If the store is empty, it returns nil Min() (Iterator, error) - // Max returns the iterator pointing at the maximum element - // in the store. If the store is empty, it returns nil - Max() (Iterator, error) // Copy makes a shallow copy of the ItemStore Copy() ItemStore // Has returns true if the specified key is present in ItemStore diff --git a/sync2/hashsync/mocks_test.go b/sync2/hashsync/mocks_test.go index 861c3d1ad8..adb27568f4 100644 --- a/sync2/hashsync/mocks_test.go +++ b/sync2/hashsync/mocks_test.go @@ -43,44 +43,6 @@ func (m *MockIterator) EXPECT() *MockIteratorMockRecorder { return m.recorder } -// Equal mocks base method. -func (m *MockIterator) Equal(other Iterator) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Equal", other) - ret0, _ := ret[0].(bool) - return ret0 -} - -// Equal indicates an expected call of Equal. -func (mr *MockIteratorMockRecorder) Equal(other any) *MockIteratorEqualCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Equal", reflect.TypeOf((*MockIterator)(nil).Equal), other) - return &MockIteratorEqualCall{Call: call} -} - -// MockIteratorEqualCall wrap *gomock.Call -type MockIteratorEqualCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockIteratorEqualCall) Return(arg0 bool) *MockIteratorEqualCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockIteratorEqualCall) Do(f func(Iterator) bool) *MockIteratorEqualCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockIteratorEqualCall) DoAndReturn(f func(Iterator) bool) *MockIteratorEqualCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - // Key mocks base method. func (m *MockIterator) Key() Ordered { m.ctrl.T.Helper() @@ -334,45 +296,6 @@ func (c *MockItemStoreHasCall) DoAndReturn(f func(Ordered) (bool, error)) *MockI return c } -// Max mocks base method. -func (m *MockItemStore) Max() (Iterator, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Max") - ret0, _ := ret[0].(Iterator) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Max indicates an expected call of Max. -func (mr *MockItemStoreMockRecorder) Max() *MockItemStoreMaxCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Max", reflect.TypeOf((*MockItemStore)(nil).Max)) - return &MockItemStoreMaxCall{Call: call} -} - -// MockItemStoreMaxCall wrap *gomock.Call -type MockItemStoreMaxCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockItemStoreMaxCall) Return(arg0 Iterator, arg1 error) *MockItemStoreMaxCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockItemStoreMaxCall) Do(f func() (Iterator, error)) *MockItemStoreMaxCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreMaxCall) DoAndReturn(f func() (Iterator, error)) *MockItemStoreMaxCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - // Min mocks base method. func (m *MockItemStore) Min() (Iterator, error) { m.ctrl.T.Helper() diff --git a/sync2/hashsync/rangesync.go b/sync2/hashsync/rangesync.go index a5908e0840..7dfce5494c 100644 --- a/sync2/hashsync/rangesync.go +++ b/sync2/hashsync/rangesync.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "os" "reflect" "slices" "strings" @@ -590,6 +591,36 @@ func fingerprintEqual(a, b any) bool { return reflect.DeepEqual(a, b) } +// CollectStoreItems returns the list of items in the given store +func CollectStoreItems[K Ordered](is ItemStore) ([]K, error) { + var r []K + it, err := is.Min() + if err != nil { + return nil, err + } + if it == nil || it.Key() == nil { + return nil, nil + } + info, err := is.GetRangeInfo(nil, it.Key(), it.Key(), -1) + if err != nil { + return nil, err + } + it, err = is.Min() + if err != nil { + return nil, err + } + for n := 0; n < info.Count; n++ { + k := it.Key() + if k == nil { + fmt.Fprintf(os.Stderr, "QQQQQ: it: %#v\n", it) + panic("BUG: iterator exausted before Count reached") + } + r = append(r, k.(K)) + it.Next() + } + return r, nil +} + // TBD: test: add items to the store even in case of NextMessage() failure // TBD: !!! use wire types instead of multiple Send* methods in the Conduit interface !!! // TBD: !!! queue outbound messages right in RangeSetReconciler while processing msgs, and no need for done in handleMessage this way ++ no need for complicated logic on the conduit part !!! diff --git a/sync2/hashsync/rangesync_test.go b/sync2/hashsync/rangesync_test.go index b579391974..e396c66c34 100644 --- a/sync2/hashsync/rangesync_test.go +++ b/sync2/hashsync/rangesync_test.go @@ -5,6 +5,7 @@ import ( "fmt" "math/rand" "slices" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -270,7 +271,10 @@ func (ds *dumbStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) ( } else if x == nil || y == nil { panic("BUG: bad X or Y") } - all := storeItemStr(ds) + all := "" + for _, k := range ds.keys { + all += string(k) + } vx := x.(sampleID) vy := y.(sampleID) if preceding != nil && preceding.Key().Compare(x) > 0 { @@ -301,16 +305,6 @@ func (ds *dumbStore) Min() (Iterator, error) { }, nil } -func (ds *dumbStore) Max() (Iterator, error) { - if len(ds.keys) == 0 { - return nil, nil - } - return &dumbStoreIterator{ - ds: ds, - n: len(ds.keys) - 1, - }, nil -} - func (ds *dumbStore) Copy() ItemStore { return &dumbStore{keys: slices.Clone(ds.keys)} } @@ -332,17 +326,6 @@ type verifiedStoreIterator struct { var _ Iterator = &verifiedStoreIterator{} -func (it verifiedStoreIterator) Equal(other Iterator) bool { - o := other.(verifiedStoreIterator) - eq1 := it.knownGood.Equal(o.knownGood) - eq2 := it.it.Equal(o.it) - assert.Equal(it.t, eq1, eq2, "iterators equal -- keys <%v> <%v> / <%v> <%v>", - it.knownGood.Key(), it.it.Key(), - o.knownGood.Key(), o.it.Key()) - assert.Equal(it.t, it.knownGood.Key(), it.it.Key(), "keys of equal iterators") - return eq2 -} - func (it verifiedStoreIterator) Key() Ordered { k1 := it.knownGood.Key() k2 := it.it.Key() @@ -471,25 +454,6 @@ func (vs *verifiedStore) Min() (Iterator, error) { }, nil } -func (vs *verifiedStore) Max() (Iterator, error) { - m1, err := vs.knownGood.Max() - require.NoError(vs.t, err) - m2, err := vs.store.Max() - require.NoError(vs.t, err) - if m1 == nil { - require.Nil(vs.t, m2, "Max") - return nil, nil - } else { - require.NotNil(vs.t, m2, "Max") - require.Equal(vs.t, m1.Key(), m2.Key(), "Max key") - } - return verifiedStoreIterator{ - t: vs.t, - knownGood: m1, - it: m2, - }, nil -} - func (vs *verifiedStore) Copy() ItemStore { return &verifiedStore{ t: vs.t, @@ -535,27 +499,15 @@ func makeStore(t *testing.T, f storeFactory, items string) ItemStore { } func storeItemStr(is ItemStore) string { - it, err := is.Min() - if err != nil { - panic("store min error") - } - if it == nil { - return "" - } - endAt, err := is.Min() + ids, err := CollectStoreItems[sampleID](is) if err != nil { - panic("store min error") + panic("store error") } - r := "" - for { - r += string(it.Key().(sampleID)) - if err := it.Next(); err != nil { - panic("iterator error") - } - if it.Equal(endAt) { - return r - } + var r strings.Builder + for _, id := range ids { + r.WriteString(string(id)) } + return r.String() } var testStores = []struct { diff --git a/sync2/hashsync/sync_tree.go b/sync2/hashsync/sync_tree.go index 690be069d2..6f09a838cb 100644 --- a/sync2/hashsync/sync_tree.go +++ b/sync2/hashsync/sync_tree.go @@ -45,7 +45,6 @@ type SyncTree interface { Set(k Ordered, v any) Lookup(k Ordered) (any, bool) Min() SyncTreePointer - Max() SyncTreePointer RangeFingerprint(ptr SyncTreePointer, start, end Ordered, stop FingerprintPredicate) (fp any, startNode, endNode SyncTreePointer) Dump() string } @@ -398,20 +397,6 @@ func (st *syncTree) Min() SyncTreePointer { return st.cachedMinPtr.clone() } -func (st *syncTree) Max() SyncTreePointer { - if st.root == nil { - return nil - } - if st.cachedMaxPtr == nil { - st.cachedMaxPtr = st.rootPtr() - st.cachedMaxPtr.max() - } - if st.cachedMaxPtr.node == nil { - panic("BUG: no maxNode in a non-empty tree") - } - return st.cachedMaxPtr.clone() -} - func (st *syncTree) Fingerprint() any { if st.root == nil { return st.m.Identity() diff --git a/sync2/hashsync/sync_tree_store.go b/sync2/hashsync/sync_tree_store.go index 65ce3506b7..43df94cade 100644 --- a/sync2/hashsync/sync_tree_store.go +++ b/sync2/hashsync/sync_tree_store.go @@ -106,11 +106,6 @@ func (sts *SyncTreeStore) Min() (Iterator, error) { return sts.iter(sts.st.Min()), nil } -// Max implements ItemStore. -func (sts *SyncTreeStore) Max() (Iterator, error) { - return sts.iter(sts.st.Max()), nil -} - // Copy implements ItemStore. func (sts *SyncTreeStore) Copy() ItemStore { return &SyncTreeStore{ diff --git a/sync2/hashsync/xorsync_test.go b/sync2/hashsync/xorsync_test.go index a017b24155..a57837a0fc 100644 --- a/sync2/hashsync/xorsync_test.go +++ b/sync2/hashsync/xorsync_test.go @@ -30,29 +30,6 @@ func TestHash32To12Xor(t *testing.T) { require.Equal(t, m.Op(m.Op(fp1, fp2), fp3), m.Op(fp1, m.Op(fp2, fp3))) } -func collectStoreItems[K Ordered](is ItemStore) (r []K) { - it, err := is.Min() - if err != nil { - panic("store min error") - } - if it == nil { - return nil - } - endAt, err := is.Min() - if err != nil { - panic("store min error") - } - for { - r = append(r, it.Key().(K)) - if err := it.Next(); err != nil { - panic("iterator error") - } - if it.Equal(endAt) { - return r - } - } -} - type catchTransferTwice struct { ItemStore t *testing.T @@ -113,8 +90,10 @@ func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB }) if sync(storeA, storeB, numSpecificA+numSpecificB, opts) { - itemsA := collectStoreItems[types.Hash32](storeA) - itemsB := collectStoreItems[types.Hash32](storeB) + itemsA, err := CollectStoreItems[types.Hash32](storeA) + require.NoError(t, err) + itemsB, err := CollectStoreItems[types.Hash32](storeB) + require.NoError(t, err) require.Equal(t, itemsA, itemsB) srcKeys := make([]types.Hash32, len(src)) for n, h := range src { diff --git a/sync2/p2p_test.go b/sync2/p2p_test.go index a72e5d7e14..d6d95813be 100644 --- a/sync2/p2p_test.go +++ b/sync2/p2p_test.go @@ -86,20 +86,8 @@ func TestP2P(t *testing.T) { for _, hsync := range hs { hsync.Stop() - min, err := hsync.ItemStore().Min() + actualItems, err := hashsync.CollectStoreItems[types.Hash32](hsync.ItemStore()) require.NoError(t, err) - it, err := hsync.ItemStore().Min() - require.NoError(t, err) - require.NotNil(t, it) - var actualItems []types.Hash32 - for { - k := it.Key().(types.Hash32) - actualItems = append(actualItems, k) - require.NoError(t, it.Next()) - if it.Equal(min) { - break - } - } require.ElementsMatch(t, initialSet, actualItems) } } From fa24d2156b2009e70009760332fface180652e66 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 25 Jun 2024 13:13:09 +0400 Subject: [PATCH 40/76] fix test --- sync2/dbsync/fptree_test.go | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index c5f73cf399..7c72307a0f 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -700,9 +700,6 @@ const dbFile = "/Users/ivan4th/Library/Application Support/Spacemesh/node-data/7 // } func testATXFP(t *testing.T, maxDepth int) { - runtime.GC() - var stats1 runtime.MemStats - runtime.ReadMemStats(&stats1) // t.Skip("slow tmp test") // counts := make(map[uint64]uint64) // prefLens := make(map[int]int) @@ -714,15 +711,12 @@ func testATXFP(t *testing.T, maxDepth int) { // var prev uint64 // first := true // where epoch=23 - store := newSQLIDStore(db, "select id from atxs where id between ? and ? order by id", 32, maxDepth) var np nodePool - ft := newFPTree(&np, store, maxDepth) t.Logf("loading IDs") var hs []types.Hash32 _, err = db.Exec("select id from atxs order by id", nil, func(stmt *sql.Statement) bool { var id types.Hash32 stmt.ColumnBytes(0, id[:]) - ft.addHash(id[:]) hs = append(hs, id) // v := load64(id[:]) // counts[v>>40]++ @@ -735,6 +729,16 @@ func testATXFP(t *testing.T, maxDepth int) { return true }) require.NoError(t, err) + + runtime.GC() + var stats1 runtime.MemStats + runtime.ReadMemStats(&stats1) + store := newSQLIDStore(db, "select id from atxs where id between ? and ? order by id", 32, maxDepth) + ft := newFPTree(&np, store, maxDepth) + for _, id := range hs { + ft.addHash(id[:]) + } + // countFreq := make(map[uint64]int) // for _, c := range counts { // countFreq[c]++ @@ -900,8 +904,6 @@ func TestDBBackedStore(t *testing.T) { // maxDepth 23: 24.997µs per range, 40003.930386 ranges/s, heap diff 470040576 // maxDepth 23: 24.741µs per range, 40418.462446 ranges/s, heap diff 470040576 -// TODO: maxDepth should be used when creating a store and not passed -// to registerHash and iterateIDs -// TODO: ensure short prefix problem is not a bug!!! // TODO: QQQQQ: retrieve the end of the interval w/count in fpTree.fingerprintInterval() // TODO: QQQQQ: test limits in TestInMemFPTreeManyItems (sep test cases SQL / non-SQL) +// TODO: the returned RangeInfo.End iterators should be cyclic From 8a54baee8d8b968e97c2d57e249fa675bc4bd170 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sun, 14 Jul 2024 01:54:25 +0400 Subject: [PATCH 41/76] re-enable in-mem store test, skip slow TestATXFP --- sync2/dbsync/fptree_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index 7c72307a0f..d864e6f980 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -410,9 +410,9 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { } func TestFPTree(t *testing.T) { - // t.Run("in-memory id store", func(t *testing.T) { - // testFPTree(t, func(maxDepth int) idStore { return newMemIDStore(maxDepth) }) - // }) + t.Run("in-memory id store", func(t *testing.T) { + testFPTree(t, func(maxDepth int) idStore { return newMemIDStore(maxDepth) }) + }) t.Run("fake ATX store", func(t *testing.T) { db := populateDB(t, 32, nil) testFPTree(t, func(maxDepth int) idStore { @@ -790,7 +790,7 @@ func testATXFP(t *testing.T, maxDepth int) { } func TestATXFP(t *testing.T) { - // t.Skip("slow test") + t.Skip("slow test") for maxDepth := 15; maxDepth <= 23; maxDepth++ { for i := 0; i < 3; i++ { testATXFP(t, maxDepth) From 3df8c38e8d3a32cbfc24a29222c1e583d5580e67 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 16 Jul 2024 00:50:53 +0400 Subject: [PATCH 42/76] make dbiter wrap around --- sync2/dbsync/dbitemstore.go | 4 +- sync2/dbsync/dbiter.go | 209 ++++++++---------- sync2/dbsync/dbiter_test.go | 420 ++++++++++++++++++++---------------- 3 files changed, 322 insertions(+), 311 deletions(-) diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go index 000e5cdee2..6a4d54fe80 100644 --- a/sync2/dbsync/dbitemstore.go +++ b/sync2/dbsync/dbitemstore.go @@ -40,7 +40,8 @@ func (d *DBItemStore) Add(ctx context.Context, k hashsync.Ordered) error { } func (d *DBItemStore) iter(min, max KeyBytes) (hashsync.Iterator, error) { - return newDBRangeIterator(d.db, d.query, min, max, d.chunkSize) + panic("TBD") + // return newDBRangeIterator(d.db, d.query, min, max, d.chunkSize) } // GetRangeInfo implements hashsync.ItemStore. @@ -56,6 +57,7 @@ func (d *DBItemStore) GetRangeInfo( // Min implements hashsync.ItemStore. func (d *DBItemStore) Min() (hashsync.Iterator, error) { + // INCORRECT !!! should return nil if the store is empty it1 := make(KeyBytes, d.keyLen) it2 := make(KeyBytes, d.keyLen) for i := range it2 { diff --git a/sync2/dbsync/dbiter.go b/sync2/dbsync/dbiter.go index 359e43f1ba..89af542ab4 100644 --- a/sync2/dbsync/dbiter.go +++ b/sync2/dbsync/dbiter.go @@ -3,6 +3,7 @@ package dbsync import ( "bytes" "errors" + "slices" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sync2/hashsync" @@ -12,47 +13,77 @@ type KeyBytes []byte var _ hashsync.Ordered = KeyBytes(nil) +func (k KeyBytes) Clone() KeyBytes { + return slices.Clone(k) +} + func (k KeyBytes) Compare(other any) int { return bytes.Compare(k, other.(KeyBytes)) } +func (k KeyBytes) inc() (overflow bool) { + for i := len(k) - 1; i >= 0; i-- { + k[i]++ + if k[i] != 0 { + return false + } + } + + return true +} + +func (k KeyBytes) zero() { + for i := range k { + k[i] = 0 + } +} + +func (k KeyBytes) isZero() bool { + for _, b := range k { + if b != 0 { + return false + } + } + return true +} + +var errEmptySet = errors.New("empty range") + type dbRangeIterator struct { - db sql.Database - from, to KeyBytes - query string - chunkSize int - chunk []KeyBytes - pos int - keyLen int + db sql.Database + from KeyBytes + query string + chunkSize int + chunk []KeyBytes + pos int + keyLen int + singleChunk bool } var _ hashsync.Iterator = &dbRangeIterator{} // makeDBIterator creates a dbRangeIterator and initializes it from the database. -// Note that [from, to] range is inclusive. +// If query returns no rows even after starting from zero ID, errEmptySet error is returned. func newDBRangeIterator( db sql.Database, query string, - from, to KeyBytes, + from KeyBytes, chunkSize int, ) (hashsync.Iterator, error) { if from == nil { panic("BUG: makeDBIterator: nil from") } - if to == nil { - panic("BUG: makeDBIterator: nil to") - } if chunkSize <= 0 { panic("BUG: makeDBIterator: chunkSize must be > 0") } it := &dbRangeIterator{ - db: db, - from: from, - to: to, - query: query, - chunkSize: chunkSize, - keyLen: len(from), - chunk: make([]KeyBytes, chunkSize), + db: db, + from: from.Clone(), + query: query, + chunkSize: chunkSize, + keyLen: len(from), + chunk: make([]KeyBytes, chunkSize), + singleChunk: false, } if err := it.load(); err != nil { return nil, err @@ -61,13 +92,22 @@ func newDBRangeIterator( } func (it *dbRangeIterator) load() error { + it.pos = 0 + if it.singleChunk { + // we have a single-chunk DB iterator, don't need to reload, + // just wrap around + return nil + } + n := 0 + // if the chunk size was reduced due to a short chunk before wraparound, we need + // to extend it back + it.chunk = it.chunk[:it.chunkSize] var ierr error _, err := it.db.Exec( it.query, func(stmt *sql.Statement) { stmt.BindBytes(1, it.from) - stmt.BindBytes(2, it.to) - stmt.BindInt64(3, int64(it.chunkSize)) + stmt.BindInt64(2, int64(it.chunkSize)) }, func(stmt *sql.Statement) bool { if n >= len(it.chunk) { @@ -84,19 +124,36 @@ func (it *dbRangeIterator) load() error { n++ return true }) - if err != nil || ierr != nil { + fromZero := it.from.isZero() + switch { + case err != nil || ierr != nil: return errors.Join(ierr, err) - } - it.pos = 0 - if n < len(it.chunk) { - // short chunk means there are no more data - it.from = nil + case n == 0: + // empty chunk + if fromZero { + // already wrapped around or started from 0, + // the set is empty + return errEmptySet + } + // wrap around + it.from.zero() + return it.load() + case n < len(it.chunk): + // short chunk means there are no more items after it, + // start the next chunk from 0 + it.from.zero() it.chunk = it.chunk[:n] - } else { + // wrapping around on an incomplete chunk that started + // from 0 means we have just a single chunk + it.singleChunk = fromZero + default: + // use last item incremented by 1 as the start of the next chunk copy(it.from, it.chunk[n-1]) - if incID(it.from) || bytes.Compare(it.from, it.to) >= 0 { - // no more items after this full chunk - it.from = nil + // inc may wrap around if it's 0xffff...fff, but it's fine + if it.from.inc() { + // if we wrapped around and the current chunk started from 0, + // we have just a single chunk + it.singleChunk = fromZero } } return nil @@ -116,96 +173,8 @@ func (it *dbRangeIterator) Next() error { return nil } it.pos++ - if it.pos < len(it.chunk) || it.from == nil { + if it.pos < len(it.chunk) { return nil } return it.load() } - -func incID(id []byte) (overflow bool) { - for i := len(id) - 1; i >= 0; i-- { - id[i]++ - if id[i] != 0 { - return false - } - } - - return true -} - -type concatIterator struct { - iters []hashsync.Iterator -} - -var _ hashsync.Iterator = &concatIterator{} - -// concatIterators concatenates multiple iterators into one. -// It assumes that the iterators follow one after another in the order of their keys. -func concatIterators(iters ...hashsync.Iterator) hashsync.Iterator { - return &concatIterator{iters: iters} -} - -func (c *concatIterator) Key() hashsync.Ordered { - if len(c.iters) == 0 { - return nil - } - return c.iters[0].Key() -} - -func (c *concatIterator) Next() error { - if len(c.iters) == 0 { - return nil - } - if err := c.iters[0].Next(); err != nil { - return err - } - for len(c.iters) > 0 { - if c.iters[0].Key() != nil { - break - } - c.iters = c.iters[1:] - } - return nil -} - -type combinedIterator struct { - iters []hashsync.Iterator - ahead hashsync.Iterator -} - -// combineIterators combines multiple iterators into one. -// Unlike concatIterator, it does not assume that the iterators follow one after another -// in the order of their keys. Instead, it always returns the smallest key among all -// iterators. -func combineIterators(iters ...hashsync.Iterator) hashsync.Iterator { - return &combinedIterator{iters: iters} -} - -func (c *combinedIterator) aheadIterator() hashsync.Iterator { - if c.ahead == nil { - if len(c.iters) == 0 { - return nil - } - c.ahead = c.iters[0] - for i := 1; i < len(c.iters); i++ { - if c.iters[i].Key() != nil { - if c.ahead.Key() == nil || c.iters[i].Key().Compare(c.ahead.Key()) < 0 { - c.ahead = c.iters[i] - } - } - } - } - return c.ahead -} - -func (c *combinedIterator) Key() hashsync.Ordered { - return c.aheadIterator().Key() -} - -func (c *combinedIterator) Next() error { - if err := c.aheadIterator().Next(); err != nil { - return err - } - c.ahead = nil - return nil -} diff --git a/sync2/dbsync/dbiter_test.go b/sync2/dbsync/dbiter_test.go index 7b37ecea13..74d998f52c 100644 --- a/sync2/dbsync/dbiter_test.go +++ b/sync2/dbsync/dbiter_test.go @@ -2,12 +2,11 @@ package dbsync import ( "encoding/hex" - "errors" "fmt" + "slices" "testing" "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sync2/hashsync" "github.com/stretchr/testify/require" ) @@ -34,18 +33,22 @@ func TestIncID(t *testing.T) { } { id := make(KeyBytes, len(tc.id)) copy(id, tc.id) - require.Equal(t, tc.overflow, incID(id)) + require.Equal(t, tc.overflow, id.inc()) require.Equal(t, tc.expected, id) } } -func populateDB(t *testing.T, keyLen int, content []KeyBytes) sql.Database { +func createDB(t *testing.T, keyLen int) sql.Database { db := sql.InMemory(sql.WithIgnoreSchemaDrift()) t.Cleanup(func() { require.NoError(t, db.Close()) }) _, err := db.Exec(fmt.Sprintf("create table foo(id char(%d))", keyLen), nil, nil) require.NoError(t, err) + return db +} + +func insertDBItems(t *testing.T, db sql.Database, content []KeyBytes) { for _, id := range content { _, err := db.Exec( "insert into foo(id) values(?)", @@ -54,223 +57,260 @@ func populateDB(t *testing.T, keyLen int, content []KeyBytes) sql.Database { }, nil) require.NoError(t, err) } +} + +func deleteDBItems(t *testing.T, db sql.Database) { + _, err := db.Exec("delete from foo", nil, nil) + require.NoError(t, err) +} + +func populateDB(t *testing.T, keyLen int, content []KeyBytes) sql.Database { + db := createDB(t, keyLen) + insertDBItems(t, db, content) return db } -const testQuery = "select id from foo where id between ? and ? order by id limit ?" +const testQuery = "select id from foo where id >= ? order by id limit ?" func TestDBRangeIterator(t *testing.T) { - db := populateDB(t, 4, []KeyBytes{ - {0x00, 0x00, 0x00, 0x01}, - {0x00, 0x00, 0x00, 0x03}, - {0x00, 0x00, 0x00, 0x05}, - {0x00, 0x00, 0x00, 0x07}, - {0x00, 0x00, 0x01, 0x00}, - {0x00, 0x00, 0x03, 0x00}, - {0x00, 0x01, 0x00, 0x00}, - {0x00, 0x05, 0x00, 0x00}, - {0x03, 0x05, 0x00, 0x00}, - {0x09, 0x05, 0x00, 0x00}, - {0x0a, 0x05, 0x00, 0x00}, - {0xff, 0xff, 0xff, 0xff}, - }) + db := createDB(t, 4) for _, tc := range []struct { - from, to KeyBytes - chunkSize int - items []KeyBytes + items []KeyBytes + from KeyBytes + fromN int + expErr error }{ { - from: KeyBytes{0x00, 0x00, 0x00, 0x00}, - to: KeyBytes{0x00, 0x00, 0x00, 0x00}, - chunkSize: 4, - items: nil, + items: nil, + from: KeyBytes{0x00, 0x00, 0x00, 0x00}, + expErr: errEmptySet, + }, + { + items: nil, + from: KeyBytes{0x80, 0x00, 0x00, 0x00}, + expErr: errEmptySet, + }, + { + items: nil, + from: KeyBytes{0xff, 0xff, 0xff, 0xff}, + expErr: errEmptySet, + }, + { + items: []KeyBytes{ + {0x00, 0x00, 0x00, 0x00}, + }, + from: KeyBytes{0x00, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []KeyBytes{ + {0x00, 0x00, 0x00, 0x00}, + }, + from: KeyBytes{0x01, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []KeyBytes{ + {0x00, 0x00, 0x00, 0x00}, + }, + from: KeyBytes{0xff, 0xff, 0xff, 0xff}, + fromN: 0, + }, + { + items: []KeyBytes{ + {0x01, 0x02, 0x03, 0x04}, + }, + from: KeyBytes{0x00, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []KeyBytes{ + {0x01, 0x02, 0x03, 0x04}, + }, + from: KeyBytes{0x01, 0x00, 0x00, 0x00}, + fromN: 0, }, { - from: KeyBytes{0x00, 0x00, 0x00, 0x00}, - to: KeyBytes{0x00, 0x00, 0x00, 0x08}, - chunkSize: 4, items: []KeyBytes{ - {0x00, 0x00, 0x00, 0x01}, - {0x00, 0x00, 0x00, 0x03}, - {0x00, 0x00, 0x00, 0x05}, - {0x00, 0x00, 0x00, 0x07}, + {0x01, 0x02, 0x03, 0x04}, }, + from: KeyBytes{0xff, 0xff, 0xff, 0xff}, + fromN: 0, }, { - from: KeyBytes{0x00, 0x00, 0x00, 0x00}, - to: KeyBytes{0x00, 0x00, 0x03, 0x00}, - chunkSize: 4, items: []KeyBytes{ - {0x00, 0x00, 0x00, 0x01}, - {0x00, 0x00, 0x00, 0x03}, - {0x00, 0x00, 0x00, 0x05}, - {0x00, 0x00, 0x00, 0x07}, - {0x00, 0x00, 0x01, 0x00}, - {0x00, 0x00, 0x03, 0x00}, + {0xff, 0xff, 0xff, 0xff}, }, + from: KeyBytes{0x00, 0x00, 0x00, 0x00}, + fromN: 0, }, { - from: KeyBytes{0x00, 0x00, 0x03, 0x00}, - to: KeyBytes{0x09, 0x05, 0x00, 0x00}, - chunkSize: 4, items: []KeyBytes{ - {0x00, 0x00, 0x03, 0x00}, - {0x00, 0x01, 0x00, 0x00}, - {0x00, 0x05, 0x00, 0x00}, - {0x03, 0x05, 0x00, 0x00}, - {0x09, 0x05, 0x00, 0x00}, + {0xff, 0xff, 0xff, 0xff}, }, + from: KeyBytes{0x01, 0x00, 0x00, 0x00}, + fromN: 0, }, { - from: KeyBytes{0x00, 0x00, 0x00, 0x00}, - to: KeyBytes{0xff, 0xff, 0xff, 0xff}, - chunkSize: 4, items: []KeyBytes{ - {0x00, 0x00, 0x00, 0x01}, - {0x00, 0x00, 0x00, 0x03}, - {0x00, 0x00, 0x00, 0x05}, - {0x00, 0x00, 0x00, 0x07}, - {0x00, 0x00, 0x01, 0x00}, - {0x00, 0x00, 0x03, 0x00}, - {0x00, 0x01, 0x00, 0x00}, - {0x00, 0x05, 0x00, 0x00}, - {0x03, 0x05, 0x00, 0x00}, - {0x09, 0x05, 0x00, 0x00}, - {0x0a, 0x05, 0x00, 0x00}, {0xff, 0xff, 0xff, 0xff}, }, + from: KeyBytes{0xff, 0xff, 0xff, 0xff}, + fromN: 0, + }, + { + items: []KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: KeyBytes{0x00, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: KeyBytes{0x00, 0x00, 0x00, 0x01}, + fromN: 0, + }, + { + items: []KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: KeyBytes{0x00, 0x00, 0x00, 0x02}, + fromN: 1, + }, + { + items: []KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: KeyBytes{0x00, 0x00, 0x00, 0x03}, + fromN: 1, + }, + { + items: []KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: KeyBytes{0x00, 0x00, 0x00, 0x05}, + fromN: 2, + }, + { + items: []KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: KeyBytes{0x00, 0x00, 0x00, 0x07}, + fromN: 3, + }, + { + items: []KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: KeyBytes{0xff, 0xff, 0xff, 0xff}, + fromN: 0, + }, + { + items: []KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + 4: {0x00, 0x00, 0x01, 0x00}, + 5: {0x00, 0x00, 0x03, 0x00}, + 6: {0x00, 0x01, 0x00, 0x00}, + 7: {0x00, 0x05, 0x00, 0x00}, + 8: {0x03, 0x05, 0x00, 0x00}, + 9: {0x09, 0x05, 0x00, 0x00}, + 10: {0x0a, 0x05, 0x00, 0x00}, + 11: {0xff, 0xff, 0xff, 0xff}, + }, + from: KeyBytes{0x00, 0x00, 0x03, 0x01}, + fromN: 6, + }, + { + items: []KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + 4: {0x00, 0x00, 0x01, 0x00}, + 5: {0x00, 0x00, 0x03, 0x00}, + 6: {0x00, 0x01, 0x00, 0x00}, + 7: {0x00, 0x05, 0x00, 0x00}, + 8: {0x03, 0x05, 0x00, 0x00}, + 9: {0x09, 0x05, 0x00, 0x00}, + 10: {0x0a, 0x05, 0x00, 0x00}, + 11: {0xff, 0xff, 0xff, 0xff}, + }, + from: KeyBytes{0x00, 0x01, 0x00, 0x00}, + fromN: 6, + }, + { + items: []KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + 4: {0x00, 0x00, 0x01, 0x00}, + 5: {0x00, 0x00, 0x03, 0x00}, + 6: {0x00, 0x01, 0x00, 0x00}, + 7: {0x00, 0x05, 0x00, 0x00}, + 8: {0x03, 0x05, 0x00, 0x00}, + 9: {0x09, 0x05, 0x00, 0x00}, + 10: {0x0a, 0x05, 0x00, 0x00}, + 11: {0xff, 0xff, 0xff, 0xff}, + }, + from: KeyBytes{0xff, 0xff, 0xff, 0xff}, + fromN: 11, }, } { - it, err := newDBRangeIterator(db, testQuery, tc.from, tc.to, tc.chunkSize) - require.NoError(t, err) - if len(tc.items) == 0 { - require.Nil(t, it.Key()) - } else { + deleteDBItems(t, db) + insertDBItems(t, db, tc.items) + for chunkSize := 1; chunkSize < 12; chunkSize++ { + it, err := newDBRangeIterator(db, testQuery, tc.from, chunkSize) + if tc.expErr != nil { + require.ErrorIs(t, err, tc.expErr) + continue + } + require.NoError(t, err) + // when there are no items, errEmptySet is returned + require.NotEmpty(t, tc.items) var collected []KeyBytes for i := 0; i < len(tc.items); i++ { - if k := it.Key(); k != nil { - collected = append(collected, k.(KeyBytes)) - } else { - break - } + k := it.Key() + require.NotNil(t, k) + collected = append(collected, k.(KeyBytes)) require.NoError(t, it.Next()) } - require.Nil(t, it.Key()) - require.Equal(t, tc.items, collected, "from=%s to=%s chunkSize=%d", - hex.EncodeToString(tc.from), hex.EncodeToString(tc.to), tc.chunkSize) + expected := slices.Concat(tc.items[tc.fromN:], tc.items[:tc.fromN]) + require.Equal(t, expected, collected, "count=%d from=%s chunkSize=%d", + len(tc.items), hex.EncodeToString(tc.from), chunkSize) + for range 2 { + for i := 0; i < len(tc.items); i++ { + k := it.Key() + require.Equal(t, collected[i], k.(KeyBytes)) + require.NoError(t, it.Next()) + } + } } } } - -type fakeIterator struct { - items []KeyBytes -} - -var _ hashsync.Iterator = &fakeIterator{} - -func (it *fakeIterator) Key() hashsync.Ordered { - if len(it.items) == 0 { - return nil - } - return KeyBytes(it.items[0]) -} - -func (it *fakeIterator) Next() error { - if len(it.items) != 0 { - it.items = it.items[1:] - } - if len(it.items) != 0 && string(it.items[0]) == "error" { - return errors.New("iterator error") - } - return nil -} - -func TestConcatIterators(t *testing.T) { - it1 := &fakeIterator{ - items: []KeyBytes{ - {0x00, 0x00, 0x00, 0x01}, - {0x00, 0x00, 0x00, 0x03}, - }, - } - it2 := &fakeIterator{ - items: []KeyBytes{ - {0x0a, 0x05, 0x00, 0x00}, - {0xff, 0xff, 0xff, 0xff}, - }, - } - - it := concatIterators(it1, it2) - var collected []KeyBytes - for i := 0; i < 4; i++ { - collected = append(collected, it.Key().(KeyBytes)) - require.NoError(t, it.Next()) - } - require.Nil(t, it.Key()) - require.Equal(t, []KeyBytes{ - {0x00, 0x00, 0x00, 0x01}, - {0x00, 0x00, 0x00, 0x03}, - {0x0a, 0x05, 0x00, 0x00}, - {0xff, 0xff, 0xff, 0xff}, - }, collected) - - it1 = &fakeIterator{items: []KeyBytes{KeyBytes{0, 0, 0, 0}, KeyBytes("error")}} - it2 = &fakeIterator{items: nil} - - it = concatIterators(it1, it2) - require.Equal(t, KeyBytes{0, 0, 0, 0}, it.Key()) - require.Error(t, it.Next()) - - it1 = &fakeIterator{items: []KeyBytes{KeyBytes{0, 0, 0, 0}}} - it2 = &fakeIterator{items: []KeyBytes{KeyBytes{0, 0, 0, 1}, KeyBytes("error")}} - - it = concatIterators(it1, it2) - require.Equal(t, KeyBytes{0, 0, 0, 0}, it.Key()) - require.NoError(t, it.Next()) - require.Equal(t, KeyBytes{0, 0, 0, 1}, it.Key()) - require.Error(t, it.Next()) -} - -func TestCombineIterators(t *testing.T) { - it1 := &fakeIterator{ - items: []KeyBytes{ - {0x00, 0x00, 0x00, 0x01}, - {0x0a, 0x05, 0x00, 0x00}, - }, - } - it2 := &fakeIterator{ - items: []KeyBytes{ - {0x00, 0x00, 0x00, 0x03}, - {0xff, 0xff, 0xff, 0xff}, - }, - } - - it := combineIterators(it1, it2) - var collected []KeyBytes - for i := 0; i < 4; i++ { - collected = append(collected, it.Key().(KeyBytes)) - require.NoError(t, it.Next()) - } - require.Nil(t, it.Key()) - require.Equal(t, []KeyBytes{ - {0x00, 0x00, 0x00, 0x01}, - {0x00, 0x00, 0x00, 0x03}, - {0x0a, 0x05, 0x00, 0x00}, - {0xff, 0xff, 0xff, 0xff}, - }, collected) - - it1 = &fakeIterator{items: []KeyBytes{KeyBytes{0, 0, 0, 0}, KeyBytes("error")}} - it2 = &fakeIterator{items: nil} - - it = combineIterators(it1, it2) - require.Equal(t, KeyBytes{0, 0, 0, 0}, it.Key()) - require.Error(t, it.Next()) - - it1 = &fakeIterator{items: []KeyBytes{KeyBytes{0, 0, 0, 0}}} - it2 = &fakeIterator{items: []KeyBytes{KeyBytes{0, 0, 0, 1}, KeyBytes("error")}} - - it = combineIterators(it1, it2) - require.Equal(t, KeyBytes{0, 0, 0, 0}, it.Key()) - require.NoError(t, it.Next()) - require.Equal(t, KeyBytes{0, 0, 0, 1}, it.Key()) - require.Error(t, it.Next()) -} From b5b00ea4726348cbf31d55b321b2b53af39a90ea Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 17 Jul 2024 05:29:15 +0400 Subject: [PATCH 43/76] use dynamically extended chunkSize in dbiter --- sync2/dbsync/dbiter.go | 43 +++++++++++++++++++++---------------- sync2/dbsync/dbiter_test.go | 8 +++---- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/sync2/dbsync/dbiter.go b/sync2/dbsync/dbiter.go index 89af542ab4..f3754d1f37 100644 --- a/sync2/dbsync/dbiter.go +++ b/sync2/dbsync/dbiter.go @@ -50,14 +50,15 @@ func (k KeyBytes) isZero() bool { var errEmptySet = errors.New("empty range") type dbRangeIterator struct { - db sql.Database - from KeyBytes - query string - chunkSize int - chunk []KeyBytes - pos int - keyLen int - singleChunk bool + db sql.Database + from KeyBytes + query string + chunkSize int + maxChunkSize int + chunk []KeyBytes + pos int + keyLen int + singleChunk bool } var _ hashsync.Iterator = &dbRangeIterator{} @@ -68,22 +69,23 @@ func newDBRangeIterator( db sql.Database, query string, from KeyBytes, - chunkSize int, + maxChunkSize int, ) (hashsync.Iterator, error) { if from == nil { panic("BUG: makeDBIterator: nil from") } - if chunkSize <= 0 { + if maxChunkSize <= 0 { panic("BUG: makeDBIterator: chunkSize must be > 0") } it := &dbRangeIterator{ - db: db, - from: from.Clone(), - query: query, - chunkSize: chunkSize, - keyLen: len(from), - chunk: make([]KeyBytes, chunkSize), - singleChunk: false, + db: db, + from: from.Clone(), + query: query, + chunkSize: 1, + maxChunkSize: maxChunkSize, + keyLen: len(from), + chunk: make([]KeyBytes, maxChunkSize), + singleChunk: false, } if err := it.load(); err != nil { return nil, err @@ -102,7 +104,11 @@ func (it *dbRangeIterator) load() error { n := 0 // if the chunk size was reduced due to a short chunk before wraparound, we need // to extend it back - it.chunk = it.chunk[:it.chunkSize] + if cap(it.chunk) < it.chunkSize { + it.chunk = make([]KeyBytes, it.chunkSize) + } else { + it.chunk = it.chunk[:it.chunkSize] + } var ierr error _, err := it.db.Exec( it.query, func(stmt *sql.Statement) { @@ -125,6 +131,7 @@ func (it *dbRangeIterator) load() error { return true }) fromZero := it.from.isZero() + it.chunkSize = min(it.chunkSize*2, it.maxChunkSize) switch { case err != nil || ierr != nil: return errors.Join(ierr, err) diff --git a/sync2/dbsync/dbiter_test.go b/sync2/dbsync/dbiter_test.go index 74d998f52c..aa71dfc478 100644 --- a/sync2/dbsync/dbiter_test.go +++ b/sync2/dbsync/dbiter_test.go @@ -285,8 +285,8 @@ func TestDBRangeIterator(t *testing.T) { } { deleteDBItems(t, db) insertDBItems(t, db, tc.items) - for chunkSize := 1; chunkSize < 12; chunkSize++ { - it, err := newDBRangeIterator(db, testQuery, tc.from, chunkSize) + for maxChunkSize := 1; maxChunkSize < 12; maxChunkSize++ { + it, err := newDBRangeIterator(db, testQuery, tc.from, maxChunkSize) if tc.expErr != nil { require.ErrorIs(t, err, tc.expErr) continue @@ -302,8 +302,8 @@ func TestDBRangeIterator(t *testing.T) { require.NoError(t, it.Next()) } expected := slices.Concat(tc.items[tc.fromN:], tc.items[:tc.fromN]) - require.Equal(t, expected, collected, "count=%d from=%s chunkSize=%d", - len(tc.items), hex.EncodeToString(tc.from), chunkSize) + require.Equal(t, expected, collected, "count=%d from=%s maxChunkSize=%d", + len(tc.items), hex.EncodeToString(tc.from), maxChunkSize) for range 2 { for i := 0; i < len(tc.items); i++ { k := it.Key() From bbccaf085c0d2c00fae628c6aca26697fafb35bc Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 17 Jul 2024 05:32:21 +0400 Subject: [PATCH 44/76] fptree: use startTail/endTail instead of tailRefs list --- sync2/dbsync/fptree.go | 280 +++++++++++++++++++++++++----------- sync2/dbsync/fptree_test.go | 175 ++++++++++++++++++---- 2 files changed, 343 insertions(+), 112 deletions(-) diff --git a/sync2/dbsync/fptree.go b/sync2/dbsync/fptree.go index f5404639c6..9bd8b2934e 100644 --- a/sync2/dbsync/fptree.go +++ b/sync2/dbsync/fptree.go @@ -288,8 +288,6 @@ type fpResult struct { } type tailRef struct { - // node from which this tailRef has been derived - idx nodeIndex // maxDepth bits of the key ref uint64 // max count to get from this tail ref, -1 for unlimited @@ -297,13 +295,68 @@ type tailRef struct { } type aggResult struct { - tailRefs []tailRef - fp fingerprint - count uint32 - itype int - limit int - lastVisited nodeIndex - lastPrefix prefix + startTail, endTail *tailRef + fp fingerprint + count uint32 + itype int + limit int + lastVisited nodeIndex + lastPrefix prefix +} + +func (r *aggResult) setStartTail(tail *tailRef) { + if r.startTail != nil { + panic("BUG: left tail already set") + } + r.startTail = tail +} + +func (r *aggResult) setEndTail(tail *tailRef) { + if r.endTail != nil { + panic("BUG: right tail already set") + } + r.endTail = tail +} + +func (r *aggResult) setTails(tail *tailRef, x, y KeyBytes, maxDepth int) { + if r.itype == 0 { + // Doesn't matter which tail as the IDs aren't filtered based on + // the interval, only based on the limit + r.setStartTail(tail) + return + } + xRef := load64(x) >> (64 - maxDepth) + yRef := load64(y) >> (64 - maxDepth) + switch { + case xRef == yRef: + // Same ref for x and y. + // In this case, this tail may only contain relevant entries + // if it's ref is the same as xRef and yRef, and it needs + // to be used for both tails. + if tail.ref == xRef { + r.setStartTail(tail) + r.setEndTail(tail) + } + case r.itype < 0: + // Normal interval. + // The tail can cover the start in case if it's at or below xRef, + // and the end in case if it's at or below yRef, but after xRef. + if tail.ref <= xRef { + r.setStartTail(tail) + } else if tail.ref <= yRef { + r.setEndTail(tail) + } + default: + // Inverse interval. + // The tail can cover the start in case if it's at or below xRef, + // but also after yRef. + // It can cover the end in case if it's at or below yRef. + if tail.ref <= yRef { + r.setEndTail(tail) + } else if tail.ref <= xRef { + r.setStartTail(tail) + } + } } func (r *aggResult) takeAtMost(count int) int { @@ -498,16 +551,16 @@ func (ft *fpTree) tailRefFromPrefix(idx nodeIndex, p prefix, limit int) tailRef if p.len() != ft.maxDepth { panic("BUG: tail from short prefix") } - return tailRef{idx: idx, ref: p.bits(), limit: limit} + return tailRef{ref: p.bits(), limit: limit} } func (ft *fpTree) tailRefFromFingerprint(idx nodeIndex, fp fingerprint, limit int) tailRef { v := load64(fp[:]) if ft.maxDepth >= 64 { - return tailRef{idx: idx, ref: v, limit: limit} + return tailRef{ref: v, limit: limit} } // // fmt.Fprintf(os.Stderr, "QQQQQ: AAAAA: v %016x maxDepth %d shift %d\n", v, ft.maxDepth, (64 - ft.maxDepth)) - return tailRef{idx: idx, ref: v >> (64 - ft.maxDepth), limit: limit} + return tailRef{ref: v >> (64 - ft.maxDepth), limit: limit} } func (ft *fpTree) tailRefFromNodeAndPrefix(idx nodeIndex, n node, p prefix, limit int) tailRef { @@ -569,70 +622,89 @@ func (ft *fpTree) visitNode(idx nodeIndex, p prefix, r *aggResult) (node, bool) return ft.np.node(idx), true } -func (ft *fpTree) aggregateUpToLimit(idx nodeIndex, p prefix, r *aggResult) { +func (ft *fpTree) aggregateUpToLimit(idx nodeIndex, p prefix, r *aggResult) (tailRef *tailRef, cont bool) { ft.enter("aggregateUpToLimit: idx %d p %s limit %d cur_fp %s cur_count %d", idx, p, r.limit, r.fp.String(), r.count) defer func() { ft.leave(r.fp, r.count) }() - node, ok := ft.visitNode(idx, p, r) - switch { - case !ok || r.limit == 0: - // for r.limit == 0, it's important that we still visit the node - // so that we can get the item immediately following the included items - ft.log("stop: ok %v r.limit %d", ok, r.limit) - case r.limit < 0: - // no limit - ft.log("no limit") - r.update(node) - case node.c <= uint32(r.limit): - // node is fully included - ft.log("included fully") - r.update(node) - r.limit -= int(node.c) - case node.leaf(): - tail := ft.tailRefFromPrefix(idx, p, r.takeAtMost(int(node.c))) - r.tailRefs = append(r.tailRefs, tail) - ft.log("add prefix to the tails: %016x => limit %d", tail.ref, r.limit) - default: - pLeft := p.left() - left, haveLeft := ft.visitNode(node.left, pLeft, r) - if haveLeft { - if int(left.c) <= r.limit { - // left node is fully included - ft.log("include left in full") - r.update(left) - r.limit -= int(left.c) - } else { - // we must stop somewhere in the left subtree - ft.log("descend to the left") - ft.aggregateUpToLimit(node.left, pLeft, r) - return + for { + node, ok := ft.visitNode(idx, p, r) + switch { + case !ok: + ft.log("stop: no node") + return nil, true + case r.limit == 0: + // for r.limit == 0, it's important that we still visit the node + // so that we can get the item immediately following the included items + ft.log("stop: limit exhausted") + return nil, false + case r.limit < 0: + // no limit + ft.log("no limit") + r.update(node) + return nil, true + case node.c <= uint32(r.limit): + // node is fully included + ft.log("included fully") + r.update(node) + r.limit -= int(node.c) + return nil, true + case node.leaf(): + // reached the limit on this node, do not need to continue after + // done with it + tail := ft.tailRefFromPrefix(idx, p, r.takeAtMost(int(node.c))) + ft.log("add prefix to the tails: %016x => limit %d", tail.ref, r.limit) + return &tail, false + default: + pLeft := p.left() + left, haveLeft := ft.visitNode(node.left, pLeft, r) + if haveLeft { + if int(left.c) <= r.limit { + // left node is fully included, after which + // we need to stop somewhere in the right subtree + ft.log("include left in full") + r.update(left) + r.limit -= int(left.c) + } else { + // we must stop somewhere in the left subtree, + // and the right subtree is irrelevant + ft.log("descend to the left") + idx = node.left + p = pLeft + continue + } } + ft.log("descend to the right") + idx = node.right + p = p.right() } - ft.log("descend to the right") - ft.aggregateUpToLimit(node.right, p.right(), r) } } -func (ft *fpTree) aggregateLeft(idx nodeIndex, v uint64, p prefix, r *aggResult) { +func (ft *fpTree) aggregateLeft(idx nodeIndex, v uint64, p prefix, r *aggResult) (cont bool) { ft.enter("aggregateLeft: idx %d v %016x p %s limit %d", idx, v, p, r.limit) defer func() { - ft.leave(r.fp, r.count, r.tailRefs) + ft.leave(r.fp, r.count, r.startTail, r.endTail) }() node, ok := ft.visitNode(idx, p, r) switch { - case !ok || r.limit == 0: + case !ok: // for r.limit == 0, it's important that we still visit the node // so that we can get the item immediately following the included items - ft.log("stop: ok %v r.limit %d", ok, r.limit) + ft.log("stop: no node") + return true + case r.limit == 0: + ft.log("stop: limit exhausted") + return false case p.len() == ft.maxDepth: if node.left != noIndex || node.right != noIndex { panic("BUG: node @ maxDepth has children") } tail := ft.tailRefFromPrefix(idx, p, r.takeAtMost(int(node.c))) - r.tailRefs = append(r.tailRefs, tail) + r.setStartTail(&tail) ft.log("add prefix to the tails: %016x => limit %d", tail.ref, r.limit) + return true case node.leaf(): // TBD: combine with prev // For leaf 1-nodes, we can use the fingerprint to get tailRef // by which the actual IDs will be selected @@ -640,38 +712,51 @@ func (ft *fpTree) aggregateLeft(idx nodeIndex, v uint64, p prefix, r *aggResult) panic("BUG: leaf non-1 node below maxDepth") } tail := ft.tailRefFromFingerprint(idx, node.fp, r.takeAtMost(1)) - r.tailRefs = append(r.tailRefs, tail) + r.setStartTail(&tail) ft.log("add prefix to the tails (1-leaf): %016x (fp %s) => limit %d", tail.ref, node.fp, r.limit) + return true case v&bit63 == 0: ft.log("incl right node %d + go left to node %d", node.right, node.left) + if !ft.aggregateLeft(node.left, v<<1, p.left(), r) { + return false + } if node.right != noIndex { - ft.aggregateUpToLimit(node.right, p.right(), r) + tail, cont := ft.aggregateUpToLimit(node.right, p.right(), r) + if tail != nil { + r.setStartTail(tail) + } + return cont } - ft.aggregateLeft(node.left, v<<1, p.left(), r) + return true default: ft.log("go right to node %d", node.right) - ft.aggregateLeft(node.right, v<<1, p.right(), r) + return ft.aggregateLeft(node.right, v<<1, p.right(), r) } } -func (ft *fpTree) aggregateRight(idx nodeIndex, v uint64, p prefix, r *aggResult) { +func (ft *fpTree) aggregateRight(idx nodeIndex, v uint64, p prefix, r *aggResult) (cont bool) { ft.enter("aggregateRight: idx %d v %016x p %s limit %d", idx, v, p, r.limit) defer func() { - ft.leave(r.fp, r.count, r.tailRefs) + ft.leave(r.fp, r.count, r.startTail, r.endTail) }() node, ok := ft.visitNode(idx, p, r) switch { - case !ok || r.limit == 0: + case !ok: // for r.limit == 0, it's important that we still visit the node // so that we can get the item immediately following the included items - ft.log("stop: ok %v r.limit %d", ok, r.limit) + ft.log("stop: no node") + return true + case r.limit == 0: + ft.log("stop: limit exhausted") + return false case p.len() == ft.maxDepth: if node.left != noIndex || node.right != noIndex { panic("BUG: node @ maxDepth has children") } tail := ft.tailRefFromPrefix(idx, p, r.takeAtMost(int(node.c))) - r.tailRefs = append(r.tailRefs, tail) + r.setEndTail(&tail) ft.log("add prefix to the tails: %016x => limit %d", tail.ref, r.limit) + return true case node.leaf(): // For leaf 1-nodes, we can use the fingerprint to get tailRef // by which the actual IDs will be selected @@ -679,17 +764,24 @@ func (ft *fpTree) aggregateRight(idx nodeIndex, v uint64, p prefix, r *aggResult panic("BUG: leaf non-1 node below maxDepth") } tail := ft.tailRefFromFingerprint(idx, node.fp, r.takeAtMost(1)) - r.tailRefs = append(r.tailRefs, tail) + r.setEndTail(&tail) ft.log("add prefix to the tails (1-leaf): %016x (fp %s) => limit %d", tail.ref, node.fp, r.limit) + return true case v&bit63 == 0: ft.log("go left to node %d", node.left) - ft.aggregateRight(node.left, v<<1, p.left(), r) + return ft.aggregateRight(node.left, v<<1, p.left(), r) default: ft.log("incl left node %d + go right to node %d", node.left, node.right) if node.left != noIndex { - ft.aggregateUpToLimit(node.left, p.left(), r) + tail, cont := ft.aggregateUpToLimit(node.left, p.left(), r) + if tail != nil { + r.setEndTail(tail) + } + if !cont { + return false + } } - ft.aggregateRight(node.right, v<<1, p.right(), r) + return ft.aggregateRight(node.right, v<<1, p.right(), r) } } @@ -707,7 +799,8 @@ func (ft *fpTree) aggregateInterval(x, y KeyBytes, limit int) (r aggResult) { // the whole set if ft.root != noIndex { ft.log("whole set") - ft.aggregateUpToLimit(ft.root, 0, &r) + tail, _ := ft.aggregateUpToLimit(ft.root, 0, &r) + r.setTails(tail, x, y, ft.maxDepth) } else { ft.log("empty set (no root)") } @@ -735,13 +828,14 @@ func (ft *fpTree) aggregateInterval(x, y KeyBytes, limit int) (r aggResult) { // through the IDs if lcaNode.leaf() { ft.visitNode(lcaIdx, followedPrefix, &r) - r.tailRefs = append(r.tailRefs, - ft.tailRefFromNodeAndPrefix( - lcaIdx, lcaNode, followedPrefix, r.takeAtMost(limit))) + tail := ft.tailRefFromNodeAndPrefix( + lcaIdx, lcaNode, followedPrefix, r.takeAtMost(limit)) + r.setTails(&tail, x, y, ft.maxDepth) } } default: // inverse interval: [min; y); [x; max] + // first, we handle [x; max] part pf0 := preFirst0(x) idx0, followedPrefix, found := ft.followPrefix(ft.root, pf0, 0) var pf0Node node @@ -759,10 +853,12 @@ func (ft *fpTree) aggregateInterval(x, y KeyBytes, limit int) (r aggResult) { if pf0Node.leaf() { ft.visitNode(idx0, followedPrefix, &r) rightLimit := r.takeAtMost(int(pf0Node.c)) - r.tailRefs = append(r.tailRefs, ft.tailRefFromNodeAndPrefix(idx0, pf0Node, followedPrefix, rightLimit)) + tail := ft.tailRefFromNodeAndPrefix(idx0, pf0Node, followedPrefix, rightLimit) + r.setStartTail(&tail) } } + // then we handle [min, y) part pf1 := preFirst1(y) idx1, followedPrefix, found := ft.followPrefix(ft.root, pf1, 0) var pf1Node node @@ -780,7 +876,8 @@ func (ft *fpTree) aggregateInterval(x, y KeyBytes, limit int) (r aggResult) { if pf1Node.leaf() { ft.visitNode(idx1, followedPrefix, &r) leftLimit := r.takeAtMost(int(pf1Node.c)) - r.tailRefs = append(r.tailRefs, ft.tailRefFromNodeAndPrefix(idx1, pf1Node, followedPrefix, leftLimit)) + tail := ft.tailRefFromNodeAndPrefix(idx1, pf1Node, followedPrefix, leftLimit) + r.setEndTail(&tail) } } } @@ -794,17 +891,30 @@ func (ft *fpTree) fingerprintInterval(x, y KeyBytes, limit int) (fpr fpResult, e }() r := ft.aggregateInterval(x, y, limit) wasWithinRange := false - ft.log("tailRefs: %#v count: %d", r.tailRefs, r.count) - // Check for edge case: the fingerprinting has looped back to the same tail, - // so we have the tail repeated twice and no other tails. - // We can't hit any tails in between. - noStop := false - if len(r.tailRefs) == 2 && r.tailRefs[0].ref == r.tailRefs[1].ref { - ft.log("edge case: tailRef loopback") - r.tailRefs = r.tailRefs[:1] - noStop = true - } - if err := ft.idStore.iterateIDs(r.tailRefs, func(tailRef tailRef, id KeyBytes) bool { + ft.log("startTail: %#v endTail: %#v count: %d", r.startTail, r.endTail, r.count) + // QQQQQ: TBD: scan tails separately using the iterators + var tailRefs []tailRef + if r.startTail != nil { + tailRefs = append(tailRefs, *r.startTail) + } + if r.endTail != nil && (r.startTail == nil || r.startTail.ref != r.endTail.ref) { + tailRefs = append(tailRefs, *r.endTail) + } + noStop := true //len(tailRefs) == 1 + // panic("TBD: limit; problem: single-tailRef wraparound, need to start from x in any case, but we're starting from the beginning of the tailRef") + // QQQQQ: should not use iterateIDs. No need for nextLeaf, etc. + // The store should return iterator for each tailRef + // There should be *always* two tailRefs (they may be the same) + // For 1st tailRef we call Next() on the iterator till Key() is >= x, it'll be RangeInfo.Start, + // then we clone the iterator (it should be cloneable!) and use it to include the IDs + // in the fingerprint + // For 2nd tailRef we call Next() on the iterator till Key() is >= y, including IDs + // in the fingerprint, then return it as RangeInfo.End (outside the range) + // DOWNSIDE: too much needs to be fetched in case of bigger chunks + // Possible optimization [DONE]: specify max chunk size for db iterator, + // start from 1, increase it 2x on each iteration but not over max chunk size + // QQQQQ: TBD: need to restore combinedIterator + if err := ft.idStore.iterateIDs(tailRefs, func(tailRef tailRef, id KeyBytes) bool { if idWithinInterval(id, x, y, r.itype) { r.fp.update(id) r.count++ @@ -816,7 +926,7 @@ func (ft *fpTree) fingerprintInterval(x, y KeyBytes, limit int) (fpr fpResult, e } else { // if we were within the range but now we're out of it, // this means we're at or beyond y and can stop - // return !wasWithinRange + // return !wasWithinRange || noStop // QQQQQ: rmme if wasWithinRange { ft.log("tailRef %v: id %s outside range after id(s) within range => terminating", @@ -847,7 +957,7 @@ func (ft *fpTree) dumpNode(w io.Writer, idx nodeIndex, indent, dir string) { leaf := node.leaf() countStr := strconv.Itoa(int(node.c)) if leaf { - countStr = "LEAF-" + countStr + countStr = "LEAF:" + countStr } fmt.Fprintf(w, "%s%sidx=%d %s %s [%d]\n", indent, dir, idx, node.fp, countStr, ft.np.refCount(idx)) if !leaf { diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index d864e6f980..e8b1fe5f06 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -215,11 +215,20 @@ func (s *fakeIDDBStore) registerHash(h KeyBytes) error { type idStoreFunc func(maxDepth int) idStore func testFPTree(t *testing.T, makeIDStore idStoreFunc) { + type rangeTestCase struct { + xIdx, yIdx int + x, y string + limit int + fp fingerprint + count uint32 + itype int + } for _, tc := range []struct { name string maxDepth int ids []string - results map[[3]int]fpResult + ranges []rangeTestCase + x, y string }{ { name: "ids1", @@ -231,63 +240,99 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { "8888888888888888888888888888888888888888888888888888888888888888", "abcdef1234567890000000000000000000000000000000000000000000000000", }, - results: map[[3]int]fpResult{ - {0, 0, -1}: { + ranges: []rangeTestCase{ + { + xIdx: 0, + yIdx: 0, + limit: -1, fp: hexToFingerprint("642464b773377bbddddddddd"), count: 5, itype: 0, }, - {0, 0, 3}: { + { + xIdx: 0, + yIdx: 0, + limit: 3, fp: hexToFingerprint("4761032dcfe98ba555555555"), count: 3, itype: 0, }, - {4, 4, -1}: { + { + xIdx: 4, + yIdx: 4, + limit: -1, fp: hexToFingerprint("642464b773377bbddddddddd"), count: 5, itype: 0, }, - {0, 1, -1}: { + { + xIdx: 0, + yIdx: 1, + limit: -1, fp: hexToFingerprint("000000000000000000000000"), count: 1, itype: -1, }, - {0, 3, -1}: { + { + xIdx: 0, + yIdx: 3, + limit: -1, fp: hexToFingerprint("4761032dcfe98ba555555555"), count: 3, itype: -1, }, - {0, 4, 3}: { + { + xIdx: 0, + yIdx: 4, + limit: 3, fp: hexToFingerprint("4761032dcfe98ba555555555"), count: 3, itype: -1, }, - {1, 4, -1}: { + { + xIdx: 1, + yIdx: 4, + limit: -1, fp: hexToFingerprint("cfe98ba54761032ddddddddd"), count: 3, itype: -1, }, - {1, 0, -1}: { + { + xIdx: 1, + yIdx: 0, + limit: -1, fp: hexToFingerprint("642464b773377bbddddddddd"), count: 4, itype: 1, }, - {2, 0, -1}: { + { + xIdx: 2, + yIdx: 0, + limit: -1, fp: hexToFingerprint("761032cfe98ba54ddddddddd"), count: 3, itype: 1, }, - {3, 1, -1}: { + { + xIdx: 3, + yIdx: 1, + limit: -1, fp: hexToFingerprint("2345679abcdef01888888888"), count: 3, itype: 1, }, - {3, 2, -1}: { + { + xIdx: 3, + yIdx: 2, + limit: -1, fp: hexToFingerprint("317131e226622ee888888888"), count: 4, itype: 1, }, - {3, 2, 3}: { + { + xIdx: 3, + yIdx: 2, + limit: 3, fp: hexToFingerprint("2345679abcdef01888888888"), count: 3, itype: 1, @@ -303,28 +348,43 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { "a280bcb8123393e0d4a15e5c9850aab5dddffa03d5efa92e59bc96202e8992bc", "e93163f908630280c2a8bffd9930aa684be7a3085432035f5c641b0786590d1d", }, - results: map[[3]int]fpResult{ - {0, 0, -1}: { + ranges: []rangeTestCase{ + { + xIdx: 0, + yIdx: 0, + limit: -1, fp: hexToFingerprint("a76fc452775b55e0dacd8be5"), count: 4, itype: 0, }, - {0, 0, 3}: { + { + xIdx: 0, + yIdx: 0, + limit: 3, fp: hexToFingerprint("4e5ea7ab7f38576018653418"), count: 3, itype: 0, }, - {0, 3, -1}: { + { + xIdx: 0, + yIdx: 3, + limit: -1, fp: hexToFingerprint("4e5ea7ab7f38576018653418"), count: 3, itype: -1, }, - {3, 1, -1}: { + { + xIdx: 3, + yIdx: 1, + limit: -1, fp: hexToFingerprint("87760f5e21a0868dc3b0c7a9"), count: 2, itype: 1, }, - {3, 2, -1}: { + { + xIdx: 3, + yIdx: 2, + limit: -1, fp: hexToFingerprint("05ef78ea6568c6000e6cd5b9"), count: 3, itype: 1, @@ -368,14 +428,65 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { "e41d8c3a7607ec5423cc376a34d21494f2d0c625fb9bebcec09d06c188ab7f3f", "e9110a384198b47be2bb63e64f094069a0ee9a013e013176bbe8189834c5e4c8", }, - results: map[[3]int]fpResult{ - {31, 0, -1}: { + ranges: []rangeTestCase{ + { + xIdx: 31, + yIdx: 0, + limit: -1, fp: hexToFingerprint("e9110a384198b47be2bb63e6"), count: 1, itype: 1, }, }, }, + { + name: "ids4", + maxDepth: 24, + ids: []string{ + "0451cd036aff0367b07590032da827b516b63a4c1b36ea9a253dcf9a7e084980", + "0e75d10a8e98a4307dd9d0427dc1d2ebf9e45b602d159ef62c5da95197159844", + "18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + "33940245f4aace670c84f471ff4e862d1d82ce0ada9b98a753038b4f9e60e330", + "366d9e7adb3932e52e0a92a0afc75a2875995e7de8e0c4159e22eb97526a3547", + "66883aa35d2c8d293f07c5c5c40c63416317423418fe5c7fd17b5fb68b3e976e", + "80fce3e9654459cff3441e1a96413f0872e0b6f093879609696042fcfe1c8115", + "8b2025fbe0bbebea4baee48bac9a63a4013a2ec898d7b0a518eccdb99bdb368e", + "8e3e609653adfddcdcb6ddda7461db3a2fc822c3f96874a002f715b80865e575", + "9b25e39d6cc3beac3ecc12140f46a699880ac8303555c694fd40ba8e61bb8b47", + "a3c8628a1b28d1ba6f3d8beb4a29315c02789c5b53a095fa7865c9b3041502d6", + "a98fdcab5e351a1bfd25ddcf9973e9c56a4b688d78743a8a03fa3b1d53da4949", + "ac9c015dd51defacfc14bd4c9c8eedb89aad884bef493553a189a2915c828e95", + "ba745196493a8368ef091860f2692978b381f67566d3413e85167672d672c8ac", + "c26353d8bc9a1eea8e79fd693c1a1e58dacded75ceda84ed6c356bcf02b6d0f1", + "c3f126a37c2e33b6258c87fd043026dacf0b8dd4df7a9afd7cdc293b075e1878", + "cefd0cc8b32929df07b6ebb5b6e433f28d5460f143814f3f651330ea15e5d6e7", + "d9390718256e71edfe671334edbfcbed8b4de3221db55805ebf606c73fe969f1", + "db7ee147da05a5cbec3f59b020cbdba88e40ab6b212ae93c98d5a210d83a4a7b", + "deab906f979a647eff85f3a54e5edd665f2536e0005812aee2e5e411ae71855e", + "e0b6ab7f483527771faadbee8b4ed99ae96167d054ae5c513faf00c78aa36bdd", + "e4ed6f5dcf179a4f10521d58d65d423098af5f6f18c42f3125a5917d338b7477", + "e53de3ec53ba88029a2a0459a3ab82cdb3726c8aeccabf38a04e048b9add92ef", + "f2aff99498615c44d94266060e948c11bb275ec37d0d3c651bb3ba0039a11a64", + "f7f81332b63b79718f0321660a5cd8f6970474ff873afcdebb0d3436a2ad12ac", + "fb42c36089a4883bc7ceaae9a57924d78557edb63ede3d5a2cf2d1f08db799d0", + "fe494ce48f5826c00f6bc6af74258ec6e47b92365850deed95b5bfcaeccc6be8", + }, + ranges: []rangeTestCase{ + { + x: "582485793d71c3e8429b9b2c8df360c2ea7bf90080d5bf375fe4618b00f59c0b", + y: "7eff517d2f11ed32f935be3001499ac779160a4891a496f88da0ceb33e3496cc", + limit: -1, + fp: hexToFingerprint("66883aa35d2c8d293f07c5c5"), + count: 1, + itype: -1, + }, + }, + }, } { t.Run(tc.name, func(t *testing.T) { var np nodePool @@ -395,12 +506,22 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { checkTree(t, ft, tc.maxDepth) - for idRange, expResult := range tc.results { - x := hs[idRange[0]] - y := hs[idRange[1]] - fpr, err := ft.fingerprintInterval(x[:], y[:], idRange[2]) + for _, rtc := range tc.ranges { + var x, y types.Hash32 + if rtc.x != "" { + x = types.HexToHash32(rtc.x) + y = types.HexToHash32(rtc.y) + } else { + x = hs[rtc.xIdx] + y = hs[rtc.yIdx] + } + fpr, err := ft.fingerprintInterval(x[:], y[:], rtc.limit) require.NoError(t, err) - require.Equal(t, expResult, fpr) + require.Equal(t, fpResult{ + fp: rtc.fp, + count: rtc.count, + itype: rtc.itype, + }, fpr) } ft.release() From 825ff03dbf829bcf0e744f4a86108bf629f21947 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 23 Jul 2024 20:11:35 +0400 Subject: [PATCH 45/76] fptree: use iterators in fptree & fix limits --- sync2/dbsync/dbitemstore.go | 21 +- sync2/dbsync/dbiter.go | 91 ++- sync2/dbsync/dbiter_test.go | 102 +++ sync2/dbsync/fptree.go | 836 +++++++---------------- sync2/dbsync/fptree_test.go | 215 +++--- sync2/dbsync/inmemidstore.go | 71 ++ sync2/dbsync/inmemidstore_test.go | 78 +++ sync2/dbsync/refcountpool.go | 5 + sync2/dbsync/sqlidstore.go | 95 +++ sync2/dbsync/sqlidstore_test.go | 50 ++ sync2/internal/skiplist/skiplist.go | 185 +++++ sync2/internal/skiplist/skiplist_test.go | 167 +++++ 12 files changed, 1214 insertions(+), 702 deletions(-) create mode 100644 sync2/dbsync/inmemidstore.go create mode 100644 sync2/dbsync/inmemidstore_test.go create mode 100644 sync2/dbsync/sqlidstore.go create mode 100644 sync2/dbsync/sqlidstore_test.go create mode 100644 sync2/internal/skiplist/skiplist.go create mode 100644 sync2/internal/skiplist/skiplist_test.go diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go index 6a4d54fe80..7f8e34228e 100644 --- a/sync2/dbsync/dbitemstore.go +++ b/sync2/dbsync/dbitemstore.go @@ -2,6 +2,7 @@ package dbsync import ( "context" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sync2/hashsync" ) @@ -26,7 +27,7 @@ func NewDBItemStore( dbStore := newDBBackedStore(db, query, keyLen, maxDepth) return &DBItemStore{ db: db, - ft: newFPTree(np, dbStore, maxDepth), + ft: newFPTree(np, dbStore, keyLen, maxDepth), query: query, keyLen: keyLen, maxDepth: maxDepth, @@ -80,20 +81,12 @@ func (d *DBItemStore) Copy() hashsync.ItemStore { // Has implements hashsync.ItemStore. func (d *DBItemStore) Has(k hashsync.Ordered) (bool, error) { - id := k.(KeyBytes) - if len(id) < d.keyLen { - panic("BUG: short key passed") - } - tailRefs := []tailRef{ - {ref: load64(id) >> (64 - d.maxDepth), limit: -1}, + it, err := d.ft.iter(k.(KeyBytes)) + if err == nil { + return k.Compare(it.Key()) == 0, nil } - found := false - if err := d.ft.iterateIDs(tailRefs, func(_ tailRef, cur KeyBytes) bool { - c := id.Compare(cur) - found = c == 0 - return c > 0 - }); err != nil { + if err != errEmptySet { return false, err } - return found, nil + return false, nil } diff --git a/sync2/dbsync/dbiter.go b/sync2/dbsync/dbiter.go index f3754d1f37..d7fcff6f85 100644 --- a/sync2/dbsync/dbiter.go +++ b/sync2/dbsync/dbiter.go @@ -2,6 +2,7 @@ package dbsync import ( "bytes" + "encoding/hex" "errors" "slices" @@ -13,6 +14,10 @@ type KeyBytes []byte var _ hashsync.Ordered = KeyBytes(nil) +func (k KeyBytes) String() string { + return hex.EncodeToString(k) +} + func (k KeyBytes) Clone() KeyBytes { return slices.Clone(k) } @@ -61,7 +66,7 @@ type dbRangeIterator struct { singleChunk bool } -var _ hashsync.Iterator = &dbRangeIterator{} +var _ iterator = &dbRangeIterator{} // makeDBIterator creates a dbRangeIterator and initializes it from the database. // If query returns no rows even after starting from zero ID, errEmptySet error is returned. @@ -70,7 +75,7 @@ func newDBRangeIterator( query string, from KeyBytes, maxChunkSize int, -) (hashsync.Iterator, error) { +) (iterator, error) { if from == nil { panic("BUG: makeDBIterator: nil from") } @@ -168,9 +173,7 @@ func (it *dbRangeIterator) load() error { func (it *dbRangeIterator) Key() hashsync.Ordered { if it.pos < len(it.chunk) { - key := make(KeyBytes, it.keyLen) - copy(key, it.chunk[it.pos]) - return key + return slices.Clone(it.chunk[it.pos]) } return nil } @@ -185,3 +188,81 @@ func (it *dbRangeIterator) Next() error { } return it.load() } + +func (it *dbRangeIterator) clone() iterator { + cloned := *it + cloned.from = slices.Clone(it.from) + cloned.chunk = make([]KeyBytes, len(it.chunk)) + for i, k := range it.chunk { + cloned.chunk[i] = slices.Clone(k) + } + return &cloned +} + +type combinedIterator struct { + iters []iterator + wrapped []iterator + ahead iterator + aheadIdx int +} + +// combineIterators combines multiple iterators into one, returning the smallest current +// key among all iterators at each step. +func combineIterators(iters ...iterator) iterator { + return &combinedIterator{iters: iters} +} + +func (c *combinedIterator) aheadIterator() iterator { + if c.ahead == nil { + if len(c.iters) == 0 { + if len(c.wrapped) == 0 { + return nil + } + c.iters = c.wrapped + c.wrapped = nil + } + c.ahead = c.iters[0] + c.aheadIdx = 0 + for i := 1; i < len(c.iters); i++ { + if c.iters[i].Key() != nil { + if c.ahead.Key() == nil || c.iters[i].Key().Compare(c.ahead.Key()) < 0 { + c.ahead = c.iters[i] + c.aheadIdx = i + } + } + } + } + return c.ahead +} + +func (c *combinedIterator) Key() hashsync.Ordered { + // return c.aheadIterator().Key() + it := c.aheadIterator() + return it.Key() +} + +func (c *combinedIterator) Next() error { + it := c.aheadIterator() + oldKey := it.Key() + if err := it.Next(); err != nil { + return err + } + c.ahead = nil + if oldKey.Compare(it.Key()) >= 0 { + // the iterator has wrapped around, move it to the wrapped list + // which will be used after all the iterators have wrapped around + c.wrapped = append(c.wrapped, it) + c.iters = append(c.iters[:c.aheadIdx], c.iters[c.aheadIdx+1:]...) + } + return nil +} + +func (c *combinedIterator) clone() iterator { + cloned := &combinedIterator{ + iters: make([]iterator, len(c.iters)), + } + for i, it := range c.iters { + cloned.iters[i] = it.clone() + } + return cloned +} diff --git a/sync2/dbsync/dbiter_test.go b/sync2/dbsync/dbiter_test.go index aa71dfc478..e1dd3beeff 100644 --- a/sync2/dbsync/dbiter_test.go +++ b/sync2/dbsync/dbiter_test.go @@ -2,11 +2,13 @@ package dbsync import ( "encoding/hex" + "errors" "fmt" "slices" "testing" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/hashsync" "github.com/stretchr/testify/require" ) @@ -294,23 +296,123 @@ func TestDBRangeIterator(t *testing.T) { require.NoError(t, err) // when there are no items, errEmptySet is returned require.NotEmpty(t, tc.items) + clonedIt := it.clone() var collected []KeyBytes for i := 0; i < len(tc.items); i++ { k := it.Key() require.NotNil(t, k) collected = append(collected, k.(KeyBytes)) + require.Equal(t, k, clonedIt.Key()) require.NoError(t, it.Next()) + // calling Next on the original iterator + // shouldn't affect the cloned one + require.Equal(t, k, clonedIt.Key()) + require.NoError(t, clonedIt.Next()) } expected := slices.Concat(tc.items[tc.fromN:], tc.items[:tc.fromN]) require.Equal(t, expected, collected, "count=%d from=%s maxChunkSize=%d", len(tc.items), hex.EncodeToString(tc.from), maxChunkSize) + clonedIt = it.clone() for range 2 { for i := 0; i < len(tc.items); i++ { k := it.Key() require.Equal(t, collected[i], k.(KeyBytes)) + require.Equal(t, k, clonedIt.Key()) require.NoError(t, it.Next()) + require.Equal(t, k, clonedIt.Key()) + require.NoError(t, clonedIt.Next()) } } } } } + +type fakeIterator struct { + items, allItems []KeyBytes +} + +var _ hashsync.Iterator = &fakeIterator{} + +func (it *fakeIterator) Key() hashsync.Ordered { + if len(it.allItems) == 0 { + panic("no items") + } + if len(it.items) == 0 { + it.items = it.allItems + } + return KeyBytes(it.items[0]) +} + +func (it *fakeIterator) Next() error { + if len(it.items) == 0 { + it.items = it.allItems + } + it.items = it.items[1:] + if len(it.items) != 0 && string(it.items[0]) == "error" { + return errors.New("iterator error") + } + return nil +} + +func (it *fakeIterator) clone() iterator { + cloned := &fakeIterator{ + allItems: make([]KeyBytes, len(it.allItems)), + } + for i, k := range it.allItems { + cloned.allItems[i] = slices.Clone(k) + } + cloned.items = cloned.allItems[len(it.allItems)-len(it.items):] + return cloned +} + +func TestCombineIterators(t *testing.T) { + it1 := &fakeIterator{ + allItems: []KeyBytes{ + {0x00, 0x00, 0x00, 0x01}, + {0x0a, 0x05, 0x00, 0x00}, + }, + } + it2 := &fakeIterator{ + allItems: []KeyBytes{ + {0x00, 0x00, 0x00, 0x03}, + {0xff, 0xff, 0xff, 0xff}, + }, + } + + it := combineIterators(it1, it2) + clonedIt := it.clone() + for range 3 { + var collected []KeyBytes + for i := 0; i < 4; i++ { + k := it.Key() + collected = append(collected, k.(KeyBytes)) + require.Equal(t, k, clonedIt.Key()) + require.NoError(t, it.Next()) + require.Equal(t, k, clonedIt.Key()) + require.NoError(t, clonedIt.Next()) + } + require.Equal(t, []KeyBytes{ + {0x00, 0x00, 0x00, 0x01}, + {0x00, 0x00, 0x00, 0x03}, + {0x0a, 0x05, 0x00, 0x00}, + {0xff, 0xff, 0xff, 0xff}, + }, collected) + require.Equal(t, KeyBytes{0x00, 0x00, 0x00, 0x01}, it.Key()) + } + + it1 = &fakeIterator{allItems: []KeyBytes{KeyBytes{0, 0, 0, 0}, KeyBytes("error")}} + it2 = &fakeIterator{allItems: []KeyBytes{KeyBytes{0, 0, 0, 1}}} + + it = combineIterators(it1, it2) + require.Equal(t, KeyBytes{0, 0, 0, 0}, it.Key()) + require.Error(t, it.Next()) + + it1 = &fakeIterator{allItems: []KeyBytes{KeyBytes{0, 0, 0, 0}}} + it2 = &fakeIterator{allItems: []KeyBytes{KeyBytes{0, 0, 0, 1}, KeyBytes("error")}} + + it = combineIterators(it1, it2) + require.Equal(t, KeyBytes{0, 0, 0, 0}, it.Key()) + require.NoError(t, it.Next()) + require.Equal(t, KeyBytes{0, 0, 0, 1}, it.Key()) + require.Error(t, it.Next()) +} diff --git a/sync2/dbsync/fptree.go b/sync2/dbsync/fptree.go index 9bd8b2934e..38e3bb98da 100644 --- a/sync2/dbsync/fptree.go +++ b/sync2/dbsync/fptree.go @@ -4,17 +4,17 @@ import ( "bytes" "encoding/binary" "encoding/hex" + "errors" "fmt" "io" "math/bits" "os" "runtime" - "slices" "strconv" "strings" "sync" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/hashsync" ) type trace struct { @@ -210,10 +210,6 @@ func (p prefix) highBit() bool { return p.bits()>>(p.len()-1) != 0 } -func (p prefix) lowBit() bool { - return p&(1<>(64-p.len()) == p.bits() +} + func load64(h KeyBytes) uint64 { return binary.BigEndian.Uint64(h[:8]) } -// func hashPrefix(h KeyBytes, nbits int) prefix { -// if nbits < 0 || nbits > maxPrefixLen { -// panic("BUG: bad prefix length") -// } -// if nbits == 0 { -// return 0 -// } -// v := load64(h) -// return prefix((v>>(64-nbits-prefixLenBits))&prefixBitMask + uint64(nbits)) -// } - func preFirst0(h KeyBytes) prefix { l := min(maxPrefixLen, bits.LeadingZeros64(^load64(h))) return mkprefix((1<> (64 - maxDepth) - yRef := load64(y) >> (64 - maxDepth) - switch { - case xRef == yRef: - // Same ref for x and y. - // In this case, this tail may only contain relevant entries - // if it's ref is the same as xRef and yRef, and it needs - // to be used for both tails. - if tail.ref == xRef { - r.setStartTail(tail) - r.setEndTail(tail) - } - case r.itype < 0: - // Normal interval. - // The tail can cover the start in case if it's at or below xRef, - // and the end in case if it's at or below yRef, but after xRef. - if tail.ref <= xRef { - r.setStartTail(tail) - } else if tail.ref <= yRef { - r.setEndTail(tail) - } - default: - // Inverse interval. - // The tail can cover the start in case if it's at or below xRef, - // but also after yRef. - // It can cover the end in case if it's at or below yRef. - if tail.ref <= yRef { - r.setEndTail(tail) - } else if tail.ref <= xRef { - r.setStartTail(tail) - } + x := load64(ac.x) + y := load64(ac.y) + v := p.bits() << (64 - p.len()) + maxV := v + (1 << (64 - p.len())) - 1 + if ac.itype < 0 { + // normal interval + // fmt.Fprintf(os.Stderr, "QQQQQ: (0) itype %d x %016x y %016x v %016x maxV %016x result %v\n", ac.itype, x, y, v, maxV, v >= x && maxV < y) + return v >= x && maxV < y } + // inverted interval + // fmt.Fprintf(os.Stderr, "QQQQQ: (1) itype %d x %016x y %016x v %016x maxV %016x result %v\n", ac.itype, x, y, v, maxV, v >= x || v < y) + return v >= x || maxV < y } -func (r *aggResult) takeAtMost(count int) int { +func (ac *aggContext) maybeIncludeNode(node node) bool { switch { - case r.limit < 0: - return -1 - case count <= r.limit: - r.limit -= count + case ac.limit < 0: + case uint32(ac.limit) < node.c: + return false default: - count = r.limit - r.limit = 0 + ac.limit -= int(node.c) } - return count + ac.fp.update(node.fp[:]) + ac.count += node.c + return true } -func (r *aggResult) update(node node) { - r.fp.update(node.fp[:]) - r.count += node.c - // // fmt.Fprintf(os.Stderr, "QQQQQ: r.count <= %d r.fp <= %s\n", r.count, r.fp) +type iterator interface { + hashsync.Iterator + clone() iterator } type idStore interface { clone() idStore registerHash(h KeyBytes) error - iterateIDs(tailRefs []tailRef, toCall func(tailRef, KeyBytes) bool) error + iter(from KeyBytes) (iterator, error) } type fpTree struct { @@ -390,11 +334,18 @@ type fpTree struct { np *nodePool rootMtx sync.Mutex root nodeIndex + keyLen int maxDepth int } -func newFPTree(np *nodePool, idStore idStore, maxDepth int) *fpTree { - ft := &fpTree{np: np, idStore: idStore, root: noIndex, maxDepth: maxDepth} +func newFPTree(np *nodePool, idStore idStore, keyLen, maxDepth int) *fpTree { + ft := &fpTree{ + np: np, + idStore: idStore, + root: noIndex, + keyLen: keyLen, + maxDepth: maxDepth, + } runtime.SetFinalizer(ft, (*fpTree).release) return ft } @@ -531,6 +482,7 @@ func (ft *fpTree) followPrefix(from nodeIndex, p, followed prefix) (idx nodeInde ft.enter("followPrefix: from %d p %s highBit %v", from, p, p.highBit()) defer func() { ft.leave(idx, rp, found) }() + // QQQQQ: refactor into a loop switch { case from == noIndex: return noIndex, followed, false @@ -545,127 +497,115 @@ func (ft *fpTree) followPrefix(from nodeIndex, p, followed prefix) (idx nodeInde } } -func (ft *fpTree) tailRefFromPrefix(idx nodeIndex, p prefix, limit int) tailRef { - // TODO: QQQQ: FIXME: this may happen with reverse intervals, - // but should we even be checking the prefixes in this case? - if p.len() != ft.maxDepth { - panic("BUG: tail from short prefix") - } - return tailRef{ref: p.bits(), limit: limit} -} - -func (ft *fpTree) tailRefFromFingerprint(idx nodeIndex, fp fingerprint, limit int) tailRef { - v := load64(fp[:]) - if ft.maxDepth >= 64 { - return tailRef{ref: v, limit: limit} - } - // // fmt.Fprintf(os.Stderr, "QQQQQ: AAAAA: v %016x maxDepth %d shift %d\n", v, ft.maxDepth, (64 - ft.maxDepth)) - return tailRef{ref: v >> (64 - ft.maxDepth), limit: limit} -} - -func (ft *fpTree) tailRefFromNodeAndPrefix(idx nodeIndex, n node, p prefix, limit int) tailRef { - if n.c == 1 { - return ft.tailRefFromFingerprint(idx, n.fp, limit) +// aggregateEdge aggregates an edge of the interval, which can be bounded by x, y, both x +// and y or none of x and y, have a common prefix and optionally bounded by a limit of N of +// aggregated items. +// It returns a boolean indicating whether the limit or the right edge (y) was reached and +// an error, if any. +func (ft *fpTree) aggregateEdge(x, y KeyBytes, p prefix, ac *aggContext) (cont bool, err error) { + ft.log("aggregateEdge: x %s y %s p %s limit %d count %d", x.String(), y.String(), p, ac.limit, ac.count) + defer func() { + ft.log("aggregateEdge ==> limit %d count %d\n", ac.limit, ac.count) + }() + if ac.limit == 0 { + ft.log("aggregateEdge: limit is 0") + return false, nil + } + var startFrom KeyBytes + if x == nil { + startFrom = make(KeyBytes, ft.keyLen) + p.minID(startFrom) } else { - return ft.tailRefFromPrefix(idx, p, limit) + startFrom = x } -} - -func (ft *fpTree) descendToLeftmostLeaf(idx nodeIndex, p prefix) (nodeIndex, prefix) { - switch { - case idx == noIndex: - return noIndex, p - case ft.np.node(idx).leaf(): - return idx, p - default: - return ft.descendToLeftmostLeaf(ft.np.node(idx).left, p.left()) + ft.log("aggregateEdge: startFrom %s", startFrom.String()) + it, err := ft.iter(startFrom) + if err != nil { + if errors.Is(err, errEmptySet) { + ft.log("aggregateEdge: empty set") + return false, nil + } + ft.log("aggregateEdge: error: %v", err) + return false, err + } + + for range ft.np.node(ft.root).c { + id := it.Key().(KeyBytes) + ft.log("aggregateEdge: ID %s", id.String()) + if y != nil && id.Compare(y) >= 0 { + ft.log("aggregateEdge: ID is over Y: %s", id.String()) + return false, nil + } + if !p.match(id) { + ft.log("aggregateEdge: ID doesn't match the prefix: %s", id.String()) + // Got to the end of the tailRef without exhausting the limit + return true, nil + } + ac.fp.update(id) + ac.count++ + if ac.limit > 0 { + ac.limit-- + if ac.limit == 0 { + ft.log("aggregateEdge: limit exhausted") + return false, nil + } + } + if err := it.Next(); err != nil { + ft.log("aggregateEdge: next error: %v", err) + return false, err + } } -} -// func (ft *fpTree) descendToNextLeaf(idx nodeIndex, p, rem prefix) (nodeIndex, prefix) { -// switch { -// case idx == noIndex: -// panic("BUG: descendToNextLeaf: no node") -// case rem == 0: -// return noIndex, p -// case rem.highBit(): -// // Descending to the right branch by following p: -// // the next leaf, if there's any, is further down the right branch. -// newIdx, newP := ft.descendToNextLeaf(ft.np.node(idx).right, p.right(), rem.shift()) -// return newIdx, newP -// default: -// // Descending to the left branch by following p: -// // if the leaf is not found in the left branch, it's the leftmost leaf -// // on the right branch -// newIdx, newP := ft.descendToNextLeaf(ft.np.node(idx).left, p.left(), rem.shift()) -// if newIdx != noIndex { -// return newIdx, newP -// } -// return ft.descendToLeftmostLeaf(ft.np.node(idx).right, p.right()) -// } -// } - -// func (ft *fpTree) nextLeaf(p prefix) (nodeIndex, prefix) { -// if ft.root == noIndex { -// return noIndex, 0 -// } -// return ft.descendToNextLeaf(ft.root, 0, p) -// } + return true, nil +} -func (ft *fpTree) visitNode(idx nodeIndex, p prefix, r *aggResult) (node, bool) { +func (ft *fpTree) node(idx nodeIndex) (node, bool) { if idx == noIndex { return node{}, false } - ft.log("visitNode: idx %d p %s", idx, p) - r.lastVisited = idx - r.lastPrefix = p return ft.np.node(idx), true } -func (ft *fpTree) aggregateUpToLimit(idx nodeIndex, p prefix, r *aggResult) (tailRef *tailRef, cont bool) { - ft.enter("aggregateUpToLimit: idx %d p %s limit %d cur_fp %s cur_count %d", idx, p, r.limit, - r.fp.String(), r.count) +func (ft *fpTree) aggregateUpToLimit(idx nodeIndex, p prefix, ac *aggContext) (cont bool, err error) { + ft.enter("aggregateUpToLimit: idx %d p %s limit %d cur_fp %s cur_count %d", idx, p, ac.limit, + ac.fp.String(), ac.count) defer func() { - ft.leave(r.fp, r.count) + ft.leave(ac.fp, ac.count) }() for { - node, ok := ft.visitNode(idx, p, r) + node, ok := ft.node(idx) switch { case !ok: ft.log("stop: no node") - return nil, true - case r.limit == 0: - // for r.limit == 0, it's important that we still visit the node + return true, nil + case ac.limit == 0: + // for ac.limit == 0, it's important that we still visit the node // so that we can get the item immediately following the included items ft.log("stop: limit exhausted") - return nil, false - case r.limit < 0: - // no limit - ft.log("no limit") - r.update(node) - return nil, true - case node.c <= uint32(r.limit): + return false, nil + case ac.maybeIncludeNode(node): // node is fully included ft.log("included fully") - r.update(node) - r.limit -= int(node.c) - return nil, true + return true, nil case node.leaf(): // reached the limit on this node, do not need to continue after // done with it - tail := ft.tailRefFromPrefix(idx, p, r.takeAtMost(int(node.c))) - ft.log("add prefix to the tails: %016x => limit %d", tail.ref, r.limit) - return &tail, false + cont, err := ft.aggregateEdge(nil, nil, p, ac) + if err != nil { + return false, err + } + if cont { + panic("BUG: expected limit not reached") + } + return false, nil default: pLeft := p.left() - left, haveLeft := ft.visitNode(node.left, pLeft, r) + left, haveLeft := ft.node(node.left) if haveLeft { - if int(left.c) <= r.limit { + if ac.maybeIncludeNode(left) { // left node is fully included, after which // we need to stop somewhere in the right subtree ft.log("include left in full") - r.update(left) - r.limit -= int(left.c) } else { // we must stop somewhere in the left subtree, // and the right subtree is irrelevant @@ -682,161 +622,135 @@ func (ft *fpTree) aggregateUpToLimit(idx nodeIndex, p prefix, r *aggResult) (tai } } -func (ft *fpTree) aggregateLeft(idx nodeIndex, v uint64, p prefix, r *aggResult) (cont bool) { - ft.enter("aggregateLeft: idx %d v %016x p %s limit %d", idx, v, p, r.limit) +func (ft *fpTree) aggregateLeft(idx nodeIndex, v uint64, p prefix, ac *aggContext) (cont bool, err error) { + ft.enter("aggregateLeft: idx %d v %016x p %s limit %d", idx, v, p, ac.limit) defer func() { - ft.leave(r.fp, r.count, r.startTail, r.endTail) + ft.leave(ac.fp, ac.count) }() - node, ok := ft.visitNode(idx, p, r) + node, ok := ft.node(idx) switch { case !ok: - // for r.limit == 0, it's important that we still visit the node + // for ac.limit == 0, it's important that we still visit the node // so that we can get the item immediately following the included items ft.log("stop: no node") - return true - case r.limit == 0: + return true, nil + case ac.limit == 0: ft.log("stop: limit exhausted") - return false - case p.len() == ft.maxDepth: + return false, nil + case ac.prefixWithinRange(p) && ac.maybeIncludeNode(node): + ft.log("including node in full: %s limit %d", p, ac.limit) + return ac.limit != 0, nil + case p.len() == ft.maxDepth || node.leaf(): if node.left != noIndex || node.right != noIndex { panic("BUG: node @ maxDepth has children") } - tail := ft.tailRefFromPrefix(idx, p, r.takeAtMost(int(node.c))) - r.setStartTail(&tail) - ft.log("add prefix to the tails: %016x => limit %d", tail.ref, r.limit) - return true - case node.leaf(): // TBD: combine with prev - // For leaf 1-nodes, we can use the fingerprint to get tailRef - // by which the actual IDs will be selected - if node.c != 1 { - panic("BUG: leaf non-1 node below maxDepth") - } - tail := ft.tailRefFromFingerprint(idx, node.fp, r.takeAtMost(1)) - r.setStartTail(&tail) - ft.log("add prefix to the tails (1-leaf): %016x (fp %s) => limit %d", tail.ref, node.fp, r.limit) - return true + return ft.aggregateEdge(ac.x, nil, p, ac) case v&bit63 == 0: ft.log("incl right node %d + go left to node %d", node.right, node.left) - if !ft.aggregateLeft(node.left, v<<1, p.left(), r) { - return false + cont, err := ft.aggregateLeft(node.left, v<<1, p.left(), ac) + if !cont || err != nil { + return false, err } if node.right != noIndex { - tail, cont := ft.aggregateUpToLimit(node.right, p.right(), r) - if tail != nil { - r.setStartTail(tail) - } - return cont + return ft.aggregateUpToLimit(node.right, p.right(), ac) } - return true + return true, nil default: ft.log("go right to node %d", node.right) - return ft.aggregateLeft(node.right, v<<1, p.right(), r) + return ft.aggregateLeft(node.right, v<<1, p.right(), ac) } } -func (ft *fpTree) aggregateRight(idx nodeIndex, v uint64, p prefix, r *aggResult) (cont bool) { - ft.enter("aggregateRight: idx %d v %016x p %s limit %d", idx, v, p, r.limit) +func (ft *fpTree) aggregateRight(idx nodeIndex, v uint64, p prefix, ac *aggContext) (cont bool, err error) { + ft.enter("aggregateRight: idx %d v %016x p %s limit %d", idx, v, p, ac.limit) defer func() { - ft.leave(r.fp, r.count, r.startTail, r.endTail) + ft.leave(ac.fp, ac.count) }() - node, ok := ft.visitNode(idx, p, r) + node, ok := ft.node(idx) switch { case !ok: - // for r.limit == 0, it's important that we still visit the node + // for ac.limit == 0, it's important that we still visit the node // so that we can get the item immediately following the included items ft.log("stop: no node") - return true - case r.limit == 0: + return true, nil + case ac.limit == 0: ft.log("stop: limit exhausted") - return false - case p.len() == ft.maxDepth: + return false, nil + case ac.prefixWithinRange(p) && ac.maybeIncludeNode(node): + ft.log("including node in full: %s limit %d", p, ac.limit) + return ac.limit != 0, nil + case p.len() == ft.maxDepth || node.leaf(): if node.left != noIndex || node.right != noIndex { panic("BUG: node @ maxDepth has children") } - tail := ft.tailRefFromPrefix(idx, p, r.takeAtMost(int(node.c))) - r.setEndTail(&tail) - ft.log("add prefix to the tails: %016x => limit %d", tail.ref, r.limit) - return true - case node.leaf(): - // For leaf 1-nodes, we can use the fingerprint to get tailRef - // by which the actual IDs will be selected - if node.c != 1 { - panic("BUG: leaf non-1 node below maxDepth") - } - tail := ft.tailRefFromFingerprint(idx, node.fp, r.takeAtMost(1)) - r.setEndTail(&tail) - ft.log("add prefix to the tails (1-leaf): %016x (fp %s) => limit %d", tail.ref, node.fp, r.limit) - return true + return ft.aggregateEdge(nil, ac.y, p, ac) case v&bit63 == 0: ft.log("go left to node %d", node.left) - return ft.aggregateRight(node.left, v<<1, p.left(), r) + return ft.aggregateRight(node.left, v<<1, p.left(), ac) default: ft.log("incl left node %d + go right to node %d", node.left, node.right) if node.left != noIndex { - tail, cont := ft.aggregateUpToLimit(node.left, p.left(), r) - if tail != nil { - r.setEndTail(tail) - } - if !cont { - return false + cont, err := ft.aggregateUpToLimit(node.left, p.left(), ac) + if !cont || err != nil { + return false, err } } - return ft.aggregateRight(node.right, v<<1, p.right(), r) + return ft.aggregateRight(node.right, v<<1, p.right(), ac) } } -func (ft *fpTree) aggregateInterval(x, y KeyBytes, limit int) (r aggResult) { +func (ft *fpTree) aggregateInterval(ac *aggContext) error { ft.rootMtx.Lock() defer ft.rootMtx.Unlock() - ft.enter("aggregateInterval: x %s y %s limit %d", hex.EncodeToString(x), hex.EncodeToString(y), limit) + ft.enter("aggregateInterval: x %s y %s limit %d", ac.x.String(), ac.y.String(), ac.limit) defer func() { - ft.leave(r) + ft.leave(ac) }() - r = aggResult{limit: limit, lastVisited: noIndex} - r.itype = bytes.Compare(x, y) + if ft.root == noIndex { + return nil + } + ac.total = ft.np.node(ft.root).c + ac.itype = bytes.Compare(ac.x, ac.y) switch { - case r.itype == 0: + case ac.itype == 0: // the whole set if ft.root != noIndex { ft.log("whole set") - tail, _ := ft.aggregateUpToLimit(ft.root, 0, &r) - r.setTails(tail, x, y, ft.maxDepth) + _, err := ft.aggregateUpToLimit(ft.root, 0, ac) + return err } else { ft.log("empty set (no root)") } - case r.itype < 0: + + case ac.itype < 0: // "proper" interval: [x; lca); (lca; y) - p := commonPrefix(x, y) - lcaIdx, followedPrefix, found := ft.followPrefix(ft.root, p, 0) - var lcaNode node + p := commonPrefix(ac.x, ac.y) + lcaIdx, lcaPrefix, fullPrefixFound := ft.followPrefix(ft.root, p, 0) + var lca node if lcaIdx != noIndex { - lcaNode = ft.np.node(lcaIdx) + // QQQQQ: TBD: perhaps just return if lcaIdx == noIndex + lca = ft.np.node(lcaIdx) } - ft.log("commonPrefix %s lca %d found %v", p, lcaIdx, found) + ft.log("commonPrefix %s lca %d found %v", p, lcaIdx, fullPrefixFound) switch { - case found && !lcaNode.leaf(): - if followedPrefix != p { + case fullPrefixFound && !lca.leaf(): + if lcaPrefix != p { panic("BUG: bad followedPrefix") } - ft.visitNode(lcaIdx, followedPrefix, &r) - ft.aggregateLeft(lcaNode.left, load64(x)<<(p.len()+1), p.left(), &r) - ft.aggregateRight(lcaNode.right, load64(y)<<(p.len()+1), p.right(), &r) - case lcaIdx != noIndex: - ft.log("commonPrefix %s NOT found but have lca %d", p, lcaIdx) - // Didn't reach LCA in the tree b/c ended up - // at a leaf, just use the prefix to go - // through the IDs - if lcaNode.leaf() { - ft.visitNode(lcaIdx, followedPrefix, &r) - tail := ft.tailRefFromNodeAndPrefix( - lcaIdx, lcaNode, followedPrefix, r.takeAtMost(limit)) - r.setTails(&tail, x, y, ft.maxDepth) - } + ft.aggregateLeft(lca.left, load64(ac.x)<<(p.len()+1), p.left(), ac) + ft.aggregateRight(lca.right, load64(ac.y)<<(p.len()+1), p.right(), ac) + case lcaIdx == noIndex || !lca.leaf(): + ft.log("commonPrefix %s NOT found b/c no items have it", p) + default: + ft.log("commonPrefix %s -- lca %d", p, lcaIdx) + _, err := ft.aggregateEdge(ac.x, ac.y, lcaPrefix, ac) + return err } + default: // inverse interval: [min; y); [x; max] // first, we handle [x; max] part - pf0 := preFirst0(x) + pf0 := preFirst0(ac.x) idx0, followedPrefix, found := ft.followPrefix(ft.root, pf0, 0) var pf0Node node if idx0 != noIndex { @@ -848,18 +762,24 @@ func (ft *fpTree) aggregateInterval(x, y KeyBytes, limit int) (r aggResult) { if followedPrefix != pf0 { panic("BUG: bad followedPrefix") } - ft.aggregateLeft(idx0, load64(x)<= x, it'll be RangeInfo.Start, - // then we clone the iterator (it should be cloneable!) and use it to include the IDs - // in the fingerprint - // For 2nd tailRef we call Next() on the iterator till Key() is >= y, including IDs - // in the fingerprint, then return it as RangeInfo.End (outside the range) - // DOWNSIDE: too much needs to be fetched in case of bigger chunks - // Possible optimization [DONE]: specify max chunk size for db iterator, - // start from 1, increase it 2x on each iteration but not over max chunk size - // QQQQQ: TBD: need to restore combinedIterator - if err := ft.idStore.iterateIDs(tailRefs, func(tailRef tailRef, id KeyBytes) bool { - if idWithinInterval(id, x, y, r.itype) { - r.fp.update(id) - r.count++ - ft.log("tailRef %v: id %s within range => fp %s count %d", - tailRef, - hex.EncodeToString(id), - r.fp.String(), r.count) - wasWithinRange = true - } else { - // if we were within the range but now we're out of it, - // this means we're at or beyond y and can stop - // return !wasWithinRange || noStop - // QQQQQ: rmme - if wasWithinRange { - ft.log("tailRef %v: id %s outside range after id(s) within range => terminating", - tailRef, - hex.EncodeToString(id)) - // TBD: QQQQQ: terminate only for this tailRef - return noStop - } else { - ft.log("tailRef %v: id %s outside range => continuing", - tailRef, - hex.EncodeToString(id)) - return true - } - } - return true - }); err != nil { + ac := aggContext{x: x, y: y, limit: limit} + if err := ft.aggregateInterval(&ac); err != nil { return fpResult{}, err } - return fpResult{fp: r.fp, count: r.count, itype: r.itype}, nil + return fpResult{fp: ac.fp, count: ac.count, itype: ac.itype}, nil } func (ft *fpTree) dumpNode(w io.Writer, idx nodeIndex, indent, dir string) { @@ -975,237 +846,6 @@ func (ft *fpTree) dump(w io.Writer) { } } -type memIDStore struct { - mtx sync.Mutex - ids map[uint64][]KeyBytes - maxDepth int -} - -var _ idStore = &memIDStore{} - -func newMemIDStore(maxDepth int) *memIDStore { - return &memIDStore{maxDepth: maxDepth} -} - -func (m *memIDStore) clone() idStore { - m.mtx.Lock() - defer m.mtx.Unlock() - s := newMemIDStore(m.maxDepth) - if m.ids != nil { - s.ids = make(map[uint64][]KeyBytes, len(m.ids)) - for k, v := range m.ids { - s.ids[k] = slices.Clone(v) - } - } - return s -} - -func (m *memIDStore) registerHash(h KeyBytes) error { - m.mtx.Lock() - defer m.mtx.Unlock() - if m.ids == nil { - m.ids = make(map[uint64][]KeyBytes, 1<> (64 - m.maxDepth) - s := m.ids[idx] - n := slices.IndexFunc(s, func(cur KeyBytes) bool { - return bytes.Compare(cur, h) > 0 - }) - if n < 0 { - m.ids[idx] = append(s, h) - } else { - m.ids[idx] = slices.Insert(s, n, h) - } - return nil -} - -func (m *memIDStore) iterateIDs(tailRefs []tailRef, toCall func(tailRef, KeyBytes) bool) error { - m.mtx.Lock() - defer m.mtx.Unlock() - if m.ids == nil { - return nil - } - // fmt.Fprintf(os.Stderr, "QQQQQ: memIDStore: iterateIDs: maxDepth %d tailRefs %v\n", m.maxDepth, tailRefs) - // fmt.Fprintf(os.Stderr, "QQQQQ: memIDStore: iterateIDs: ids %#v\n", m.ids) - for _, t := range tailRefs { - count := t.limit - if count == 0 { - // fmt.Fprintf(os.Stderr, "QQQQQ: memIDStore: iterateIDs: t %v: count == 0\n", t) - continue - } - for _, id := range m.ids[t.ref] { - if count == 0 { - // fmt.Fprintf(os.Stderr, "QQQQQ: memIDStore: iterateIDs: t %v id %s: count == 0\n", t, hex.EncodeToString(id)) - break - } - if count > 0 { - // fmt.Fprintf(os.Stderr, "QQQQQ: memIDStore: iterateIDs: t %v id %s: dec count\n", t, hex.EncodeToString(id)) - count-- - } - // fmt.Fprintf(os.Stderr, "QQQQQ: memIDStore: iterateIDs: t %v id %s: call\n", t, hex.EncodeToString(id)) - if !toCall(t, id) { - // fmt.Fprintf(os.Stderr, "QQQQQ: memIDStore: iterateIDs: t %v id %s: stop\n", t, hex.EncodeToString(id)) - return nil - } - } - } - return nil -} - -type sqlIDStore struct { - db sql.Database - query string - keyLen int - maxDepth int -} - -var _ idStore = &sqlIDStore{} - -func newSQLIDStore(db sql.Database, query string, keyLen, maxDepth int) *sqlIDStore { - return &sqlIDStore{db: db, query: query, keyLen: keyLen, maxDepth: maxDepth} -} - -func (s *sqlIDStore) clone() idStore { - return newSQLIDStore(s.db, s.query, s.keyLen, s.maxDepth) -} - -func (s *sqlIDStore) registerHash(h KeyBytes) error { - // should be registered by the handler code - return nil -} - -func (s *sqlIDStore) iterateIDs(tailRefs []tailRef, toCall func(tailRef, KeyBytes) bool) error { - cont := true - for _, t := range tailRefs { - if t.limit == 0 { - continue - } - p := mkprefix(t.ref, s.maxDepth) - minID := make([]byte, s.keyLen) - maxID := make([]byte, s.keyLen) - p.minID(minID[:]) - p.maxID(maxID[:]) - // start := time.Now() - query := s.query - if t.limit > 0 { - query += " LIMIT " + strconv.Itoa(t.limit) - } - if _, err := s.db.Exec( - query, - func(stmt *sql.Statement) { - stmt.BindBytes(1, minID) - stmt.BindBytes(2, maxID) - }, - func(stmt *sql.Statement) bool { - id := make(KeyBytes, s.keyLen) - stmt.ColumnBytes(0, id) - cont = toCall(t, id) - return cont - }, - ); err != nil { - return err - } - // fmt.Fprintf(os.Stderr, "QQQQQ: %v: sel atxs between %s and %s\n", time.Now().Sub(start), minID.String(), maxID.String()) - if !cont { - break - } - } - return nil -} - -type dbBackedStore struct { - *sqlIDStore - *memIDStore - maxDepth int -} - -var _ idStore = &dbBackedStore{} - -func newDBBackedStore(db sql.Database, query string, keyLen, maxDepth int) *dbBackedStore { - return &dbBackedStore{ - sqlIDStore: newSQLIDStore(db, query, keyLen, maxDepth), - memIDStore: newMemIDStore(maxDepth), - maxDepth: maxDepth, - } -} - -func (s *dbBackedStore) clone() idStore { - return &dbBackedStore{ - sqlIDStore: s.sqlIDStore.clone().(*sqlIDStore), - memIDStore: s.memIDStore.clone().(*memIDStore), - maxDepth: s.maxDepth, - } -} - -func (s *dbBackedStore) registerHash(h KeyBytes) error { - return s.memIDStore.registerHash(h) -} - -func (s *dbBackedStore) iterateIDs(tailRefs []tailRef, toCall func(tailRef, KeyBytes) bool) error { - type memItem struct { - tailRef tailRef - id KeyBytes - } - var memItems []memItem - s.memIDStore.iterateIDs(tailRefs, func(tailRef tailRef, id KeyBytes) bool { - memItems = append(memItems, memItem{tailRef: tailRef, id: id}) - return true - }) - cont := true - limits := make(map[uint64]int, len(tailRefs)) - for _, t := range tailRefs { - if t.limit >= 0 { - limits[t.ref] += t.limit - } - } - if err := s.sqlIDStore.iterateIDs(tailRefs, func(tailRef tailRef, id KeyBytes) bool { - ref := load64(id) >> (64 - s.maxDepth) - limit, haveLimit := limits[ref] - for len(memItems) > 0 && bytes.Compare(memItems[0].id, id) < 0 { - if haveLimit && limit == 0 { - return false - } - cont = toCall(memItems[0].tailRef, memItems[0].id) - if !cont { - return false - } - limits[ref] = limit - 1 - memItems = memItems[1:] - } - if haveLimit && limit == 0 { - return false - } - cont = toCall(tailRef, id) - limits[ref] = limit - 1 - return cont - }); err != nil { - return err - } - if cont { - for _, mi := range memItems { - ref := load64(mi.id) >> (64 - s.maxDepth) - limit, haveLimit := limits[ref] - if haveLimit && limit == 0 { - break - } - if !toCall(mi.tailRef, mi.id) { - break - } - limits[ref] = limit - 1 - } - } - return nil -} - -func idWithinInterval(id, x, y KeyBytes, itype int) bool { - switch itype { - case 0: - return true - case -1: - return bytes.Compare(id, x) >= 0 && bytes.Compare(id, y) < 0 - default: - return bytes.Compare(id, y) < 0 || bytes.Compare(id, x) >= 0 - } -} - // TBD: optimize, get rid of binary.BigEndian.* +// TBD: QQQQQ: detect unbalancedness when a ref gets too many items +// TBD: QQQQQ: ItemStore.Close(): close db conns, also free fpTree instead of using finalizer! diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index e8b1fe5f06..6ed2648362 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -143,10 +143,11 @@ func TestPrefix(t *testing.T) { tc.p.minID(minID[:]) require.Equal(t, expMinID, minID) - expMaxID := types.HexToHash32(tc.maxID) - var maxID types.Hash32 - tc.p.maxID(maxID[:]) - require.Equal(t, expMaxID, maxID) + // QQQQQ: TBD: rm (probably with maxid fields?) + // expMaxID := types.HexToHash32(tc.maxID) + // var maxID types.Hash32 + // tc.p.maxID(maxID[:]) + // require.Equal(t, expMaxID, maxID) }) } } @@ -195,7 +196,7 @@ type fakeIDDBStore struct { var _ idStore = &fakeIDDBStore{} -const fakeIDQuery = "select id from foo where id between ? and ? order by id" +const fakeIDQuery = "select id from foo where id >= ? order by id limit ?" func newFakeATXIDStore(db sql.Database, maxDepth int) *fakeIDDBStore { return &fakeIDDBStore{db: db, sqlIDStore: newSQLIDStore(db, fakeIDQuery, 32, maxDepth)} @@ -249,6 +250,14 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { count: 5, itype: 0, }, + { + xIdx: 0, + yIdx: 0, + limit: 0, + fp: hexToFingerprint("000000000000000000000000"), + count: 0, + itype: 0, + }, { xIdx: 0, yIdx: 0, @@ -289,6 +298,14 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { count: 3, itype: -1, }, + { + xIdx: 0, + yIdx: 4, + limit: 0, + fp: hexToFingerprint("000000000000000000000000"), + count: 0, + itype: -1, + }, { xIdx: 1, yIdx: 4, @@ -313,6 +330,14 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { count: 3, itype: 1, }, + { + xIdx: 2, + yIdx: 0, + limit: 0, + fp: hexToFingerprint("000000000000000000000000"), + count: 0, + itype: 1, + }, { xIdx: 3, yIdx: 1, @@ -491,7 +516,8 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { t.Run(tc.name, func(t *testing.T) { var np nodePool idStore := makeIDStore(tc.maxDepth) - ft := newFPTree(&np, idStore, tc.maxDepth) + ft := newFPTree(&np, idStore, 32, tc.maxDepth) + // ft.traceEnabled = true var hs []types.Hash32 for _, hex := range tc.ids { t.Logf("add: %s", hex) @@ -521,7 +547,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { fp: rtc.fp, count: rtc.count, itype: rtc.itype, - }, fpr) + }, fpr, "range: x %s y %s", x.String(), y.String()) } ft.release() @@ -532,7 +558,9 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { func TestFPTree(t *testing.T) { t.Run("in-memory id store", func(t *testing.T) { - testFPTree(t, func(maxDepth int) idStore { return newMemIDStore(maxDepth) }) + testFPTree(t, func(maxDepth int) idStore { + return newInMemIDStore(32, maxDepth) + }) }) t.Run("fake ATX store", func(t *testing.T) { db := populateDB(t, 32, nil) @@ -546,7 +574,7 @@ func TestFPTree(t *testing.T) { func TestFPTreeClone(t *testing.T) { var np nodePool - ft1 := newFPTree(&np, newMemIDStore(24), 24) + ft1 := newFPTree(&np, newInMemIDStore(32, 24), 32, 24) hashes := []types.Hash32{ types.HexToHash32("1111111111111111111111111111111111111111111111111111111111111111"), types.HexToHash32("3333333333333333333333333333333333333333333333333333333333333333"), @@ -671,7 +699,7 @@ func repeatTestFPTreeManyItems( } } -func dumbFP(hs hashList, x, y types.Hash32) fpResult { +func dumbFP(hs hashList, x, y types.Hash32, limit int) fpResult { var fpr fpResult fpr.itype = x.Compare(y) switch fpr.itype { @@ -679,33 +707,40 @@ func dumbFP(hs hashList, x, y types.Hash32) fpResult { pX := hs.findGTE(x) pY := hs.findGTE(y) // t.Logf("x=%s y=%s pX=%d y=%d", x.String(), y.String(), pX, pY) - for p := pX; p < pY; p++ { + for p := pX; p < pY && limit != 0; p++ { // t.Logf("XOR %s", hs[p].String()) fpr.fp.update(hs[p][:]) + limit-- + fpr.count++ } - fpr.count = uint32(pY - pX) case 1: pX := hs.findGTE(x) pY := hs.findGTE(y) - for p := 0; p < pY; p++ { + for p := pX; p < len(hs) && limit != 0; p++ { fpr.fp.update(hs[p][:]) + limit-- + fpr.count++ } - for p := pX; p < len(hs); p++ { + for p := 0; p < pY && limit != 0; p++ { fpr.fp.update(hs[p][:]) + limit-- + fpr.count++ } - fpr.count = uint32(pY + len(hs) - pX) default: - for _, h := range hs { + for n, h := range hs { + if limit >= 0 && n >= limit { + break + } fpr.fp.update(h[:]) + fpr.count++ } - fpr.count = uint32(len(hs)) } return fpr } func testFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool, numItems, maxDepth int) { var np nodePool - ft := newFPTree(&np, idStore, maxDepth) + ft := newFPTree(&np, idStore, 32, maxDepth) // ft.traceEnabled = true hs := make(hashList, numItems) var fp fingerprint @@ -734,7 +769,7 @@ func testFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool, numItems, x = hs[rand.Intn(numItems)] y = hs[rand.Intn(numItems)] } - expFPR := dumbFP(hs, x, y) + expFPR := dumbFP(hs, x, y, -1) fpr, err := ft.fingerprintInterval(x[:], y[:], -1) require.NoError(t, err) @@ -751,6 +786,28 @@ func testFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool, numItems, // QQQQQ: /rm require.Equal(t, expFPR, fpr) + + limit := 0 + if fpr.count != 0 { + limit = rand.Intn(int(fpr.count)) + } + expFPR = dumbFP(hs, x, y, limit) + fpr, err = ft.fingerprintInterval(x[:], y[:], limit) + require.NoError(t, err) + + // QQQQQ: rm + if !reflect.DeepEqual(fpr, expFPR) { + t.Logf("QQQQQ: x=%s y=%s", x.String(), y.String()) + for _, h := range hs { + t.Logf("QQQQQ: hash: %s", h.String()) + } + var sb strings.Builder + ft.dump(&sb) + t.Logf("QQQQQ: tree:\n%s", sb.String()) + } + // QQQQQ: /rm + + require.Equal(t, expFPR, fpr, "x=%s y=%s limit=%d", x.String(), y.String(), limit) } } @@ -758,18 +815,18 @@ func TestFPTreeManyItems(t *testing.T) { const ( // numItems = 1 << 16 // maxDepth = 24 - numItems = 1 << 5 + numItems = 1 << 2 // 1 << 5 maxDepth = 4 ) t.Run("bounds from the set", func(t *testing.T) { repeatTestFPTreeManyItems(t, func(maxDepth int) idStore { - return newMemIDStore(maxDepth) + return newInMemIDStore(32, maxDepth) }, false, numItems, maxDepth) }) t.Run("random bounds", func(t *testing.T) { repeatTestFPTreeManyItems(t, func(maxDepth int) idStore { - return newMemIDStore(maxDepth) + return newInMemIDStore(32, maxDepth) }, true, numItems, maxDepth) }) t.Run("SQL, bounds from the set", func(t *testing.T) { @@ -788,6 +845,8 @@ func TestFPTreeManyItems(t *testing.T) { return newFakeATXIDStore(db, maxDepth) }, true, numItems, maxDepth) }) + // TBD: test limits with both random and non-random bounds + // TBD: test start/end iterators } const dbFile = "/Users/ivan4th/Library/Application Support/Spacemesh/node-data/7c8cef2b/state.sql" @@ -820,11 +879,12 @@ const dbFile = "/Users/ivan4th/Library/Application Support/Spacemesh/node-data/7 // } // } -func testATXFP(t *testing.T, maxDepth int) { +func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { // t.Skip("slow tmp test") // counts := make(map[uint64]uint64) // prefLens := make(map[int]int) - db, err := statesql.Open("file:" + dbFile) + // QQQQQ: TBD: reenable schema drift check + db, err := statesql.Open("file:"+dbFile, sql.WithIgnoreSchemaDrift()) require.NoError(t, err) defer db.Close() // _, err = db.Exec("PRAGMA cache_size = -2000000", nil, nil) @@ -833,30 +893,32 @@ func testATXFP(t *testing.T, maxDepth int) { // first := true // where epoch=23 var np nodePool - t.Logf("loading IDs") - var hs []types.Hash32 - _, err = db.Exec("select id from atxs order by id", nil, func(stmt *sql.Statement) bool { - var id types.Hash32 - stmt.ColumnBytes(0, id[:]) - hs = append(hs, id) - // v := load64(id[:]) - // counts[v>>40]++ - // if first { - // first = false - // } else { - // prefLens[bits.LeadingZeros64(prev^v)]++ - // } - // prev = v - return true - }) - require.NoError(t, err) + if *hs == nil { + t.Logf("loading IDs") + _, err = db.Exec("select id from atxs order by id", nil, func(stmt *sql.Statement) bool { + var id types.Hash32 + stmt.ColumnBytes(0, id[:]) + *hs = append(*hs, id) + // v := load64(id[:]) + // counts[v>>40]++ + // if first { + // first = false + // } else { + // prefLens[bits.LeadingZeros64(prev^v)]++ + // } + // prev = v + return true + }) + require.NoError(t, err) + } + // TODO: use testing.B and b.ReportAllocs() runtime.GC() var stats1 runtime.MemStats runtime.ReadMemStats(&stats1) - store := newSQLIDStore(db, "select id from atxs where id between ? and ? order by id", 32, maxDepth) - ft := newFPTree(&np, store, maxDepth) - for _, id := range hs { + store := newSQLIDStore(db, "select id from atxs where id >= ? order by id limit ?", 32, maxDepth) + ft := newFPTree(&np, store, 32, maxDepth) + for _, id := range *hs { ft.addHash(id[:]) } @@ -902,61 +964,44 @@ func testATXFP(t *testing.T, maxDepth int) { x := types.RandomHash() y := types.RandomHash() t.Logf("QQQQQ: x=%s y=%s", x.String(), y.String()) - expFPResult := dumbFP(hs, x, y) + expFPResult := dumbFP(*hs, x, y, -1) //expFPResult := dumbAggATXs(t, db, x, y) fpr, err := ft.fingerprintInterval(x[:], y[:], -1) require.NoError(t, err) require.Equal(t, expFPResult, fpr, "x=%s y=%s", x.String(), y.String()) + + limit := 0 + if fpr.count != 0 { + limit = rand.Intn(int(fpr.count)) + } + t.Logf("QQQQQ: x=%s y=%s limit=%d", x.String(), y.String(), limit) + expFPResult = dumbFP(*hs, x, y, limit) + fpr, err = ft.fingerprintInterval(x[:], y[:], limit) + require.NoError(t, err) + require.Equal(t, expFPResult, fpr, "x=%s y=%s limit=%d", x.String(), y.String(), limit) } + + // x := types.HexToHash32("ab27e01be51af3775fa20299767aef712128a021ffdca7617b31c9ca811376d2") + // y := types.HexToHash32("20c64bc7ea2114babe08380be6cc379aebc715f3820aca52f44fe10748d792a3") + // t.Logf("QQQQQ: x=%s y=%s", x.String(), y.String()) + // expFPResult := dumbFP(hs, x, y) + // //expFPResult := dumbAggATXs(t, db, x, y) + // ft.traceEnabled = true + // fpr, err := ft.fingerprintInterval(x[:], y[:], -1) + // require.NoError(t, err) + // require.Equal(t, expFPResult, fpr, "x=%s y=%s", x.String(), y.String()) } func TestATXFP(t *testing.T) { - t.Skip("slow test") + // t.Skip("slow test") + var hs []types.Hash32 for maxDepth := 15; maxDepth <= 23; maxDepth++ { for i := 0; i < 3; i++ { - testATXFP(t, maxDepth) + testATXFP(t, maxDepth, &hs) } } } -func TestDBBackedStore(t *testing.T) { - // create an in-memory-database, put some ids into it, - // create dbBackedStore, read the ids from the database and check them, - // then add some ids to the dbBackedStore but not to the database, - // and re-check the dbBackedStore contents using iterateIDs method - // use plain sql.InMemory and foo table like in TestDBRangeIterator - initialIDs := []KeyBytes{ - {0, 0, 0, 1, 0, 0, 0, 0}, - {0, 0, 0, 3, 0, 0, 0, 0}, - {0, 0, 0, 5, 0, 0, 0, 0}, - {0, 0, 0, 7, 0, 0, 0, 0}, - } - db := populateDB(t, 8, initialIDs) - store := newDBBackedStore(db, fakeIDQuery, 8, 24) - var actualIDs []KeyBytes - require.NoError(t, store.iterateIDs([]tailRef{{ref: 0, limit: -1}}, func(_ tailRef, id KeyBytes) bool { - actualIDs = append(actualIDs, id) - return true - })) - require.Equal(t, initialIDs, actualIDs) - - require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 2, 0, 0, 0, 0})) - require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 9, 0, 0, 0, 0})) - actualIDs = nil - require.NoError(t, store.iterateIDs([]tailRef{{ref: 0, limit: -1}}, func(_ tailRef, id KeyBytes) bool { - actualIDs = append(actualIDs, id) - return true - })) - require.Equal(t, []KeyBytes{ - {0, 0, 0, 1, 0, 0, 0, 0}, - {0, 0, 0, 2, 0, 0, 0, 0}, - {0, 0, 0, 3, 0, 0, 0, 0}, - {0, 0, 0, 5, 0, 0, 0, 0}, - {0, 0, 0, 7, 0, 0, 0, 0}, - {0, 0, 0, 9, 0, 0, 0, 0}, - }, actualIDs) -} - // benchmarks // maxDepth 18: 94.739µs per range, 10555.290991 ranges/s, heap diff 16621568 diff --git a/sync2/dbsync/inmemidstore.go b/sync2/dbsync/inmemidstore.go new file mode 100644 index 0000000000..f970294d08 --- /dev/null +++ b/sync2/dbsync/inmemidstore.go @@ -0,0 +1,71 @@ +package dbsync + +import ( + "github.com/spacemeshos/go-spacemesh/sync2/hashsync" + "github.com/spacemeshos/go-spacemesh/sync2/internal/skiplist" +) + +type inMemIDStore struct { + sl *skiplist.SkipList + keyLen int + maxDepth int + len int +} + +var _ idStore = &inMemIDStore{} + +func newInMemIDStore(keyLen, maxDepth int) *inMemIDStore { + return &inMemIDStore{ + sl: skiplist.New(keyLen), + keyLen: keyLen, + maxDepth: maxDepth, + } +} + +func (s *inMemIDStore) clone() idStore { + newStore := newInMemIDStore(s.keyLen, s.maxDepth) + for node := s.sl.First(); node != nil; node = node.Next() { + newStore.sl.Add(node.Key()) + } + return newStore +} + +func (s *inMemIDStore) registerHash(h KeyBytes) error { + s.sl.Add(h) + s.len++ + return nil +} + +func (s *inMemIDStore) iter(from KeyBytes) (iterator, error) { + node := s.sl.FindGTENode(from) + if node == nil { + return nil, errEmptySet + } + return &inMemIDStoreIterator{sl: s.sl, node: node}, nil +} + +type inMemIDStoreIterator struct { + sl *skiplist.SkipList + node *skiplist.Node +} + +var _ iterator = &inMemIDStoreIterator{} + +func (it *inMemIDStoreIterator) Key() hashsync.Ordered { + return KeyBytes(it.node.Key()) +} + +func (it *inMemIDStoreIterator) Next() error { + if it.node = it.node.Next(); it.node == nil { + it.node = it.sl.First() + if it.node == nil { + panic("BUG: iterator returned for an empty skiplist") + } + } + return nil +} + +func (it *inMemIDStoreIterator) clone() iterator { + cloned := *it + return &cloned +} diff --git a/sync2/dbsync/inmemidstore_test.go b/sync2/dbsync/inmemidstore_test.go new file mode 100644 index 0000000000..b6943d08e8 --- /dev/null +++ b/sync2/dbsync/inmemidstore_test.go @@ -0,0 +1,78 @@ +package dbsync + +import ( + "encoding/hex" + "testing" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/common/util" + "github.com/stretchr/testify/require" +) + +func TestInMemIDStore(t *testing.T) { + s := newInMemIDStore(32, 24) + + _, err := s.iter(util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")) + require.ErrorIs(t, err, errEmptySet) + + for _, h := range []string{ + "0000000000000000000000000000000000000000000000000000000000000000", + "1234561111111111111111111111111111111111111111111111111111111111", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "8888889999999999999999999999999999999999999999999999999999999999", + "abcdef1234567890000000000000000000000000000000000000000000000000", + } { + s.registerHash(util.FromHex(h)) + } + + for range 2 { + it, err := s.iter( + util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")) + require.NoError(t, err) + var items []string + for range 7 { + items = append(items, hex.EncodeToString(it.Key().(KeyBytes))) + require.NoError(t, it.Next()) + } + require.Equal(t, []string{ + "0000000000000000000000000000000000000000000000000000000000000000", + "1234561111111111111111111111111111111111111111111111111111111111", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "8888889999999999999999999999999999999999999999999999999999999999", + "abcdef1234567890000000000000000000000000000000000000000000000000", + }, items) + require.Equal(t, + "0000000000000000000000000000000000000000000000000000000000000000", + hex.EncodeToString(it.Key().(KeyBytes))) + + s1 := s.clone() + h := types.BytesToHash( + util.FromHex("2000000000000000000000000000000000000000000000000000000000000000")) + s1.registerHash(h[:]) + items = nil + it, err = s1.iter( + util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")) + require.NoError(t, err) + for range 8 { + items = append(items, hex.EncodeToString(it.Key().(KeyBytes))) + require.NoError(t, it.Next()) + } + require.Equal(t, []string{ + "0000000000000000000000000000000000000000000000000000000000000000", + "1234561111111111111111111111111111111111111111111111111111111111", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "2000000000000000000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "8888889999999999999999999999999999999999999999999999999999999999", + "abcdef1234567890000000000000000000000000000000000000000000000000", + }, items) + require.Equal(t, + "0000000000000000000000000000000000000000000000000000000000000000", + hex.EncodeToString(it.Key().(KeyBytes))) + } +} diff --git a/sync2/dbsync/refcountpool.go b/sync2/dbsync/refcountpool.go index 8a112a521e..165f911a25 100644 --- a/sync2/dbsync/refcountpool.go +++ b/sync2/dbsync/refcountpool.go @@ -13,6 +13,11 @@ type poolEntry[T any, I ~uint32] struct { content T } +// rcPool is a reference-counted pool of items. It is safe for concurrent use. +// The zero value is a valid, empty rcPool. +// Unlike sync.Pool, rcPool does not shrink, but uint32 indices can be used +// to reference items instead of larger 64-bit pointers, and the items +// can be shared between type rcPool[T any, I ~uint32] struct { mtx sync.Mutex entries []poolEntry[T, I] diff --git a/sync2/dbsync/sqlidstore.go b/sync2/dbsync/sqlidstore.go new file mode 100644 index 0000000000..320bec9918 --- /dev/null +++ b/sync2/dbsync/sqlidstore.go @@ -0,0 +1,95 @@ +package dbsync + +import ( + "bytes" + "errors" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +const sqlMaxChunkSize = 1024 + +type sqlIDStore struct { + db sql.Database + query string + keyLen int + maxDepth int +} + +var _ idStore = &sqlIDStore{} + +func newSQLIDStore(db sql.Database, query string, keyLen, maxDepth int) *sqlIDStore { + return &sqlIDStore{db: db, query: query, keyLen: keyLen, maxDepth: maxDepth} +} + +func (s *sqlIDStore) clone() idStore { + return newSQLIDStore(s.db, s.query, s.keyLen, s.maxDepth) +} + +func (s *sqlIDStore) registerHash(h KeyBytes) error { + // should be registered by the handler code + return nil +} + +func (s *sqlIDStore) iter(from KeyBytes) (iterator, error) { + if len(from) != s.keyLen { + panic("BUG: invalid key length") + } + return newDBRangeIterator(s.db, s.query, from, sqlMaxChunkSize) +} + +type dbBackedStore struct { + *sqlIDStore + *inMemIDStore + maxDepth int +} + +var _ idStore = &dbBackedStore{} + +func newDBBackedStore(db sql.Database, query string, keyLen, maxDepth int) *dbBackedStore { + return &dbBackedStore{ + sqlIDStore: newSQLIDStore(db, query, keyLen, maxDepth), + inMemIDStore: newInMemIDStore(keyLen, maxDepth), + maxDepth: maxDepth, + } +} + +func (s *dbBackedStore) clone() idStore { + return &dbBackedStore{ + sqlIDStore: s.sqlIDStore.clone().(*sqlIDStore), + inMemIDStore: s.inMemIDStore.clone().(*inMemIDStore), + maxDepth: s.maxDepth, + } +} + +func (s *dbBackedStore) registerHash(h KeyBytes) error { + return s.inMemIDStore.registerHash(h) +} + +func (s *dbBackedStore) iter(from KeyBytes) (iterator, error) { + dbIt, err := s.sqlIDStore.iter(from) + if err != nil { + if errors.Is(err, errEmptySet) { + return s.inMemIDStore.iter(from) + } + return nil, err + } + memIt, err := s.inMemIDStore.iter(from) + if err == nil { + return combineIterators(dbIt, memIt), nil + } else if errors.Is(err, errEmptySet) { + return dbIt, nil + } + return nil, err +} + +func idWithinInterval(id, x, y KeyBytes, itype int) bool { + switch itype { + case 0: + return true + case -1: + return bytes.Compare(id, x) >= 0 && bytes.Compare(id, y) < 0 + default: + return bytes.Compare(id, y) < 0 || bytes.Compare(id, x) >= 0 + } +} diff --git a/sync2/dbsync/sqlidstore_test.go b/sync2/dbsync/sqlidstore_test.go new file mode 100644 index 0000000000..a1a47d3067 --- /dev/null +++ b/sync2/dbsync/sqlidstore_test.go @@ -0,0 +1,50 @@ +package dbsync + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDBBackedStore(t *testing.T) { + initialIDs := []KeyBytes{ + {0, 0, 0, 1, 0, 0, 0, 0}, + {0, 0, 0, 3, 0, 0, 0, 0}, + {0, 0, 0, 5, 0, 0, 0, 0}, + {0, 0, 0, 7, 0, 0, 0, 0}, + } + db := populateDB(t, 8, initialIDs) + store := newDBBackedStore(db, fakeIDQuery, 8, 24) + var actualIDs []KeyBytes + it, err := store.iter(KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}) + require.NoError(t, err) + for range 5 { + actualIDs = append(actualIDs, it.Key().(KeyBytes)) + require.NoError(t, it.Next()) + } + require.Equal(t, []KeyBytes{ + {0, 0, 0, 1, 0, 0, 0, 0}, + {0, 0, 0, 3, 0, 0, 0, 0}, + {0, 0, 0, 5, 0, 0, 0, 0}, + {0, 0, 0, 7, 0, 0, 0, 0}, + {0, 0, 0, 1, 0, 0, 0, 0}, // wrapped around + }, actualIDs) + + require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 2, 0, 0, 0, 0})) + require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 9, 0, 0, 0, 0})) + actualIDs = nil + it, err = store.iter(KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}) + require.NoError(t, err) + for range 6 { + actualIDs = append(actualIDs, it.Key().(KeyBytes)) + require.NoError(t, it.Next()) + } + require.Equal(t, []KeyBytes{ + {0, 0, 0, 1, 0, 0, 0, 0}, + {0, 0, 0, 2, 0, 0, 0, 0}, + {0, 0, 0, 3, 0, 0, 0, 0}, + {0, 0, 0, 5, 0, 0, 0, 0}, + {0, 0, 0, 7, 0, 0, 0, 0}, + {0, 0, 0, 9, 0, 0, 0, 0}, + }, actualIDs) +} diff --git a/sync2/internal/skiplist/skiplist.go b/sync2/internal/skiplist/skiplist.go new file mode 100644 index 0000000000..091e234717 --- /dev/null +++ b/sync2/internal/skiplist/skiplist.go @@ -0,0 +1,185 @@ +package skiplist + +import ( + "bytes" + "hash/maphash" + "math" + "slices" +) + +const ( + maxSLHeight = 24 + pValue = 1 / math.E // ref: https://www.sciencedirect.com/science/article/pii/030439759400296U +) + +func randUint32() uint32 { + // this effectively calls runtime.rand which is much faster than math/rand + return uint32(new(maphash.Hash).Sum64()) +} + +// inspired by https://www.cloudcentric.dev/implementing-a-skip-list-in-go/ +var probabilities = calcProbabilities() + +func calcProbabilities() [maxSLHeight]uint32 { + var probs [maxSLHeight]uint32 + probability := 1.0 + for level := 0; level < maxSLHeight; level++ { + probs[level] = uint32(probability * float64(math.MaxUint32)) + probability *= pValue + } + return probs +} + +func randomHeight() int { + v := randUint32() + height := 1 + for height < maxSLHeight && v <= probabilities[height] { + height++ + } + return height +} + +type Node struct { + key []byte + nextNodes []*Node +} + +func (n *Node) height() int { + return len(n.nextNodes) +} + +// Key returns the key of the node. +func (n *Node) Key() []byte { + return n.key +} + +// Next returns the node following this one, or nil if there's +// no next node. +func (n *Node) Next() *Node { + return n.nextNodes[0] +} + +// SkipList represents an insert-only skip list. +type SkipList struct { + keySize int + head *Node + // TBD: rm: no much sense in this pool as nodes aren't + // released. Global non-sync.Pool pool might make sense + // nodePools [maxSLHeight]sync.Pool +} + +func New(keySize int) *SkipList { + sl := &SkipList{ + keySize: keySize, + head: &Node{}, + } + // for n := range sl.nodePools { + // sl.nodePools[n].New = func() interface{} { + // return &Node{ + // key: make([]byte, keySize), + // nextNodes: make([]*Node, n+1), + // } + // } + // } + return sl +} + +func (sl *SkipList) FindGTENode(key []byte) *Node { + var next, candidate *Node + node := sl.head +OUTER: + for l := sl.head.height() - 1; l >= 0; l-- { + next = node.nextNodes[l] + for next != nil { + switch bytes.Compare(next.key, key) { + case -1: + // The next node is still below target key, advance to it + node = next + next = node.nextNodes[l] + case 0: + // Found an exact match + return next + default: + // The next node is beyond the target key, try to find a + // smaller key that's >= target key on a lower level. + // Failing that, stick with what we found so far. + candidate = next + continue OUTER + } + } + } + + return candidate +} + +func (sl *SkipList) newNode(height int, key []byte) *Node { + // newNode := sl.nodePools[height-1].Get().(*Node) + // copy(newNode.key, key) + newNode := &Node{ + key: slices.Clone(key), + nextNodes: make([]*Node, height), + } + return newNode +} + +// First returns the first node in the skip list. +func (sl *SkipList) First() *Node { + if sl.head.height() == 0 { + return nil + } + return sl.head.Next() +} + +// Add adds key to the skiplist if it's not yet present there. +func (sl *SkipList) Add(key []byte) { + var ( + prevs [maxSLHeight]*Node + next *Node + ) + + height := randomHeight() + newNode := sl.newNode(height, key) + prev := sl.head + oldHeight := sl.head.height() + for l := oldHeight - 1; l >= 0; l-- { + next = prev.nextNodes[l] + INNER: + for next != nil { + switch bytes.Compare(next.key, key) { + case -1: + // The next node is still below target key, advance to it + prev = next + next = next.nextNodes[l] + case 0: + // Exact match, skip adding duplicate entry + return + case 1: + // The next node is beyond the target key, record it + // as the previous node at this level and proceed + // to the lower level. + break INNER + } + } + prevs[l] = prev + } + + sl.grow(height) + for l := range min(height, oldHeight) { + newNode.nextNodes[l] = prevs[l].nextNodes[l] + prevs[l].nextNodes[l] = newNode + } + for l := oldHeight; l < height; l++ { + newNode.nextNodes[l] = nil + sl.head.nextNodes[l] = newNode + } +} + +func (sl *SkipList) grow(newHeight int) { + if newHeight <= sl.head.height() { + return + } + if newHeight > maxSLHeight { + panic("BUG: skiplist height too high") + } + sl.head.nextNodes = append(sl.head.nextNodes, make([]*Node, newHeight-sl.head.height())...) +} diff --git a/sync2/internal/skiplist/skiplist_test.go b/sync2/internal/skiplist/skiplist_test.go new file mode 100644 index 0000000000..806c3af252 --- /dev/null +++ b/sync2/internal/skiplist/skiplist_test.go @@ -0,0 +1,167 @@ +package skiplist + +import ( + "bytes" + crand "crypto/rand" + "fmt" + "math/rand/v2" + "slices" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// checkSanity is based on test from https://github.com/sean-public/fast-skiplist +func checkSanity(t *testing.T, sl *SkipList) { + // each level must be correctly ordered + for l, next := range sl.head.nextNodes { + if next == nil { + continue + } + + if l > len(next.nextNodes) { + t.Fatal("first node's level must be no less than current level") + } + + for next.nextNodes[l] != nil { + require.Positive(t, bytes.Compare(next.nextNodes[l].key, next.key), + "next key value must be greater than prev key value. [next:%v] [prev:%v]", + next.nextNodes[l].key, next.key) + require.GreaterOrEqual(t, len(next.nextNodes), l, + "node's level must be >= it's predecessor level. [cur:%v] [nextNodes:%v]", + l, next.nextNodes) + + next = next.nextNodes[l] + } + } +} + +func enumerate(sl *SkipList) [][]byte { + var r [][]byte + for node := sl.First(); node != nil; node = node.Next() { + r = append(r, node.Key()) + } + return r +} + +func dumpSkipList(t *testing.T, sl *SkipList) { + nodeIndices := make(map[*Node]int) + n := 0 + for node := sl.First(); node != nil; node = node.Next() { + nodeIndices[node] = n + n++ + } + for l, next := range sl.head.nextNodes { + var b strings.Builder + for next != nil { + fmt.Fprintf(&b, " --> %d", nodeIndices[next]) + next = next.nextNodes[l] + } + t.Logf("Level %02d:%s", l, b.String()) + } + + for node := sl.First(); node != nil; node = node.Next() { + t.Logf("Node %d: %v (height %d)", nodeIndices[node], node.Key(), node.height()) + } +} + +func TestSkipList(t *testing.T) { + sl := New(4) + require.Nil(t, sl.First()) + require.Nil(t, sl.FindGTENode([]byte{0, 0, 0, 0})) + require.Nil(t, sl.FindGTENode([]byte{1, 2, 3, 4})) + + for _, v := range [][]byte{ + {0, 0, 0, 0}, + {1, 2, 3, 4}, + {5, 6, 7, 9}, + {100, 200, 120, 1}, + {50, 10, 1, 9}, + {11, 33, 54, 22}, + {8, 3, 9, 5}, + } { + t.Logf("add: %v", v) + sl.Add(v) + // dumpSkipList(t, sl) + checkSanity(t, sl) + } + checkSanity(t, sl) + require.Equal(t, [][]byte{ + {0, 0, 0, 0}, + {1, 2, 3, 4}, + {5, 6, 7, 9}, + {8, 3, 9, 5}, + {11, 33, 54, 22}, + {50, 10, 1, 9}, + {100, 200, 120, 1}, + }, enumerate(sl)) + require.Equal(t, sl.First(), sl.FindGTENode([]byte{0, 0, 0, 0})) + require.Equal(t, []byte{1, 2, 3, 4}, sl.FindGTENode([]byte{1, 2, 3, 4}).Key()) + require.Equal(t, []byte{1, 2, 3, 4}, sl.FindGTENode([]byte{1, 2, 3, 0}).Key()) + require.Equal(t, []byte{50, 10, 1, 9}, sl.FindGTENode([]byte{50, 10, 1, 9}).Key()) + require.Equal(t, []byte{50, 10, 1, 9}, sl.FindGTENode([]byte{50, 0, 0, 0}).Key()) + require.Equal(t, []byte{100, 200, 120, 1}, sl.FindGTENode([]byte{100, 200, 120, 1}).Key()) + require.Equal(t, []byte{100, 200, 120, 1}, sl.FindGTENode([]byte{99, 0, 0, 1}).Key()) + require.Nil(t, sl.FindGTENode([]byte{101, 0, 0, 0})) + + for _, v := range [][]byte{ + {5, 5, 5, 5}, + {100, 200, 120, 1}, + {7, 8, 9, 10}, + {11, 12, 13, 15}, + } { + t.Logf("add: %v", v) + sl.Add(v) + // dumpSkipList(t, sl) + checkSanity(t, sl) + } + checkSanity(t, sl) + require.Equal(t, [][]byte{ + {0, 0, 0, 0}, + {1, 2, 3, 4}, + {5, 5, 5, 5}, + {5, 6, 7, 9}, + {7, 8, 9, 10}, + {8, 3, 9, 5}, + {11, 12, 13, 15}, + {11, 33, 54, 22}, + {50, 10, 1, 9}, + {100, 200, 120, 1}, + }, enumerate(sl)) + require.Equal(t, []byte{11, 12, 13, 15}, sl.FindGTENode([]byte{11, 12, 13, 15}).Key()) + require.Equal(t, []byte{11, 12, 13, 15}, sl.FindGTENode([]byte{11, 10, 1, 1}).Key()) +} + +func TestRandomSkipList(t *testing.T) { + for i := 0; i < 100; i++ { + sl := New(4) + n := rand.IntN(10000) + 1 + expect := make([][]byte, n) + generated := make(map[[4]byte]struct{}) + for j := range expect { + var b [4]byte + for { + _, err := crand.Read(b[:]) + require.NoError(t, err) + if _, ok := generated[b]; !ok { + generated[b] = struct{}{} + break + } + } + expect[j] = slices.Clone(b[:]) + sl.Add(b[:]) + } + checkSanity(t, sl) + slices.SortFunc(expect, func(a, b []byte) int { return bytes.Compare(a, b) }) + require.Equal(t, expect, enumerate(sl)) + for i := 0; i < min(10, n); i++ { + key := expect[rand.IntN(len(expect))] + node := sl.FindGTENode(key) + require.NotNil(t, node) + require.Equal(t, key, node.Key()) + } + } +} + +// TBD: benchmark From cdb3f7a79a922306c727ad397b279dbf17e5967d Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 24 Jul 2024 12:25:09 +0400 Subject: [PATCH 46/76] fptree: return iterators from aggregation --- sync2/dbsync/fptree.go | 441 +++++++++++++-------- sync2/dbsync/fptree_test.go | 638 ++++++++++++++++++++---------- sync2/dbsync/inmemidstore.go | 5 +- sync2/dbsync/inmemidstore_test.go | 7 + 4 files changed, 718 insertions(+), 373 deletions(-) diff --git a/sync2/dbsync/fptree.go b/sync2/dbsync/fptree.go index 38e3bb98da..d299c38afd 100644 --- a/sync2/dbsync/fptree.go +++ b/sync2/dbsync/fptree.go @@ -221,6 +221,18 @@ func (p prefix) minID(b KeyBytes) { } } +func (p prefix) idAfter(b KeyBytes) { + if len(b) < 8 { + panic("BUG: id slice too small") + } + s := uint64(64 - p.len()) + v := (p.bits() + 1) << s + binary.BigEndian.PutUint64(b, v) + for n := 8; n < len(b); n++ { + b[n] = 0xff + } +} + // QQQQQ: rm ? // func (p prefix) maxID(b KeyBytes) { // if len(b) < 8 { @@ -272,39 +284,32 @@ func commonPrefix(a, b KeyBytes) prefix { } type fpResult struct { - fp fingerprint - count uint32 - itype int + fp fingerprint + count uint32 + itype int + start, end iterator } type aggContext struct { - x, y KeyBytes - fp fingerprint - count uint32 - itype int - limit int - total uint32 + x, y KeyBytes + fp fingerprint + count uint32 + itype int + limit int + total uint32 + start, end iterator + lastPrefix *prefix } -func (ac *aggContext) prefixWithinRange(p prefix) bool { - if ac.itype == 0 { - return true - } - x := load64(ac.x) - y := load64(ac.y) - v := p.bits() << (64 - p.len()) - maxV := v + (1 << (64 - p.len())) - 1 - if ac.itype < 0 { - // normal interval - // fmt.Fprintf(os.Stderr, "QQQQQ: (0) itype %d x %016x y %016x v %016x maxV %016x result %v\n", ac.itype, x, y, v, maxV, v >= x && maxV < y) - return v >= x && maxV < y - } - // inverted interval - // fmt.Fprintf(os.Stderr, "QQQQQ: (1) itype %d x %016x y %016x v %016x maxV %016x result %v\n", ac.itype, x, y, v, maxV, v >= x || v < y) - return v >= x || maxV < y +func (ac *aggContext) prefixAtOrAfterX(p prefix) bool { + return p.bits()<<(64-p.len()) >= load64(ac.x) } -func (ac *aggContext) maybeIncludeNode(node node) bool { +func (ac *aggContext) prefixBelowY(p prefix) bool { + return (p.bits()+1)<<(64-p.len())-1 < load64(ac.y) +} + +func (ac *aggContext) maybeIncludeNode(node node, p prefix) bool { switch { case ac.limit < 0: case uint32(ac.limit) < node.c: @@ -314,6 +319,7 @@ func (ac *aggContext) maybeIncludeNode(node node) bool { } ac.fp.update(node.fp[:]) ac.count += node.c + ac.lastPrefix = &p return true } @@ -386,7 +392,7 @@ func (ft *fpTree) clone() *fpTree { } func (ft *fpTree) pushDown(fpA, fpB fingerprint, p prefix, curCount uint32) nodeIndex { - // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: fpA %s fpB %s p %s\n", fpA.String(), fpB.String(), p) + // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: fpA %s fpB %s p %s\n", fpA.(fmt.Stringer), fpB.(fmt.Stringer), p) fpCombined := fpA fpCombined.update(fpB[:]) if ft.maxDepth != 0 && p.len() == ft.maxDepth { @@ -428,7 +434,7 @@ func (ft *fpTree) pushDown(fpA, fpB fingerprint, p prefix, curCount uint32) node func (ft *fpTree) addValue(fp fingerprint, p prefix, idx nodeIndex) nodeIndex { if idx == noIndex { r := ft.np.add(fp, 1, noIndex, noIndex) - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: addNew fp %s p %s => %d\n", fp.String(), p.String(), r) + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: addNew fp %s p %s => %d\n", fp.(fmt.Stringer), p.(fmt.Stringer), r) return r } node := ft.np.node(idx) @@ -437,30 +443,30 @@ func (ft *fpTree) addValue(fp fingerprint, p prefix, idx nodeIndex) nodeIndex { // we're at a leaf node, need to push down the old fingerprint, or, // if we've reached the max depth, just update the current node r := ft.pushDown(fp, node.fp, p, node.c) - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: pushDown fp %s p %s oldIdx %d => %d\n", fp.String(), p.String(), idx, r) + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: pushDown fp %s p %s oldIdx %d => %d\n", fp.(fmt.Stringer), p.(fmt.Stringer), idx, r) return r } fpCombined := fp fpCombined.update(node.fp[:]) if fp.bitFromLeft(p.len()) { - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceRight fp %s p %s oldIdx %d\n", fp.String(), p.String(), idx) + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceRight fp %s p %s oldIdx %d\n", fp.(fmt.Stringer), p.(fmt.Stringer), idx) if node.left != noIndex { ft.np.ref(node.left) // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: ref left %d -- refCount %d\n", node.left, ft.np.entry(node.left).refCount) } newRight := ft.addValue(fp, p.right(), node.right) r := ft.np.add(fpCombined, node.c+1, node.left, newRight) - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceRight fp %s p %s oldIdx %d => %d node.left %d newRight %d\n", fp.String(), p.String(), idx, r, node.left, newRight) + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceRight fp %s p %s oldIdx %d => %d node.left %d newRight %d\n", fp.(fmt.Stringer), p.(fmt.Stringer), idx, r, node.left, newRight) return r } else { - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceLeft fp %s p %s oldIdx %d\n", fp.String(), p.String(), idx) + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceLeft fp %s p %s oldIdx %d\n", fp.(fmt.Stringer), p.(fmt.Stringer), idx) if node.right != noIndex { // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: ref right %d -- refCount %d\n", node.right, ft.np.entry(node.right).refCount) ft.np.ref(node.right) } newLeft := ft.addValue(fp, p.left(), node.left) r := ft.np.add(fpCombined, node.c+1, newLeft, node.right) - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceLeft fp %s p %s oldIdx %d => %d newLeft %d node.right %d\n", fp.String(), p.String(), idx, r, newLeft, node.right) + // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceLeft fp %s p %s oldIdx %d => %d newLeft %d node.right %d\n", fp.(fmt.Stringer), p.(fmt.Stringer), idx, r, newLeft, node.right) return r } } @@ -482,19 +488,24 @@ func (ft *fpTree) followPrefix(from nodeIndex, p, followed prefix) (idx nodeInde ft.enter("followPrefix: from %d p %s highBit %v", from, p, p.highBit()) defer func() { ft.leave(idx, rp, found) }() - // QQQQQ: refactor into a loop - switch { - case from == noIndex: - return noIndex, followed, false - case p == 0: - return from, followed, true - case ft.np.node(from).leaf(): - return from, followed, false - case p.highBit(): - return ft.followPrefix(ft.np.node(from).right, p.shift(), followed.right()) - default: - return ft.followPrefix(ft.np.node(from).left, p.shift(), followed.left()) + for from != noIndex { + switch { + case p == 0: + return from, followed, true + case ft.np.node(from).leaf(): + return from, followed, false + case p.highBit(): + from = ft.np.node(from).right + p = p.shift() + followed = followed.right() + default: + from = ft.np.node(from).left + p = p.shift() + followed = followed.left() + } } + + return noIndex, followed, false } // aggregateEdge aggregates an edge of the interval, which can be bounded by x, y, both x @@ -503,12 +514,12 @@ func (ft *fpTree) followPrefix(from nodeIndex, p, followed prefix) (idx nodeInde // It returns a boolean indicating whether the limit or the right edge (y) was reached and // an error, if any. func (ft *fpTree) aggregateEdge(x, y KeyBytes, p prefix, ac *aggContext) (cont bool, err error) { - ft.log("aggregateEdge: x %s y %s p %s limit %d count %d", x.String(), y.String(), p, ac.limit, ac.count) + ft.log("aggregateEdge: x %s y %s p %s limit %d count %d", x, y, p, ac.limit, ac.count) defer func() { ft.log("aggregateEdge ==> limit %d count %d\n", ac.limit, ac.count) }() - if ac.limit == 0 { - ft.log("aggregateEdge: limit is 0") + if ac.limit == 0 && ac.end != nil { + ft.log("aggregateEdge: limit is 0 and end already set") return false, nil } var startFrom KeyBytes @@ -518,7 +529,7 @@ func (ft *fpTree) aggregateEdge(x, y KeyBytes, p prefix, ac *aggContext) (cont b } else { startFrom = x } - ft.log("aggregateEdge: startFrom %s", startFrom.String()) + ft.log("aggregateEdge: startFrom %s", startFrom) it, err := ft.iter(startFrom) if err != nil { if errors.Is(err, errEmptySet) { @@ -528,32 +539,45 @@ func (ft *fpTree) aggregateEdge(x, y KeyBytes, p prefix, ac *aggContext) (cont b ft.log("aggregateEdge: error: %v", err) return false, err } + if ac.limit == 0 { + ac.end = it.clone() + if x != nil { + ac.start = ac.end + } + ft.log("aggregateEdge: limit is 0 at %s", ac.end.Key().(fmt.Stringer)) + return false, nil + } + if x != nil { + ac.start = it.clone() + } for range ft.np.node(ft.root).c { id := it.Key().(KeyBytes) - ft.log("aggregateEdge: ID %s", id.String()) + ft.log("aggregateEdge: ID %s", id) if y != nil && id.Compare(y) >= 0 { - ft.log("aggregateEdge: ID is over Y: %s", id.String()) + ac.end = it + ft.log("aggregateEdge: ID is over Y: %s", id) return false, nil } if !p.match(id) { - ft.log("aggregateEdge: ID doesn't match the prefix: %s", id.String()) - // Got to the end of the tailRef without exhausting the limit + ft.log("aggregateEdge: ID doesn't match the prefix: %s", id) + ac.lastPrefix = &p return true, nil } ac.fp.update(id) ac.count++ if ac.limit > 0 { ac.limit-- - if ac.limit == 0 { - ft.log("aggregateEdge: limit exhausted") - return false, nil - } } if err := it.Next(); err != nil { - ft.log("aggregateEdge: next error: %v", err) + ft.log("aggregateEdge: Next failed: %v", err) return false, err } + if ac.limit == 0 { + ac.end = it + ft.log("aggregateEdge: limit exhausted") + return false, nil + } } return true, nil @@ -566,9 +590,25 @@ func (ft *fpTree) node(idx nodeIndex) (node, bool) { return ft.np.node(idx), true } +// QQQQQ: rm +// func (ft *fpTree) markEnd(p prefix, ac *aggContext) error { +// if ac.end != nil { +// return nil +// } +// k := make(KeyBytes, ft.keyLen) +// p.minID(k) +// it, err := ft.iter(k) +// if err != nil { +// return err +// } +// ac.end = it +// ft.log("markEnd: p %s k %s => %s", p, k, it.Key().(fmt.Stringer)) +// return nil +// } + func (ft *fpTree) aggregateUpToLimit(idx nodeIndex, p prefix, ac *aggContext) (cont bool, err error) { ft.enter("aggregateUpToLimit: idx %d p %s limit %d cur_fp %s cur_count %d", idx, p, ac.limit, - ac.fp.String(), ac.count) + ac.fp, ac.count) defer func() { ft.leave(ac.fp, ac.count) }() @@ -578,12 +618,12 @@ func (ft *fpTree) aggregateUpToLimit(idx nodeIndex, p prefix, ac *aggContext) (c case !ok: ft.log("stop: no node") return true, nil - case ac.limit == 0: - // for ac.limit == 0, it's important that we still visit the node - // so that we can get the item immediately following the included items - ft.log("stop: limit exhausted") - return false, nil - case ac.maybeIncludeNode(node): + // case ac.limit == 0: + // // for ac.limit == 0, it's important that we still visit the node + // // so that we can get the item immediately following the included items + // ft.log("stop: limit exhausted") + // return false, ft.markEnd(p, ac) + case ac.maybeIncludeNode(node, p): // node is fully included ft.log("included fully") return true, nil @@ -602,7 +642,7 @@ func (ft *fpTree) aggregateUpToLimit(idx nodeIndex, p prefix, ac *aggContext) (c pLeft := p.left() left, haveLeft := ft.node(node.left) if haveLeft { - if ac.maybeIncludeNode(left) { + if ac.maybeIncludeNode(left, pLeft) { // left node is fully included, after which // we need to stop somewhere in the right subtree ft.log("include left in full") @@ -633,11 +673,12 @@ func (ft *fpTree) aggregateLeft(idx nodeIndex, v uint64, p prefix, ac *aggContex // for ac.limit == 0, it's important that we still visit the node // so that we can get the item immediately following the included items ft.log("stop: no node") - return true, nil - case ac.limit == 0: - ft.log("stop: limit exhausted") - return false, nil - case ac.prefixWithinRange(p) && ac.maybeIncludeNode(node): + // QQQQQ: no mark end.... + return true, nil //ft.markEnd(p, ac) + // case ac.limit == 0: + // ft.log("stop: limit exhausted") + // return false, ft.markEnd(p, ac) + case ac.prefixAtOrAfterX(p) && ac.maybeIncludeNode(node, p): ft.log("including node in full: %s limit %d", p, ac.limit) return ac.limit != 0, nil case p.len() == ft.maxDepth || node.leaf(): @@ -669,14 +710,12 @@ func (ft *fpTree) aggregateRight(idx nodeIndex, v uint64, p prefix, ac *aggConte node, ok := ft.node(idx) switch { case !ok: - // for ac.limit == 0, it's important that we still visit the node - // so that we can get the item immediately following the included items ft.log("stop: no node") return true, nil - case ac.limit == 0: - ft.log("stop: limit exhausted") - return false, nil - case ac.prefixWithinRange(p) && ac.maybeIncludeNode(node): + // case ac.limit == 0: + // ft.log("stop: limit exhausted") + // return false, ft.markEnd(p, ac) + case ac.prefixBelowY(p) && ac.maybeIncludeNode(node, p): ft.log("including node in full: %s limit %d", p, ac.limit) return ac.limit != 0, nil case p.len() == ft.maxDepth || node.leaf(): @@ -699,124 +738,184 @@ func (ft *fpTree) aggregateRight(idx nodeIndex, v uint64, p prefix, ac *aggConte } } -func (ft *fpTree) aggregateInterval(ac *aggContext) error { - ft.rootMtx.Lock() - defer ft.rootMtx.Unlock() - ft.enter("aggregateInterval: x %s y %s limit %d", ac.x.String(), ac.y.String(), ac.limit) - defer func() { - ft.leave(ac) - }() +func (ft *fpTree) aggregateXX(ac *aggContext) error { + // [x; x) interval which denotes the whole set unless + // the limit is specified, in which case we need to start aggregating + // with x and wrap around if necessary if ft.root == noIndex { - return nil + ft.log("empty set (no root)") + } else if ac.maybeIncludeNode(ft.np.node(ft.root), 0) { + ft.log("whole set") + } else { + // We need to aggregate up to ac.limit number of items starting + // from x and wrapping around if necessary + return ft.aggregateInverse(ac) } - ac.total = ft.np.node(ft.root).c - ac.itype = bytes.Compare(ac.x, ac.y) + return nil +} + +func (ft *fpTree) aggregateSimple(ac *aggContext) error { + // "proper" interval: [x; lca); (lca; y) + p := commonPrefix(ac.x, ac.y) + lcaIdx, lcaPrefix, fullPrefixFound := ft.followPrefix(ft.root, p, 0) + var lca node + if lcaIdx != noIndex { + // QQQQQ: TBD: perhaps just return if lcaIdx == noIndex + lca = ft.np.node(lcaIdx) + } + ft.log("commonPrefix %s lca %d found %v", p, lcaIdx, fullPrefixFound) switch { - case ac.itype == 0: - // the whole set - if ft.root != noIndex { - ft.log("whole set") - _, err := ft.aggregateUpToLimit(ft.root, 0, ac) - return err - } else { - ft.log("empty set (no root)") + case fullPrefixFound && !lca.leaf(): + if lcaPrefix != p { + panic("BUG: bad followedPrefix") } + ft.aggregateLeft(lca.left, load64(ac.x)<<(p.len()+1), p.left(), ac) + ft.aggregateRight(lca.right, load64(ac.y)<<(p.len()+1), p.right(), ac) + case lcaIdx == noIndex || !lca.leaf(): + ft.log("commonPrefix %s NOT found b/c no items have it", p) + default: + ft.log("commonPrefix %s -- lca %d", p, lcaIdx) + _, err := ft.aggregateEdge(ac.x, ac.y, lcaPrefix, ac) + return err + } + return nil +} - case ac.itype < 0: - // "proper" interval: [x; lca); (lca; y) - p := commonPrefix(ac.x, ac.y) - lcaIdx, lcaPrefix, fullPrefixFound := ft.followPrefix(ft.root, p, 0) - var lca node - if lcaIdx != noIndex { - // QQQQQ: TBD: perhaps just return if lcaIdx == noIndex - lca = ft.np.node(lcaIdx) - } - ft.log("commonPrefix %s lca %d found %v", p, lcaIdx, fullPrefixFound) - switch { - case fullPrefixFound && !lca.leaf(): - if lcaPrefix != p { - panic("BUG: bad followedPrefix") - } - ft.aggregateLeft(lca.left, load64(ac.x)<<(p.len()+1), p.left(), ac) - ft.aggregateRight(lca.right, load64(ac.y)<<(p.len()+1), p.right(), ac) - case lcaIdx == noIndex || !lca.leaf(): - ft.log("commonPrefix %s NOT found b/c no items have it", p) - default: - ft.log("commonPrefix %s -- lca %d", p, lcaIdx) - _, err := ft.aggregateEdge(ac.x, ac.y, lcaPrefix, ac) - return err +func (ft *fpTree) aggregateInverse(ac *aggContext) error { + // inverse interval: [min; y); [x; max] + // first, we handle [x; max] part + pf0 := preFirst0(ac.x) + idx0, followedPrefix, found := ft.followPrefix(ft.root, pf0, 0) + var pf0Node node + if idx0 != noIndex { + pf0Node = ft.np.node(idx0) + } + ft.log("pf0 %s idx0 %d found %v", pf0, idx0, found) + switch { + case found && !pf0Node.leaf(): + if followedPrefix != pf0 { + panic("BUG: bad followedPrefix") } - + ft.aggregateLeft(idx0, load64(ac.x)<= pY || limit == 0 { + fpr.end = hs.keyAt(p) + break + } // t.Logf("XOR %s", hs[p].String()) - fpr.fp.update(hs[p][:]) + fpr.fp.update(hs.keyAt(p)) limit-- fpr.count++ + p++ } case 1: - pX := hs.findGTE(x) - pY := hs.findGTE(y) - for p := pX; p < len(hs) && limit != 0; p++ { - fpr.fp.update(hs[p][:]) + p := hs.findGTE(x) + fpr.start = hs.keyAt(p) + for { + if p >= len(hs) || limit == 0 { + fpr.end = hs.keyAt(p) + break + } + fpr.fp.update(hs.keyAt(p)) limit-- fpr.count++ + p++ } - for p := 0; p < pY && limit != 0; p++ { - fpr.fp.update(hs[p][:]) + if limit == 0 { + return fpr + } + pY := hs.findGTE(y) + p = 0 + for { + if p == pY || limit == 0 { + fpr.end = hs.keyAt(p) + break + } + fpr.fp.update(hs.keyAt(p)) limit-- fpr.count++ + p++ } default: - for n, h := range hs { - if limit >= 0 && n >= limit { + pX := hs.findGTE(x) + p := pX + fpr.start = hs.keyAt(p) + fpr.end = fpr.start + for { + if limit == 0 { + fpr.end = hs.keyAt(p) break } - fpr.fp.update(h[:]) + fpr.fp.update(hs.keyAt(p)) + limit-- fpr.count++ + p = (p + 1) % l + if p == pX { + break + } } } return fpr } -func testFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool, numItems, maxDepth int) { +func testFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool, numItems, maxDepth, repeat int) { var np nodePool ft := newFPTree(&np, idStore, 32, maxDepth) // ft.traceEnabled = true @@ -758,8 +968,10 @@ func testFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool, numItems, fpr, err := ft.fingerprintInterval(hs[0][:], hs[0][:], -1) require.NoError(t, err) - require.Equal(t, fpResult{fp: fp, count: uint32(numItems), itype: 0}, fpr) - for i := 0; i < 100; i++ { + require.Equal(t, fp, fpr.fp, "fp") + require.Equal(t, uint32(numItems), fpr.count, "count") + require.Equal(t, 0, fpr.itype, "itype") + for i := 0; i < repeat; i++ { // TBD: allow reverse order var x, y types.Hash32 if randomXY { @@ -774,7 +986,7 @@ func testFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool, numItems, require.NoError(t, err) // QQQQQ: rm - if !reflect.DeepEqual(fpr, expFPR) { + if !reflect.DeepEqual(toFPResultWithBounds(fpr), expFPR) { t.Logf("QQQQQ: x=%s y=%s", x.String(), y.String()) for _, h := range hs { t.Logf("QQQQQ: hash: %s", h.String()) @@ -785,7 +997,8 @@ func testFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool, numItems, } // QQQQQ: /rm - require.Equal(t, expFPR, fpr) + require.Equal(t, expFPR, toFPResultWithBounds(fpr), + "x=%s y=%s", x.String(), y.String()) limit := 0 if fpr.count != 0 { @@ -796,7 +1009,7 @@ func testFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool, numItems, require.NoError(t, err) // QQQQQ: rm - if !reflect.DeepEqual(fpr, expFPR) { + if !reflect.DeepEqual(toFPResultWithBounds(fpr), expFPR) { t.Logf("QQQQQ: x=%s y=%s", x.String(), y.String()) for _, h := range hs { t.Logf("QQQQQ: hash: %s", h.String()) @@ -807,43 +1020,50 @@ func testFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool, numItems, } // QQQQQ: /rm - require.Equal(t, expFPR, fpr, "x=%s y=%s limit=%d", x.String(), y.String(), limit) + require.Equal(t, expFPR, toFPResultWithBounds(fpr), + "x=%s y=%s limit=%d", x.String(), y.String(), limit) } } func TestFPTreeManyItems(t *testing.T) { const ( - // numItems = 1 << 16 - // maxDepth = 24 - numItems = 1 << 2 // 1 << 5 - maxDepth = 4 + repeatOuter = 30 + repeatInner = 20 + numItems = 1 << 13 + maxDepth = 12 + // numItems = 1 << 5 + // maxDepth = 4 ) t.Run("bounds from the set", func(t *testing.T) { + t.Parallel() repeatTestFPTreeManyItems(t, func(maxDepth int) idStore { return newInMemIDStore(32, maxDepth) - }, false, numItems, maxDepth) + }, false, numItems, maxDepth, repeatOuter, repeatInner) }) t.Run("random bounds", func(t *testing.T) { + t.Parallel() repeatTestFPTreeManyItems(t, func(maxDepth int) idStore { return newInMemIDStore(32, maxDepth) - }, true, numItems, maxDepth) + }, true, numItems, maxDepth, repeatOuter, repeatInner) }) t.Run("SQL, bounds from the set", func(t *testing.T) { + t.Parallel() db := populateDB(t, 32, nil) repeatTestFPTreeManyItems(t, func(maxDepth int) idStore { _, err := db.Exec("delete from foo", nil, nil) require.NoError(t, err) return newFakeATXIDStore(db, maxDepth) - }, false, numItems, maxDepth) + }, false, numItems, maxDepth, repeatOuter, repeatInner) }) t.Run("SQL, random bounds", func(t *testing.T) { + t.Parallel() db := populateDB(t, 32, nil) repeatTestFPTreeManyItems(t, func(maxDepth int) idStore { _, err := db.Exec("delete from foo", nil, nil) require.NoError(t, err) return newFakeATXIDStore(db, maxDepth) - }, true, numItems, maxDepth) + }, true, numItems, maxDepth, repeatOuter, repeatInner) }) // TBD: test limits with both random and non-random bounds // TBD: test start/end iterators @@ -968,7 +1188,8 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { //expFPResult := dumbAggATXs(t, db, x, y) fpr, err := ft.fingerprintInterval(x[:], y[:], -1) require.NoError(t, err) - require.Equal(t, expFPResult, fpr, "x=%s y=%s", x.String(), y.String()) + require.Equal(t, expFPResult, toFPResultWithBounds(fpr), + "x=%s y=%s", x.String(), y.String()) limit := 0 if fpr.count != 0 { @@ -978,13 +1199,14 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { expFPResult = dumbFP(*hs, x, y, limit) fpr, err = ft.fingerprintInterval(x[:], y[:], limit) require.NoError(t, err) - require.Equal(t, expFPResult, fpr, "x=%s y=%s limit=%d", x.String(), y.String(), limit) + require.Equal(t, expFPResult, toFPResultWithBounds(fpr), + "x=%s y=%s limit=%d", x.String(), y.String(), limit) } - // x := types.HexToHash32("ab27e01be51af3775fa20299767aef712128a021ffdca7617b31c9ca811376d2") - // y := types.HexToHash32("20c64bc7ea2114babe08380be6cc379aebc715f3820aca52f44fe10748d792a3") - // t.Logf("QQQQQ: x=%s y=%s", x.String(), y.String()) - // expFPResult := dumbFP(hs, x, y) + // x := types.HexToHash32("930a069661bf21b52aa79a4b5149ecc1190282f1386b6b8ae6b738153a7a802d") + // y := types.HexToHash32("6c966fc65c07c92e869b7796b2346a33e01c4fe38c25094a480cdcd2e7df1f56") + // t.Logf("QQQQQ: maxDepth=%d x=%s y=%s", maxDepth, x.String(), y.String()) + // expFPResult := dumbFP(*hs, x, y, -1) // //expFPResult := dumbAggATXs(t, db, x, y) // ft.traceEnabled = true // fpr, err := ft.fingerprintInterval(x[:], y[:], -1) @@ -1073,3 +1295,17 @@ func TestATXFP(t *testing.T) { // TODO: QQQQQ: retrieve the end of the interval w/count in fpTree.fingerprintInterval() // TODO: QQQQQ: test limits in TestInMemFPTreeManyItems (sep test cases SQL / non-SQL) // TODO: the returned RangeInfo.End iterators should be cyclic + +// TBD: random off-by-1 failure? +// --- Expected +// +++ Actual +// @@ -2,5 +2,5 @@ +// fp: (dbsync.fingerprint) (len=12) { +// - 00000000 30 d4 db 9d b9 15 dd ad 75 1e 67 fd |0.......u.g.| +// + 00000000 a3 de 4b 89 7b 93 fc 76 24 88 82 b2 |..K.{..v$...| +// }, +// - count: (uint32) 41784134, +// + count: (uint32) 41784135, +// itype: (int) 1 +// Test: TestATXFP +// Messages: x=930a069661bf21b52aa79a4b5149ecc1190282f1386b6b8ae6b738153a7a802d y=6c966fc65c07c92e869b7796b2346a33e01c4fe38c25094a480cdcd2e7df1f56 diff --git a/sync2/dbsync/inmemidstore.go b/sync2/dbsync/inmemidstore.go index f970294d08..1dc6519ba7 100644 --- a/sync2/dbsync/inmemidstore.go +++ b/sync2/dbsync/inmemidstore.go @@ -39,7 +39,10 @@ func (s *inMemIDStore) registerHash(h KeyBytes) error { func (s *inMemIDStore) iter(from KeyBytes) (iterator, error) { node := s.sl.FindGTENode(from) if node == nil { - return nil, errEmptySet + node = s.sl.First() + if node == nil { + return nil, errEmptySet + } } return &inMemIDStoreIterator{sl: s.sl, node: node}, nil } diff --git a/sync2/dbsync/inmemidstore_test.go b/sync2/dbsync/inmemidstore_test.go index b6943d08e8..59650ae723 100644 --- a/sync2/dbsync/inmemidstore_test.go +++ b/sync2/dbsync/inmemidstore_test.go @@ -74,5 +74,12 @@ func TestInMemIDStore(t *testing.T) { require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(it.Key().(KeyBytes))) + + it, err = s1.iter( + util.FromHex("fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0")) + require.NoError(t, err) + require.Equal(t, + "0000000000000000000000000000000000000000000000000000000000000000", + hex.EncodeToString(it.Key().(KeyBytes))) } } From a77e698d31649df764f15153a84a0c788443c345 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 24 Jul 2024 18:28:24 +0400 Subject: [PATCH 47/76] fptree: test and fix empty range handling --- sync2/dbsync/fptree.go | 2 +- sync2/dbsync/fptree_test.go | 54 ++++++++++++++++++++++++++++++++++--- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/sync2/dbsync/fptree.go b/sync2/dbsync/fptree.go index d299c38afd..6613ae88e8 100644 --- a/sync2/dbsync/fptree.go +++ b/sync2/dbsync/fptree.go @@ -849,11 +849,11 @@ func (ft *fpTree) aggregateInterval(ac *aggContext) error { defer func() { ft.leave(ac) }() + ac.itype = bytes.Compare(ac.x, ac.y) if ft.root == noIndex { return nil } ac.total = ft.np.node(ft.root).c - ac.itype = bytes.Compare(ac.x, ac.y) switch ac.itype { case 0: return ft.aggregateXX(ac) diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index 828999f1d8..303f386e1d 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -232,7 +232,53 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { ranges []rangeTestCase x, y string }{ - // TBD: QQQQQ: test empty set + { + name: "empty", + maxDepth: 24, + ids: nil, + ranges: []rangeTestCase{ + { + x: "123456789abcdef0000000000000000000000000000000000000000000000000", + y: "123456789abcdef0000000000000000000000000000000000000000000000000", + limit: -1, + fp: hexToFingerprint("000000000000000000000000"), + count: 0, + itype: 0, + startIdx: -1, + endIdx: -1, + }, + { + x: "123456789abcdef0000000000000000000000000000000000000000000000000", + y: "123456789abcdef0000000000000000000000000000000000000000000000000", + limit: 1, + fp: hexToFingerprint("000000000000000000000000"), + count: 0, + itype: 0, + startIdx: -1, + endIdx: -1, + }, + { + x: "123456789abcdef0000000000000000000000000000000000000000000000000", + y: "223456789abcdef0000000000000000000000000000000000000000000000000", + limit: 1, + fp: hexToFingerprint("000000000000000000000000"), + count: 0, + itype: -1, + startIdx: -1, + endIdx: -1, + }, + { + x: "223456789abcdef0000000000000000000000000000000000000000000000000", + y: "123456789abcdef0000000000000000000000000000000000000000000000000", + limit: 1, + fp: hexToFingerprint("000000000000000000000000"), + count: 0, + itype: 1, + startIdx: -1, + endIdx: -1, + }, + }, + }, { name: "ids1", maxDepth: 24, @@ -836,7 +882,9 @@ func checkNode(t *testing.T, ft *fpTree, idx nodeIndex, depth int) { func checkTree(t *testing.T, ft *fpTree, maxDepth int) { require.Equal(t, maxDepth, ft.maxDepth) - checkNode(t, ft, ft.root, 0) + if ft.root != noIndex { + checkNode(t, ft, ft.root, 0) + } } func repeatTestFPTreeManyItems( @@ -1215,7 +1263,7 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { } func TestATXFP(t *testing.T) { - // t.Skip("slow test") + t.Skip("slow test") var hs []types.Hash32 for maxDepth := 15; maxDepth <= 23; maxDepth++ { for i := 0; i < 3; i++ { From 53a241c90bf50bcbf784ba42b8929f6f21109c31 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 24 Jul 2024 19:47:07 +0400 Subject: [PATCH 48/76] fptree: add idStore.start() method --- sync2/dbsync/dbitemstore.go | 33 ++++++++++++++++++------------- sync2/dbsync/fptree.go | 1 + sync2/dbsync/inmemidstore.go | 8 ++++++++ sync2/dbsync/inmemidstore_test.go | 19 ++++++++++++++---- sync2/dbsync/sqlidstore.go | 23 ++++++++++++++++++++- sync2/dbsync/sqlidstore_test.go | 7 +++++++ 6 files changed, 72 insertions(+), 19 deletions(-) diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go index 7f8e34228e..1e4cd63d7c 100644 --- a/sync2/dbsync/dbitemstore.go +++ b/sync2/dbsync/dbitemstore.go @@ -2,6 +2,7 @@ package dbsync import ( "context" + "errors" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sync2/hashsync" @@ -40,31 +41,35 @@ func (d *DBItemStore) Add(ctx context.Context, k hashsync.Ordered) error { return d.ft.addHash(k.(KeyBytes)) } -func (d *DBItemStore) iter(min, max KeyBytes) (hashsync.Iterator, error) { - panic("TBD") - // return newDBRangeIterator(d.db, d.query, min, max, d.chunkSize) -} - // GetRangeInfo implements hashsync.ItemStore. func (d *DBItemStore) GetRangeInfo( preceding hashsync.Iterator, x, y hashsync.Ordered, count int, ) (hashsync.RangeInfo, error) { - // QQQQQ: note: iter's max is inclusive!!!! - // TBD: QQQQQ: need count limiting in ft.fingerprintInterval - panic("unimplemented") + fpr, err := d.ft.fingerprintInterval(x.(KeyBytes), y.(KeyBytes), count) + if err != nil { + return hashsync.RangeInfo{}, err + } + return hashsync.RangeInfo{ + Fingerprint: fpr.fp, + Count: count, + Start: fpr.start, + End: fpr.end, + }, nil } // Min implements hashsync.ItemStore. func (d *DBItemStore) Min() (hashsync.Iterator, error) { - // INCORRECT !!! should return nil if the store is empty - it1 := make(KeyBytes, d.keyLen) - it2 := make(KeyBytes, d.keyLen) - for i := range it2 { - it2[i] = 0xff + it, err := d.ft.start() + switch { + case err == nil: + return it, nil + case errors.Is(err, errEmptySet): + return nil, nil + default: + return nil, err } - return d.iter(it1, it2) } // Copy implements hashsync.ItemStore. diff --git a/sync2/dbsync/fptree.go b/sync2/dbsync/fptree.go index 6613ae88e8..da28c9094f 100644 --- a/sync2/dbsync/fptree.go +++ b/sync2/dbsync/fptree.go @@ -331,6 +331,7 @@ type iterator interface { type idStore interface { clone() idStore registerHash(h KeyBytes) error + start() (iterator, error) iter(from KeyBytes) (iterator, error) } diff --git a/sync2/dbsync/inmemidstore.go b/sync2/dbsync/inmemidstore.go index 1dc6519ba7..def4863b22 100644 --- a/sync2/dbsync/inmemidstore.go +++ b/sync2/dbsync/inmemidstore.go @@ -36,6 +36,14 @@ func (s *inMemIDStore) registerHash(h KeyBytes) error { return nil } +func (s *inMemIDStore) start() (iterator, error) { + node := s.sl.First() + if node == nil { + return nil, errEmptySet + } + return &inMemIDStoreIterator{sl: s.sl, node: node}, nil +} + func (s *inMemIDStore) iter(from KeyBytes) (iterator, error) { node := s.sl.FindGTENode(from) if node == nil { diff --git a/sync2/dbsync/inmemidstore_test.go b/sync2/dbsync/inmemidstore_test.go index 59650ae723..39c4fcc9d4 100644 --- a/sync2/dbsync/inmemidstore_test.go +++ b/sync2/dbsync/inmemidstore_test.go @@ -10,9 +10,16 @@ import ( ) func TestInMemIDStore(t *testing.T) { + var ( + it iterator + err error + ) s := newInMemIDStore(32, 24) - _, err := s.iter(util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")) + _, err = s.start() + require.ErrorIs(t, err, errEmptySet) + + _, err = s.iter(util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")) require.ErrorIs(t, err, errEmptySet) for _, h := range []string{ @@ -27,9 +34,13 @@ func TestInMemIDStore(t *testing.T) { s.registerHash(util.FromHex(h)) } - for range 2 { - it, err := s.iter( - util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")) + for i := range 6 { + if i%2 == 0 { + it, err = s.start() + } else { + it, err = s.iter( + util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")) + } require.NoError(t, err) var items []string for range 7 { diff --git a/sync2/dbsync/sqlidstore.go b/sync2/dbsync/sqlidstore.go index 320bec9918..a6717bc456 100644 --- a/sync2/dbsync/sqlidstore.go +++ b/sync2/dbsync/sqlidstore.go @@ -13,7 +13,7 @@ type sqlIDStore struct { db sql.Database query string keyLen int - maxDepth int + maxDepth int // TBD: remove } var _ idStore = &sqlIDStore{} @@ -31,6 +31,10 @@ func (s *sqlIDStore) registerHash(h KeyBytes) error { return nil } +func (s *sqlIDStore) start() (iterator, error) { + return s.iter(make(KeyBytes, s.keyLen)) +} + func (s *sqlIDStore) iter(from KeyBytes) (iterator, error) { if len(from) != s.keyLen { panic("BUG: invalid key length") @@ -66,6 +70,23 @@ func (s *dbBackedStore) registerHash(h KeyBytes) error { return s.inMemIDStore.registerHash(h) } +func (s *dbBackedStore) start() (iterator, error) { + dbIt, err := s.sqlIDStore.start() + if err != nil { + if errors.Is(err, errEmptySet) { + return s.inMemIDStore.start() + } + return nil, err + } + memIt, err := s.inMemIDStore.start() + if err == nil { + return combineIterators(dbIt, memIt), nil + } else if errors.Is(err, errEmptySet) { + return dbIt, nil + } + return nil, err +} + func (s *dbBackedStore) iter(from KeyBytes) (iterator, error) { dbIt, err := s.sqlIDStore.iter(from) if err != nil { diff --git a/sync2/dbsync/sqlidstore_test.go b/sync2/dbsync/sqlidstore_test.go index a1a47d3067..3b0cb94060 100644 --- a/sync2/dbsync/sqlidstore_test.go +++ b/sync2/dbsync/sqlidstore_test.go @@ -30,6 +30,13 @@ func TestDBBackedStore(t *testing.T) { {0, 0, 0, 1, 0, 0, 0, 0}, // wrapped around }, actualIDs) + it, err = store.start() + require.NoError(t, err) + for n := range 5 { + require.Equal(t, actualIDs[n], it.Key().(KeyBytes)) + require.NoError(t, it.Next()) + } + require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 2, 0, 0, 0, 0})) require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 9, 0, 0, 0, 0})) actualIDs = nil From 07e791dc9a0169ec1d0723a4c58a8493c9a46fb9 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 24 Jul 2024 20:38:22 +0400 Subject: [PATCH 49/76] fptree: dump tree stats --- sync2/dbsync/fptree.go | 1 + sync2/dbsync/fptree_test.go | 58 +++++++++++++++++++++++++++++++++---- 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/sync2/dbsync/fptree.go b/sync2/dbsync/fptree.go index da28c9094f..8228902b1c 100644 --- a/sync2/dbsync/fptree.go +++ b/sync2/dbsync/fptree.go @@ -162,6 +162,7 @@ const ( maxPrefixLen = 64 - prefixLenBits ) +// TODO: use uint32 for prefix type prefix uint64 func mkprefix(bits uint64, l int) prefix { diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index 303f386e1d..3c6ae09799 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -2,6 +2,7 @@ package dbsync import ( "fmt" + "math" "math/rand" "reflect" "runtime" @@ -1147,6 +1148,43 @@ const dbFile = "/Users/ivan4th/Library/Application Support/Spacemesh/node-data/7 // } // } +func treeStats(t *testing.T, ft *fpTree) { + numNodes := 0 + numCompactable := 0 + numLeafs := 0 + numEarlyLeafs := 0 + minLeafSize := uint32(math.MaxUint32) + maxLeafSize := uint32(0) + totalLeafSize := uint32(0) + var scanNode func(nodeIndex, int) bool + scanNode = func(idx nodeIndex, depth int) bool { + if idx == noIndex { + return false + } + numNodes++ + node := ft.np.node(idx) + if node.leaf() { + minLeafSize = min(minLeafSize, node.c) + maxLeafSize = max(maxLeafSize, node.c) + totalLeafSize += node.c + numLeafs++ + if depth < ft.maxDepth { + numEarlyLeafs++ + } + } else { + haveLeft := scanNode(node.left, depth+1) + if !scanNode(node.right, depth+1) || !haveLeft { + numCompactable++ + } + } + return true + } + scanNode(ft.root, 0) + avgLeafSize := float64(totalLeafSize) / float64(numLeafs) + t.Logf("tree stats: numNodes=%d numLeafs=%d numEarlyLeafs=%d numCompactable=%d minLeafSize=%d maxLeafSize=%d avgLeafSize=%f", + numNodes, numLeafs, numEarlyLeafs, numCompactable, minLeafSize, maxLeafSize, avgLeafSize) +} + func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { // t.Skip("slow tmp test") // counts := make(map[uint64]uint64) @@ -1163,7 +1201,7 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { var np nodePool if *hs == nil { t.Logf("loading IDs") - _, err = db.Exec("select id from atxs order by id", nil, func(stmt *sql.Statement) bool { + _, err = db.Exec("select id from atxs where epoch = 26 order by id", nil, func(stmt *sql.Statement) bool { var id types.Hash32 stmt.ColumnBytes(0, id[:]) *hs = append(*hs, id) @@ -1181,14 +1219,19 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { } // TODO: use testing.B and b.ReportAllocs() - runtime.GC() + for i := 0; i < 3; i++ { + runtime.GC() + time.Sleep(100 * time.Millisecond) + } var stats1 runtime.MemStats runtime.ReadMemStats(&stats1) - store := newSQLIDStore(db, "select id from atxs where id >= ? order by id limit ?", 32, maxDepth) + // TODO: pass extra bind params to the SQL query + store := newSQLIDStore(db, "select id from atxs where id >= ? and epoch = 26 order by id limit ?", 32, maxDepth) ft := newFPTree(&np, store, 32, maxDepth) for _, id := range *hs { ft.addHash(id[:]) } + treeStats(t, ft) // countFreq := make(map[uint64]int) // for _, c := range counts { @@ -1215,7 +1258,10 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { } elapsed := time.Now().Sub(ts) - runtime.GC() + for i := 0; i < 3; i++ { + runtime.GC() + time.Sleep(100 * time.Millisecond) + } var stats2 runtime.MemStats runtime.ReadMemStats(&stats2) t.Logf("range benchmark for maxDepth %d: %v per range, %f ranges/s, heap diff %d", @@ -1231,7 +1277,7 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { for n := 0; n < 50; n++ { x := types.RandomHash() y := types.RandomHash() - t.Logf("QQQQQ: x=%s y=%s", x.String(), y.String()) + // t.Logf("QQQQQ: x=%s y=%s", x.String(), y.String()) expFPResult := dumbFP(*hs, x, y, -1) //expFPResult := dumbAggATXs(t, db, x, y) fpr, err := ft.fingerprintInterval(x[:], y[:], -1) @@ -1243,7 +1289,7 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { if fpr.count != 0 { limit = rand.Intn(int(fpr.count)) } - t.Logf("QQQQQ: x=%s y=%s limit=%d", x.String(), y.String(), limit) + // t.Logf("QQQQQ: x=%s y=%s limit=%d", x.String(), y.String(), limit) expFPResult = dumbFP(*hs, x, y, limit) fpr, err = ft.fingerprintInterval(x[:], y[:], limit) require.NoError(t, err) From 2c1554aa0dcacc4ad480845215ba87afe935d1f1 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 24 Jul 2024 21:43:28 +0400 Subject: [PATCH 50/76] fptree: don't pass maxDepth to tree stores --- sync2/dbsync/dbitemstore.go | 2 +- sync2/dbsync/fptree_test.go | 14 +++++++------- sync2/dbsync/inmemidstore.go | 16 +++++++--------- sync2/dbsync/inmemidstore_test.go | 2 +- sync2/dbsync/sqlidstore.go | 23 ++++++++++------------- sync2/dbsync/sqlidstore_test.go | 2 +- 6 files changed, 27 insertions(+), 32 deletions(-) diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go index 1e4cd63d7c..79b098236b 100644 --- a/sync2/dbsync/dbitemstore.go +++ b/sync2/dbsync/dbitemstore.go @@ -25,7 +25,7 @@ func NewDBItemStore( query string, keyLen, maxDepth, chunkSize int, ) *DBItemStore { - dbStore := newDBBackedStore(db, query, keyLen, maxDepth) + dbStore := newDBBackedStore(db, query, keyLen) return &DBItemStore{ db: db, ft: newFPTree(np, dbStore, keyLen, maxDepth), diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index 3c6ae09799..026922966c 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -200,7 +200,7 @@ var _ idStore = &fakeIDDBStore{} const fakeIDQuery = "select id from foo where id >= ? order by id limit ?" func newFakeATXIDStore(db sql.Database, maxDepth int) *fakeIDDBStore { - return &fakeIDDBStore{db: db, sqlIDStore: newSQLIDStore(db, fakeIDQuery, 32, maxDepth)} + return &fakeIDDBStore{db: db, sqlIDStore: newSQLIDStore(db, fakeIDQuery, 32)} } func (s *fakeIDDBStore) registerHash(h KeyBytes) error { @@ -758,7 +758,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { func TestFPTree(t *testing.T) { t.Run("in-memory id store", func(t *testing.T) { testFPTree(t, func(maxDepth int) idStore { - return newInMemIDStore(32, maxDepth) + return newInMemIDStore(32) }) }) t.Run("fake ATX store", func(t *testing.T) { @@ -773,7 +773,7 @@ func TestFPTree(t *testing.T) { func TestFPTreeClone(t *testing.T) { var np nodePool - ft1 := newFPTree(&np, newInMemIDStore(32, 24), 32, 24) + ft1 := newFPTree(&np, newInMemIDStore(32), 32, 24) hashes := []types.Hash32{ types.HexToHash32("1111111111111111111111111111111111111111111111111111111111111111"), types.HexToHash32("3333333333333333333333333333333333333333333333333333333333333333"), @@ -1086,14 +1086,14 @@ func TestFPTreeManyItems(t *testing.T) { t.Run("bounds from the set", func(t *testing.T) { t.Parallel() repeatTestFPTreeManyItems(t, func(maxDepth int) idStore { - return newInMemIDStore(32, maxDepth) + return newInMemIDStore(32) }, false, numItems, maxDepth, repeatOuter, repeatInner) }) t.Run("random bounds", func(t *testing.T) { t.Parallel() repeatTestFPTreeManyItems(t, func(maxDepth int) idStore { - return newInMemIDStore(32, maxDepth) + return newInMemIDStore(32) }, true, numItems, maxDepth, repeatOuter, repeatInner) }) t.Run("SQL, bounds from the set", func(t *testing.T) { @@ -1226,7 +1226,7 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { var stats1 runtime.MemStats runtime.ReadMemStats(&stats1) // TODO: pass extra bind params to the SQL query - store := newSQLIDStore(db, "select id from atxs where id >= ? and epoch = 26 order by id limit ?", 32, maxDepth) + store := newSQLIDStore(db, "select id from atxs where id >= ? and epoch = 26 order by id limit ?", 32) ft := newFPTree(&np, store, 32, maxDepth) for _, id := range *hs { ft.addHash(id[:]) @@ -1309,7 +1309,7 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { } func TestATXFP(t *testing.T) { - t.Skip("slow test") + // t.Skip("slow test") var hs []types.Hash32 for maxDepth := 15; maxDepth <= 23; maxDepth++ { for i := 0; i < 3; i++ { diff --git a/sync2/dbsync/inmemidstore.go b/sync2/dbsync/inmemidstore.go index def4863b22..5fbc56a3a4 100644 --- a/sync2/dbsync/inmemidstore.go +++ b/sync2/dbsync/inmemidstore.go @@ -6,24 +6,22 @@ import ( ) type inMemIDStore struct { - sl *skiplist.SkipList - keyLen int - maxDepth int - len int + sl *skiplist.SkipList + keyLen int + len int } var _ idStore = &inMemIDStore{} -func newInMemIDStore(keyLen, maxDepth int) *inMemIDStore { +func newInMemIDStore(keyLen int) *inMemIDStore { return &inMemIDStore{ - sl: skiplist.New(keyLen), - keyLen: keyLen, - maxDepth: maxDepth, + sl: skiplist.New(keyLen), + keyLen: keyLen, } } func (s *inMemIDStore) clone() idStore { - newStore := newInMemIDStore(s.keyLen, s.maxDepth) + newStore := newInMemIDStore(s.keyLen) for node := s.sl.First(); node != nil; node = node.Next() { newStore.sl.Add(node.Key()) } diff --git a/sync2/dbsync/inmemidstore_test.go b/sync2/dbsync/inmemidstore_test.go index 39c4fcc9d4..f8ca3719ed 100644 --- a/sync2/dbsync/inmemidstore_test.go +++ b/sync2/dbsync/inmemidstore_test.go @@ -14,7 +14,7 @@ func TestInMemIDStore(t *testing.T) { it iterator err error ) - s := newInMemIDStore(32, 24) + s := newInMemIDStore(32) _, err = s.start() require.ErrorIs(t, err, errEmptySet) diff --git a/sync2/dbsync/sqlidstore.go b/sync2/dbsync/sqlidstore.go index a6717bc456..d7bda85b14 100644 --- a/sync2/dbsync/sqlidstore.go +++ b/sync2/dbsync/sqlidstore.go @@ -10,20 +10,19 @@ import ( const sqlMaxChunkSize = 1024 type sqlIDStore struct { - db sql.Database - query string - keyLen int - maxDepth int // TBD: remove + db sql.Database + query string + keyLen int } var _ idStore = &sqlIDStore{} -func newSQLIDStore(db sql.Database, query string, keyLen, maxDepth int) *sqlIDStore { - return &sqlIDStore{db: db, query: query, keyLen: keyLen, maxDepth: maxDepth} +func newSQLIDStore(db sql.Database, query string, keyLen int) *sqlIDStore { + return &sqlIDStore{db: db, query: query, keyLen: keyLen} } func (s *sqlIDStore) clone() idStore { - return newSQLIDStore(s.db, s.query, s.keyLen, s.maxDepth) + return newSQLIDStore(s.db, s.query, s.keyLen) } func (s *sqlIDStore) registerHash(h KeyBytes) error { @@ -32,6 +31,7 @@ func (s *sqlIDStore) registerHash(h KeyBytes) error { } func (s *sqlIDStore) start() (iterator, error) { + // TODO: should probably use a different query to get the first key return s.iter(make(KeyBytes, s.keyLen)) } @@ -45,16 +45,14 @@ func (s *sqlIDStore) iter(from KeyBytes) (iterator, error) { type dbBackedStore struct { *sqlIDStore *inMemIDStore - maxDepth int } var _ idStore = &dbBackedStore{} -func newDBBackedStore(db sql.Database, query string, keyLen, maxDepth int) *dbBackedStore { +func newDBBackedStore(db sql.Database, query string, keyLen int) *dbBackedStore { return &dbBackedStore{ - sqlIDStore: newSQLIDStore(db, query, keyLen, maxDepth), - inMemIDStore: newInMemIDStore(keyLen, maxDepth), - maxDepth: maxDepth, + sqlIDStore: newSQLIDStore(db, query, keyLen), + inMemIDStore: newInMemIDStore(keyLen), } } @@ -62,7 +60,6 @@ func (s *dbBackedStore) clone() idStore { return &dbBackedStore{ sqlIDStore: s.sqlIDStore.clone().(*sqlIDStore), inMemIDStore: s.inMemIDStore.clone().(*inMemIDStore), - maxDepth: s.maxDepth, } } diff --git a/sync2/dbsync/sqlidstore_test.go b/sync2/dbsync/sqlidstore_test.go index 3b0cb94060..f756bac693 100644 --- a/sync2/dbsync/sqlidstore_test.go +++ b/sync2/dbsync/sqlidstore_test.go @@ -14,7 +14,7 @@ func TestDBBackedStore(t *testing.T) { {0, 0, 0, 7, 0, 0, 0, 0}, } db := populateDB(t, 8, initialIDs) - store := newDBBackedStore(db, fakeIDQuery, 8, 24) + store := newDBBackedStore(db, fakeIDQuery, 8) var actualIDs []KeyBytes it, err := store.iter(KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}) require.NoError(t, err) From 64667b98e2c9944b57c14b594a2af7c4cd8caff1 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 25 Jul 2024 00:55:07 +0400 Subject: [PATCH 51/76] dbsync: implement working DBItemStore with test --- sync2/dbsync/dbitemstore.go | 62 ++++++++-- sync2/dbsync/dbitemstore_test.go | 200 +++++++++++++++++++++++++++++++ sync2/dbsync/fptree.go | 46 ++++--- sync2/dbsync/fptree_test.go | 112 ++++++++--------- 4 files changed, 334 insertions(+), 86 deletions(-) create mode 100644 sync2/dbsync/dbitemstore_test.go diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go index 79b098236b..24655d265e 100644 --- a/sync2/dbsync/dbitemstore.go +++ b/sync2/dbsync/dbitemstore.go @@ -3,41 +3,72 @@ package dbsync import ( "context" "errors" + "sync" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sync2/hashsync" ) type DBItemStore struct { + loadMtx sync.Mutex + loaded bool db sql.Database ft *fpTree - query string + loadQuery string + iterQuery string keyLen int maxDepth int - chunkSize int } var _ hashsync.ItemStore = &DBItemStore{} func NewDBItemStore( - np *nodePool, db sql.Database, - query string, - keyLen, maxDepth, chunkSize int, + loadQuery, iterQuery string, + keyLen, maxDepth int, ) *DBItemStore { - dbStore := newDBBackedStore(db, query, keyLen) + var np nodePool + dbStore := newDBBackedStore(db, iterQuery, keyLen) return &DBItemStore{ db: db, - ft: newFPTree(np, dbStore, keyLen, maxDepth), - query: query, + ft: newFPTree(&np, dbStore, keyLen, maxDepth), + loadQuery: loadQuery, + iterQuery: iterQuery, keyLen: keyLen, maxDepth: maxDepth, - chunkSize: chunkSize, } } +func (d *DBItemStore) load() error { + _, err := d.db.Exec(d.loadQuery, nil, + func(stmt *sql.Statement) bool { + id := make(KeyBytes, d.keyLen) // TODO: don't allocate new ID + stmt.ColumnBytes(0, id[:]) + d.ft.addStoredHash(id) + return true + }) + return err +} + +func (d *DBItemStore) EnsureLoaded() error { + d.loadMtx.Lock() + defer d.loadMtx.Unlock() + if !d.loaded { + if err := d.load(); err != nil { + return err + } + d.loaded = true + } + return nil +} + // Add implements hashsync.ItemStore. func (d *DBItemStore) Add(ctx context.Context, k hashsync.Ordered) error { + d.EnsureLoaded() + has, err := d.Has(k) // TODO: this check shouldn't be needed + if has || err != nil { + return err + } return d.ft.addHash(k.(KeyBytes)) } @@ -47,13 +78,14 @@ func (d *DBItemStore) GetRangeInfo( x, y hashsync.Ordered, count int, ) (hashsync.RangeInfo, error) { + d.EnsureLoaded() fpr, err := d.ft.fingerprintInterval(x.(KeyBytes), y.(KeyBytes), count) if err != nil { return hashsync.RangeInfo{}, err } return hashsync.RangeInfo{ Fingerprint: fpr.fp, - Count: count, + Count: int(fpr.count), Start: fpr.start, End: fpr.end, }, nil @@ -61,6 +93,7 @@ func (d *DBItemStore) GetRangeInfo( // Min implements hashsync.ItemStore. func (d *DBItemStore) Min() (hashsync.Iterator, error) { + d.EnsureLoaded() it, err := d.ft.start() switch { case err == nil: @@ -74,18 +107,23 @@ func (d *DBItemStore) Min() (hashsync.Iterator, error) { // Copy implements hashsync.ItemStore. func (d *DBItemStore) Copy() hashsync.ItemStore { + d.EnsureLoaded() return &DBItemStore{ db: d.db, ft: d.ft.clone(), - query: d.query, + loadQuery: d.loadQuery, + iterQuery: d.iterQuery, keyLen: d.keyLen, maxDepth: d.maxDepth, - chunkSize: d.chunkSize, + loaded: true, } } // Has implements hashsync.ItemStore. func (d *DBItemStore) Has(k hashsync.Ordered) (bool, error) { + d.EnsureLoaded() + // TODO: should often be able to avoid querying the database if we check the key + // against the fptree it, err := d.ft.iter(k.(KeyBytes)) if err == nil { return k.Compare(it.Key()) == 0, nil diff --git a/sync2/dbsync/dbitemstore_test.go b/sync2/dbsync/dbitemstore_test.go new file mode 100644 index 0000000000..9d6b1b7f33 --- /dev/null +++ b/sync2/dbsync/dbitemstore_test.go @@ -0,0 +1,200 @@ +package dbsync + +import ( + "context" + "fmt" + "testing" + + "github.com/spacemeshos/go-spacemesh/common/util" + "github.com/stretchr/testify/require" +) + +func TestDBItemStoreEmpty(t *testing.T) { + db := populateDB(t, 8, nil) + s := NewDBItemStore(db, "select id from foo", testQuery, 32, 24) + it, err := s.Min() + require.NoError(t, err) + require.Nil(t, it) + + info, err := s.GetRangeInfo(nil, + KeyBytes(util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")), + KeyBytes(util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")), + -1) + require.NoError(t, err) + require.Equal(t, 0, info.Count) + require.Equal(t, "000000000000000000000000", info.Fingerprint.(fmt.Stringer).String()) + require.Nil(t, info.Start) + require.Nil(t, info.End) + + info, err = s.GetRangeInfo(nil, + KeyBytes(util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")), + KeyBytes(util.FromHex("9999000000000000000000000000000000000000000000000000000000000000")), + -1) + require.NoError(t, err) + require.Equal(t, 0, info.Count) + require.Equal(t, "000000000000000000000000", info.Fingerprint.(fmt.Stringer).String()) + require.Nil(t, info.Start) + require.Nil(t, info.End) +} + +func TestDBItemStore(t *testing.T) { + ids := []KeyBytes{ + util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), + util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), + util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), + util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), + util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000"), + } + db := populateDB(t, 8, ids) + s := NewDBItemStore(db, "select id from foo", testQuery, 32, 24) + it, err := s.Min() + require.NoError(t, err) + require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", + it.Key().(KeyBytes).String()) + has, err := s.Has(KeyBytes(util.FromHex("9876000000000000000000000000000000000000000000000000000000000000"))) + require.NoError(t, err) + require.False(t, has) + + for _, tc := range []struct { + xIdx, yIdx int + limit int + fp string + count int + startIdx, endIdx int + }{ + { + xIdx: 0, + yIdx: 0, + limit: 0, + fp: "000000000000000000000000", + count: 0, + startIdx: 0, + endIdx: 0, + }, + { + xIdx: 1, + yIdx: 1, + limit: -1, + fp: "642464b773377bbddddddddd", + count: 5, + startIdx: 1, + endIdx: 1, + }, + { + xIdx: 0, + yIdx: 3, + limit: -1, + fp: "4761032dcfe98ba555555555", + count: 3, + startIdx: 0, + endIdx: 3, + }, + { + xIdx: 2, + yIdx: 0, + limit: -1, + fp: "761032cfe98ba54ddddddddd", + count: 3, + startIdx: 2, + endIdx: 0, + }, + { + xIdx: 3, + yIdx: 2, + limit: 3, + fp: "2345679abcdef01888888888", + count: 3, + startIdx: 3, + endIdx: 1, + }, + } { + name := fmt.Sprintf("%d-%d_%d", tc.xIdx, tc.yIdx, tc.limit) + t.Run(name, func(t *testing.T) { + t.Logf("x %s y %s limit %d", ids[tc.xIdx], ids[tc.yIdx], tc.limit) + info, err := s.GetRangeInfo(nil, ids[tc.xIdx], ids[tc.yIdx], tc.limit) + require.NoError(t, err) + require.Equal(t, tc.count, info.Count) + require.Equal(t, tc.fp, info.Fingerprint.(fmt.Stringer).String()) + require.Equal(t, ids[tc.startIdx], info.Start.Key().(KeyBytes)) + require.Equal(t, ids[tc.endIdx], info.End.Key().(KeyBytes)) + has, err := s.Has(ids[tc.startIdx]) + require.NoError(t, err) + require.True(t, has) + has, err = s.Has(ids[tc.endIdx]) + require.NoError(t, err) + require.True(t, has) + }) + } +} + +func TestDBItemStoreAdd(t *testing.T) { + ids := []KeyBytes{ + util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), + util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), + util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), + util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), + } + db := populateDB(t, 8, ids) + s := NewDBItemStore(db, "select id from foo", testQuery, 32, 24) + it, err := s.Min() + require.NoError(t, err) + require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", + it.Key().(KeyBytes).String()) + + newID := KeyBytes(util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000")) + require.NoError(t, s.Add(context.Background(), newID)) + + // // QQQQQ: rm + // s.ft.traceEnabled = true + // var sb strings.Builder + // s.ft.dump(&sb) + // t.Logf("tree:\n%s", sb.String()) + + info, err := s.GetRangeInfo(nil, ids[2], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 3, info.Count) + require.Equal(t, "761032cfe98ba54ddddddddd", info.Fingerprint.(fmt.Stringer).String()) + require.Equal(t, ids[2], info.Start.Key().(KeyBytes)) + require.Equal(t, ids[0], info.End.Key().(KeyBytes)) +} + +func TestDBItemStoreCopy(t *testing.T) { + ids := []KeyBytes{ + util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), + util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), + util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), + util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), + } + db := populateDB(t, 8, ids) + s := NewDBItemStore(db, "select id from foo", testQuery, 32, 24) + it, err := s.Min() + require.NoError(t, err) + require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", + it.Key().(KeyBytes).String()) + + copy := s.Copy() + + info, err := copy.GetRangeInfo(nil, ids[2], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 2, info.Count) + require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.(fmt.Stringer).String()) + require.Equal(t, ids[2], info.Start.Key().(KeyBytes)) + require.Equal(t, ids[0], info.End.Key().(KeyBytes)) + + newID := KeyBytes(util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000")) + require.NoError(t, copy.Add(context.Background(), newID)) + + info, err = s.GetRangeInfo(nil, ids[2], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 2, info.Count) + require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.(fmt.Stringer).String()) + require.Equal(t, ids[2], info.Start.Key().(KeyBytes)) + require.Equal(t, ids[0], info.End.Key().(KeyBytes)) + + info, err = copy.GetRangeInfo(nil, ids[2], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 3, info.Count) + require.Equal(t, "761032cfe98ba54ddddddddd", info.Fingerprint.(fmt.Stringer).String()) + require.Equal(t, ids[2], info.Start.Key().(KeyBytes)) + require.Equal(t, ids[0], info.End.Key().(KeyBytes)) +} diff --git a/sync2/dbsync/fptree.go b/sync2/dbsync/fptree.go index 8228902b1c..17f165c105 100644 --- a/sync2/dbsync/fptree.go +++ b/sync2/dbsync/fptree.go @@ -389,16 +389,17 @@ func (ft *fpTree) clone() *fpTree { np: ft.np, idStore: ft.idStore.clone(), root: ft.root, + keyLen: ft.keyLen, maxDepth: ft.maxDepth, } } func (ft *fpTree) pushDown(fpA, fpB fingerprint, p prefix, curCount uint32) nodeIndex { - // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: fpA %s fpB %s p %s\n", fpA.(fmt.Stringer), fpB.(fmt.Stringer), p) + // ft.log("QQQQQ: pushDown: fpA %s fpB %s p %s", fpA, fpB, p) fpCombined := fpA fpCombined.update(fpB[:]) if ft.maxDepth != 0 && p.len() == ft.maxDepth { - // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: add at maxDepth\n") + // ft.log("QQQQQ: pushDown: add at maxDepth") return ft.np.add(fpCombined, curCount+1, noIndex, noIndex) } if curCount != 1 { @@ -406,16 +407,16 @@ func (ft *fpTree) pushDown(fpA, fpB fingerprint, p prefix, curCount uint32) node } dirA := fpA.bitFromLeft(p.len()) dirB := fpB.bitFromLeft(p.len()) - // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: bitFromLeft %d: dirA %v dirB %v\n", p.len(), dirA, dirB) + // ft.log("QQQQQ: pushDown: bitFromLeft %d: dirA %v dirB %v", p.len(), dirA, dirB) if dirA == dirB { childIdx := ft.pushDown(fpA, fpB, p.dir(dirA), 1) if dirA { r := ft.np.add(fpCombined, 2, noIndex, childIdx) - // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: sameDir: left => %d\n", r) + // ft.log("QQQQQ: pushDown: sameDir: left => %d", r) return r } else { r := ft.np.add(fpCombined, 2, childIdx, noIndex) - // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: sameDir: right => %d\n", r) + // ft.log("QQQQQ: pushDown: sameDir: right => %d", r) return r } } @@ -424,11 +425,11 @@ func (ft *fpTree) pushDown(fpA, fpB fingerprint, p prefix, curCount uint32) node idxB := ft.np.add(fpB, curCount, noIndex, noIndex) if dirA { r := ft.np.add(fpCombined, 2, idxB, idxA) - // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: add A-B => %d\n", r) + // ft.log("QQQQQ: pushDown: add A-B => %d", r) return r } else { r := ft.np.add(fpCombined, 2, idxA, idxB) - // fmt.Fprintf(os.Stderr, "QQQQQ: pushDown: add B-A => %d\n", r) + // ft.log("QQQQQ: pushDown: add B-A => %d", r) return r } } @@ -436,7 +437,7 @@ func (ft *fpTree) pushDown(fpA, fpB fingerprint, p prefix, curCount uint32) node func (ft *fpTree) addValue(fp fingerprint, p prefix, idx nodeIndex) nodeIndex { if idx == noIndex { r := ft.np.add(fp, 1, noIndex, noIndex) - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: addNew fp %s p %s => %d\n", fp.(fmt.Stringer), p.(fmt.Stringer), r) + // ft.log("QQQQQ: addValue: addNew fp %s p %s => %d", fp, p, r) return r } node := ft.np.node(idx) @@ -445,45 +446,52 @@ func (ft *fpTree) addValue(fp fingerprint, p prefix, idx nodeIndex) nodeIndex { // we're at a leaf node, need to push down the old fingerprint, or, // if we've reached the max depth, just update the current node r := ft.pushDown(fp, node.fp, p, node.c) - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: pushDown fp %s p %s oldIdx %d => %d\n", fp.(fmt.Stringer), p.(fmt.Stringer), idx, r) + // ft.log("QQQQQ: addValue: pushDown fp %s p %s oldIdx %d => %d", fp, p, idx, r) return r } fpCombined := fp fpCombined.update(node.fp[:]) if fp.bitFromLeft(p.len()) { - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceRight fp %s p %s oldIdx %d\n", fp.(fmt.Stringer), p.(fmt.Stringer), idx) + // ft.log("QQQQQ: addValue: replaceRight fp %s p %s oldIdx %d", fp, p, idx) if node.left != noIndex { ft.np.ref(node.left) - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: ref left %d -- refCount %d\n", node.left, ft.np.entry(node.left).refCount) + // ft.log("QQQQQ: addValue: ref left %d -- refCount %d", node.left, ft.np.entry(node.left).refCount) } newRight := ft.addValue(fp, p.right(), node.right) r := ft.np.add(fpCombined, node.c+1, node.left, newRight) - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceRight fp %s p %s oldIdx %d => %d node.left %d newRight %d\n", fp.(fmt.Stringer), p.(fmt.Stringer), idx, r, node.left, newRight) + // ft.log("QQQQQ: addValue: replaceRight fp %s p %s oldIdx %d => %d node.left %d newRight %d", fp, p, idx, r, node.left, newRight) return r } else { - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceLeft fp %s p %s oldIdx %d\n", fp.(fmt.Stringer), p.(fmt.Stringer), idx) + // ft.log("QQQQQ: addValue: replaceLeft fp %s p %s oldIdx %d", fp, p, idx) if node.right != noIndex { - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: ref right %d -- refCount %d\n", node.right, ft.np.entry(node.right).refCount) + // ft.log("QQQQQ: addValue: ref right %d -- refCount %d", node.right, ft.np.entry(node.right).refCount) ft.np.ref(node.right) } newLeft := ft.addValue(fp, p.left(), node.left) r := ft.np.add(fpCombined, node.c+1, newLeft, node.right) - // fmt.Fprintf(os.Stderr, "QQQQQ: addValue: replaceLeft fp %s p %s oldIdx %d => %d newLeft %d node.right %d\n", fp.(fmt.Stringer), p.(fmt.Stringer), idx, r, newLeft, node.right) + // ft.log("QQQQQ: addValue: replaceLeft fp %s p %s oldIdx %d => %d newLeft %d node.right %d", fp, p, idx, r, newLeft, node.right) return r } } -func (ft *fpTree) addHash(h KeyBytes) error { - // fmt.Fprintf(os.Stderr, "QQQQQ: addHash: %s\n", hex.EncodeToString(h)) +func (ft *fpTree) addStoredHash(h KeyBytes) { var fp fingerprint fp.update(h) ft.rootMtx.Lock() defer ft.rootMtx.Unlock() + ft.log("addStoredHash: h %s fp %s", h, fp) oldRoot := ft.root ft.root = ft.addValue(fp, 0, ft.root) ft.releaseNode(oldRoot) - // fmt.Fprintf(os.Stderr, "QQQQQ: addHash: new root %d\n", ft.root) - return ft.idStore.registerHash(h) +} + +func (ft *fpTree) addHash(h KeyBytes) error { + ft.log("addHash: h %s", h) + if err := ft.idStore.registerHash(h); err != nil { + return err + } + ft.addStoredHash(h) + return nil } func (ft *fpTree) followPrefix(from nodeIndex, p, followed prefix) (idx nodeIndex, rp prefix, found bool) { diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index 026922966c..a5addfb9bf 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -221,7 +221,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx, yIdx int x, y string limit int - fp fingerprint + fp string count uint32 itype int startIdx, endIdx int @@ -242,7 +242,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { x: "123456789abcdef0000000000000000000000000000000000000000000000000", y: "123456789abcdef0000000000000000000000000000000000000000000000000", limit: -1, - fp: hexToFingerprint("000000000000000000000000"), + fp: "000000000000000000000000", count: 0, itype: 0, startIdx: -1, @@ -252,7 +252,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { x: "123456789abcdef0000000000000000000000000000000000000000000000000", y: "123456789abcdef0000000000000000000000000000000000000000000000000", limit: 1, - fp: hexToFingerprint("000000000000000000000000"), + fp: "000000000000000000000000", count: 0, itype: 0, startIdx: -1, @@ -262,7 +262,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { x: "123456789abcdef0000000000000000000000000000000000000000000000000", y: "223456789abcdef0000000000000000000000000000000000000000000000000", limit: 1, - fp: hexToFingerprint("000000000000000000000000"), + fp: "000000000000000000000000", count: 0, itype: -1, startIdx: -1, @@ -272,7 +272,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { x: "223456789abcdef0000000000000000000000000000000000000000000000000", y: "123456789abcdef0000000000000000000000000000000000000000000000000", limit: 1, - fp: hexToFingerprint("000000000000000000000000"), + fp: "000000000000000000000000", count: 0, itype: 1, startIdx: -1, @@ -292,10 +292,11 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { }, ranges: []rangeTestCase{ { - xIdx: 0, - yIdx: 0, - limit: -1, - fp: hexToFingerprint("642464b773377bbddddddddd"), + xIdx: 0, + yIdx: 0, + limit: -1, + // QQQQQ: use string instead of fingerprint in tcs + fp: "642464b773377bbddddddddd", count: 5, itype: 0, startIdx: 0, @@ -305,7 +306,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 0, yIdx: 0, limit: 0, - fp: hexToFingerprint("000000000000000000000000"), + fp: "000000000000000000000000", count: 0, itype: 0, startIdx: 0, @@ -315,7 +316,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 0, yIdx: 0, limit: 3, - fp: hexToFingerprint("4761032dcfe98ba555555555"), + fp: "4761032dcfe98ba555555555", count: 3, itype: 0, startIdx: 0, @@ -325,7 +326,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 4, yIdx: 4, limit: -1, - fp: hexToFingerprint("642464b773377bbddddddddd"), + fp: "642464b773377bbddddddddd", count: 5, itype: 0, startIdx: 4, @@ -335,7 +336,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 4, yIdx: 4, limit: 1, - fp: hexToFingerprint("abcdef123456789000000000"), + fp: "abcdef123456789000000000", count: 1, itype: 0, startIdx: 4, @@ -345,7 +346,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 0, yIdx: 1, limit: -1, - fp: hexToFingerprint("000000000000000000000000"), + fp: "000000000000000000000000", count: 1, itype: -1, startIdx: 0, @@ -355,7 +356,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 0, yIdx: 3, limit: -1, - fp: hexToFingerprint("4761032dcfe98ba555555555"), + fp: "4761032dcfe98ba555555555", count: 3, itype: -1, startIdx: 0, @@ -365,7 +366,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 0, yIdx: 4, limit: 3, - fp: hexToFingerprint("4761032dcfe98ba555555555"), + fp: "4761032dcfe98ba555555555", count: 3, itype: -1, startIdx: 0, @@ -375,7 +376,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 0, yIdx: 4, limit: 0, - fp: hexToFingerprint("000000000000000000000000"), + fp: "000000000000000000000000", count: 0, itype: -1, startIdx: 0, @@ -385,7 +386,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 1, yIdx: 4, limit: -1, - fp: hexToFingerprint("cfe98ba54761032ddddddddd"), + fp: "cfe98ba54761032ddddddddd", count: 3, itype: -1, startIdx: 1, @@ -395,7 +396,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 1, yIdx: 0, limit: -1, - fp: hexToFingerprint("642464b773377bbddddddddd"), + fp: "642464b773377bbddddddddd", count: 4, itype: 1, startIdx: 1, @@ -405,7 +406,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 2, yIdx: 0, limit: -1, - fp: hexToFingerprint("761032cfe98ba54ddddddddd"), + fp: "761032cfe98ba54ddddddddd", count: 3, itype: 1, startIdx: 2, @@ -415,7 +416,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 2, yIdx: 0, limit: 0, - fp: hexToFingerprint("000000000000000000000000"), + fp: "000000000000000000000000", count: 0, itype: 1, startIdx: 2, @@ -425,7 +426,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 3, yIdx: 1, limit: -1, - fp: hexToFingerprint("2345679abcdef01888888888"), + fp: "2345679abcdef01888888888", count: 3, itype: 1, startIdx: 3, @@ -435,7 +436,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 3, yIdx: 2, limit: -1, - fp: hexToFingerprint("317131e226622ee888888888"), + fp: "317131e226622ee888888888", count: 4, itype: 1, startIdx: 3, @@ -445,7 +446,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 3, yIdx: 2, limit: 3, - fp: hexToFingerprint("2345679abcdef01888888888"), + fp: "2345679abcdef01888888888", count: 3, itype: 1, startIdx: 3, @@ -455,7 +456,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { x: "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0", y: "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", limit: -1, - fp: hexToFingerprint("000000000000000000000000"), + fp: "000000000000000000000000", count: 0, itype: -1, startIdx: 0, @@ -477,7 +478,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 0, yIdx: 0, limit: -1, - fp: hexToFingerprint("a76fc452775b55e0dacd8be5"), + fp: "a76fc452775b55e0dacd8be5", count: 4, itype: 0, startIdx: 0, @@ -487,7 +488,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 0, yIdx: 0, limit: 3, - fp: hexToFingerprint("4e5ea7ab7f38576018653418"), + fp: "4e5ea7ab7f38576018653418", count: 3, itype: 0, startIdx: 0, @@ -497,7 +498,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 0, yIdx: 3, limit: -1, - fp: hexToFingerprint("4e5ea7ab7f38576018653418"), + fp: "4e5ea7ab7f38576018653418", count: 3, itype: -1, startIdx: 0, @@ -507,7 +508,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 3, yIdx: 1, limit: -1, - fp: hexToFingerprint("87760f5e21a0868dc3b0c7a9"), + fp: "87760f5e21a0868dc3b0c7a9", count: 2, itype: 1, startIdx: 3, @@ -517,7 +518,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 3, yIdx: 2, limit: -1, - fp: hexToFingerprint("05ef78ea6568c6000e6cd5b9"), + fp: "05ef78ea6568c6000e6cd5b9", count: 3, itype: 1, startIdx: 3, @@ -567,7 +568,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 31, yIdx: 0, limit: -1, - fp: hexToFingerprint("e9110a384198b47be2bb63e6"), + fp: "e9110a384198b47be2bb63e6", count: 1, itype: 1, startIdx: 31, @@ -617,7 +618,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { x: "582485793d71c3e8429b9b2c8df360c2ea7bf90080d5bf375fe4618b00f59c0b", y: "7eff517d2f11ed32f935be3001499ac779160a4891a496f88da0ceb33e3496cc", limit: -1, - fp: hexToFingerprint("66883aa35d2c8d293f07c5c5"), + fp: "66883aa35d2c8d293f07c5c5", count: 1, itype: -1, startIdx: 10, @@ -639,7 +640,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 2, yIdx: 0, limit: 1, - fp: hexToFingerprint("b5527010e990254702f77ffc"), + fp: "b5527010e990254702f77ffc", count: 1, itype: 1, startIdx: 2, @@ -661,7 +662,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { xIdx: 3, yIdx: 3, limit: 2, - fp: hexToFingerprint("9fbedabb68b3dd688767f8e9"), + fp: "9fbedabb68b3dd688767f8e9", count: 2, itype: 0, startIdx: 3, @@ -684,7 +685,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { y: "cb93566c2037bc8353162e9988974e4585c14f656bf6aed8fa51d00e1ae594de", limit: -1, // fingerprint: 0xb5, 0xc0, 0x6e, 0x5b, 0x55, 0x30, 0x61, 0xbf, 0xa1, 0xc7, 0xe, 0x75 - fp: hexToFingerprint("b5c06e5b553061bfa1c70e75"), + fp: "b5c06e5b553061bfa1c70e75", count: 3, itype: -1, startIdx: 1, @@ -727,7 +728,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { t.Run(name, func(t *testing.T) { fpr, err := ft.fingerprintInterval(x[:], y[:], rtc.limit) require.NoError(t, err) - require.Equal(t, rtc.fp, fpr.fp, "fp") + require.Equal(t, rtc.fp, fpr.fp.String(), "fp") require.Equal(t, rtc.count, fpr.count, "count") require.Equal(t, rtc.itype, fpr.itype, "itype") @@ -784,7 +785,7 @@ func TestFPTreeClone(t *testing.T) { fpr, err := ft1.fingerprintInterval(hashes[0][:], hashes[0][:], -1) require.NoError(t, err) - require.Equal(t, hexToFingerprint("222222222222222222222222"), fpr.fp, "fp") + require.Equal(t, "222222222222222222222222", fpr.fp.String(), "fp") require.Equal(t, uint32(2), fpr.count, "count") require.Equal(t, 0, fpr.itype, "itype") @@ -805,7 +806,7 @@ func TestFPTreeClone(t *testing.T) { // original tree unchanged --- rmme!!!! fpr, err = ft1.fingerprintInterval(hashes[0][:], hashes[0][:], -1) require.NoError(t, err) - require.Equal(t, hexToFingerprint("222222222222222222222222"), fpr.fp, "fp") + require.Equal(t, "222222222222222222222222", fpr.fp.String(), "fp") require.Equal(t, uint32(2), fpr.count, "count") require.Equal(t, 0, fpr.itype, "itype") @@ -813,14 +814,14 @@ func TestFPTreeClone(t *testing.T) { fpr, err = ft2.fingerprintInterval(hashes[0][:], hashes[0][:], -1) require.NoError(t, err) - require.Equal(t, hexToFingerprint("666666666666666666666666"), fpr.fp, "fp") + require.Equal(t, "666666666666666666666666", fpr.fp.String(), "fp") require.Equal(t, uint32(3), fpr.count, "count") require.Equal(t, 0, fpr.itype, "itype") // original tree unchanged fpr, err = ft1.fingerprintInterval(hashes[0][:], hashes[0][:], -1) require.NoError(t, err) - require.Equal(t, hexToFingerprint("222222222222222222222222"), fpr.fp, "fp") + require.Equal(t, "222222222222222222222222", fpr.fp.String(), "fp") require.Equal(t, uint32(2), fpr.count, "count") require.Equal(t, 0, fpr.itype, "itype") @@ -1201,20 +1202,21 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { var np nodePool if *hs == nil { t.Logf("loading IDs") - _, err = db.Exec("select id from atxs where epoch = 26 order by id", nil, func(stmt *sql.Statement) bool { - var id types.Hash32 - stmt.ColumnBytes(0, id[:]) - *hs = append(*hs, id) - // v := load64(id[:]) - // counts[v>>40]++ - // if first { - // first = false - // } else { - // prefLens[bits.LeadingZeros64(prev^v)]++ - // } - // prev = v - return true - }) + _, err = db.Exec("select id from atxs where epoch = 26 order by id", + nil, func(stmt *sql.Statement) bool { + var id types.Hash32 + stmt.ColumnBytes(0, id[:]) + *hs = append(*hs, id) + // v := load64(id[:]) + // counts[v>>40]++ + // if first { + // first = false + // } else { + // prefLens[bits.LeadingZeros64(prev^v)]++ + // } + // prev = v + return true + }) require.NoError(t, err) } @@ -1309,7 +1311,7 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { } func TestATXFP(t *testing.T) { - // t.Skip("slow test") + t.Skip("slow test") var hs []types.Hash32 for maxDepth := 15; maxDepth <= 23; maxDepth++ { for i := 0; i < 3; i++ { From fd092293bb9020a58be27e3754f46a9375a29792 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 25 Jul 2024 17:42:17 +0400 Subject: [PATCH 52/76] dbsync: integrate with hashsync pairwise sync --- sync2/dbsync/dbitemstore.go | 80 ++++++++++ sync2/dbsync/dbitemstore_test.go | 8 +- sync2/dbsync/dbiter.go | 28 +++- sync2/dbsync/dbiter_test.go | 40 ++++- sync2/dbsync/fptree.go | 6 +- sync2/dbsync/fptree_test.go | 91 ++++++------ sync2/dbsync/p2p_test.go | 242 +++++++++++++++++++++++++++++++ sync2/dbsync/sqlidstore.go | 4 +- sync2/hashsync/log.go | 55 +++++++ sync2/hashsync/rangesync.go | 62 +++++++- 10 files changed, 549 insertions(+), 67 deletions(-) create mode 100644 sync2/dbsync/p2p_test.go create mode 100644 sync2/hashsync/log.go diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go index 24655d265e..a9fa4972b8 100644 --- a/sync2/dbsync/dbitemstore.go +++ b/sync2/dbsync/dbitemstore.go @@ -5,6 +5,7 @@ import ( "errors" "sync" + "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sync2/hashsync" ) @@ -133,3 +134,82 @@ func (d *DBItemStore) Has(k hashsync.Ordered) (bool, error) { } return false, nil } + +// TODO: get rid of ItemStoreAdapter, it shouldn't be needed +type ItemStoreAdapter struct { + s *DBItemStore +} + +var _ hashsync.ItemStore = &ItemStoreAdapter{} + +func NewItemStoreAdapter(s *DBItemStore) *ItemStoreAdapter { + return &ItemStoreAdapter{s: s} +} + +func (a *ItemStoreAdapter) wrapIterator(it hashsync.Iterator) hashsync.Iterator { + if it == nil { + return nil + } + return &iteratorAdapter{it: it} +} + +// Add implements hashsync.ItemStore. +func (a *ItemStoreAdapter) Add(ctx context.Context, k hashsync.Ordered) error { + h := k.(types.Hash32) + return a.s.Add(ctx, KeyBytes(h[:])) +} + +// Copy implements hashsync.ItemStore. +func (a *ItemStoreAdapter) Copy() hashsync.ItemStore { + return &ItemStoreAdapter{s: a.s.Copy().(*DBItemStore)} +} + +// GetRangeInfo implements hashsync.ItemStore. +func (a *ItemStoreAdapter) GetRangeInfo(preceding hashsync.Iterator, x hashsync.Ordered, y hashsync.Ordered, count int) (hashsync.RangeInfo, error) { + hx := x.(types.Hash32) + hy := y.(types.Hash32) + info, err := a.s.GetRangeInfo(preceding, KeyBytes(hx[:]), KeyBytes(hy[:]), count) + if err != nil { + return hashsync.RangeInfo{}, err + } + var fp types.Hash12 + src := info.Fingerprint.(fingerprint) + copy(fp[:], src[:]) + return hashsync.RangeInfo{ + Fingerprint: fp, + Count: info.Count, + Start: a.wrapIterator(info.Start), + End: a.wrapIterator(info.End), + }, nil +} + +// Has implements hashsync.ItemStore. +func (a *ItemStoreAdapter) Has(k hashsync.Ordered) (bool, error) { + h := k.(types.Hash32) + return a.s.Has(KeyBytes(h[:])) +} + +// Min implements hashsync.ItemStore. +func (a *ItemStoreAdapter) Min() (hashsync.Iterator, error) { + it, err := a.s.Min() + if err != nil { + return nil, err + } + return a.wrapIterator(it), nil +} + +type iteratorAdapter struct { + it hashsync.Iterator +} + +var _ hashsync.Iterator = &iteratorAdapter{} + +func (ia *iteratorAdapter) Key() hashsync.Ordered { + var h types.Hash32 + copy(h[:], ia.it.Key().(KeyBytes)) + return h +} + +func (ia *iteratorAdapter) Next() error { + return ia.it.Next() +} diff --git a/sync2/dbsync/dbitemstore_test.go b/sync2/dbsync/dbitemstore_test.go index 9d6b1b7f33..aae101866b 100644 --- a/sync2/dbsync/dbitemstore_test.go +++ b/sync2/dbsync/dbitemstore_test.go @@ -10,7 +10,7 @@ import ( ) func TestDBItemStoreEmpty(t *testing.T) { - db := populateDB(t, 8, nil) + db := populateDB(t, 32, nil) s := NewDBItemStore(db, "select id from foo", testQuery, 32, 24) it, err := s.Min() require.NoError(t, err) @@ -45,7 +45,7 @@ func TestDBItemStore(t *testing.T) { util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000"), } - db := populateDB(t, 8, ids) + db := populateDB(t, 32, ids) s := NewDBItemStore(db, "select id from foo", testQuery, 32, 24) it, err := s.Min() require.NoError(t, err) @@ -134,7 +134,7 @@ func TestDBItemStoreAdd(t *testing.T) { util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), } - db := populateDB(t, 8, ids) + db := populateDB(t, 32, ids) s := NewDBItemStore(db, "select id from foo", testQuery, 32, 24) it, err := s.Min() require.NoError(t, err) @@ -165,7 +165,7 @@ func TestDBItemStoreCopy(t *testing.T) { util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), } - db := populateDB(t, 8, ids) + db := populateDB(t, 32, ids) s := NewDBItemStore(db, "select id from foo", testQuery, 32, 24) it, err := s.Min() require.NoError(t, err) diff --git a/sync2/dbsync/dbiter.go b/sync2/dbsync/dbiter.go index d7fcff6f85..f6b4e2f391 100644 --- a/sync2/dbsync/dbiter.go +++ b/sync2/dbsync/dbiter.go @@ -208,8 +208,28 @@ type combinedIterator struct { // combineIterators combines multiple iterators into one, returning the smallest current // key among all iterators at each step. -func combineIterators(iters ...iterator) iterator { - return &combinedIterator{iters: iters} +func combineIterators(startingPoint hashsync.Ordered, iters ...iterator) iterator { + var c combinedIterator + // Some of the iterators may already be wrapped around. + // This corresponds to the case when we ask an idStore for iterator + // with a starting point beyond the last key in the store. + if startingPoint == nil { + c.iters = iters + } else { + for _, it := range iters { + if it.Key().Compare(startingPoint) < 0 { + c.wrapped = append(c.wrapped, it) + } else { + c.iters = append(c.iters, it) + } + } + if len(c.iters) == 0 { + // all iterators wrapped around + c.iters = c.wrapped + c.wrapped = nil + } + } + return &c } func (c *combinedIterator) aheadIterator() iterator { @@ -236,9 +256,7 @@ func (c *combinedIterator) aheadIterator() iterator { } func (c *combinedIterator) Key() hashsync.Ordered { - // return c.aheadIterator().Key() - it := c.aheadIterator() - return it.Key() + return c.aheadIterator().Key() } func (c *combinedIterator) Next() error { diff --git a/sync2/dbsync/dbiter_test.go b/sync2/dbsync/dbiter_test.go index e1dd3beeff..eeac4c779d 100644 --- a/sync2/dbsync/dbiter_test.go +++ b/sync2/dbsync/dbiter_test.go @@ -45,7 +45,7 @@ func createDB(t *testing.T, keyLen int) sql.Database { t.Cleanup(func() { require.NoError(t, db.Close()) }) - _, err := db.Exec(fmt.Sprintf("create table foo(id char(%d))", keyLen), nil, nil) + _, err := db.Exec(fmt.Sprintf("create table foo(id char(%d) not null primary key)", keyLen), nil, nil) require.NoError(t, err) return db } @@ -379,11 +379,11 @@ func TestCombineIterators(t *testing.T) { }, } - it := combineIterators(it1, it2) + it := combineIterators(nil, it1, it2) clonedIt := it.clone() for range 3 { var collected []KeyBytes - for i := 0; i < 4; i++ { + for range 4 { k := it.Key() collected = append(collected, k.(KeyBytes)) require.Equal(t, k, clonedIt.Key()) @@ -403,16 +403,46 @@ func TestCombineIterators(t *testing.T) { it1 = &fakeIterator{allItems: []KeyBytes{KeyBytes{0, 0, 0, 0}, KeyBytes("error")}} it2 = &fakeIterator{allItems: []KeyBytes{KeyBytes{0, 0, 0, 1}}} - it = combineIterators(it1, it2) + it = combineIterators(nil, it1, it2) require.Equal(t, KeyBytes{0, 0, 0, 0}, it.Key()) require.Error(t, it.Next()) it1 = &fakeIterator{allItems: []KeyBytes{KeyBytes{0, 0, 0, 0}}} it2 = &fakeIterator{allItems: []KeyBytes{KeyBytes{0, 0, 0, 1}, KeyBytes("error")}} - it = combineIterators(it1, it2) + it = combineIterators(nil, it1, it2) require.Equal(t, KeyBytes{0, 0, 0, 0}, it.Key()) require.NoError(t, it.Next()) require.Equal(t, KeyBytes{0, 0, 0, 1}, it.Key()) require.Error(t, it.Next()) } + +func TestCombineIteratorsInitiallyWrapped(t *testing.T) { + it1 := &fakeIterator{ + allItems: []KeyBytes{ + {0x00, 0x00, 0x00, 0x01}, + {0x0a, 0x05, 0x00, 0x00}, + }, + } + it2 := &fakeIterator{ + allItems: []KeyBytes{ + {0x00, 0x00, 0x00, 0x03}, + {0xff, 0x00, 0x00, 0x55}, + }, + } + require.NoError(t, it2.Next()) + it := combineIterators(KeyBytes{0xff, 0x00, 0x00, 0x55}, it1, it2) + var collected []KeyBytes + for range 4 { + k := it.Key() + collected = append(collected, k.(KeyBytes)) + require.NoError(t, it.Next()) + } + require.Equal(t, []KeyBytes{ + {0xff, 0x00, 0x00, 0x55}, + {0x00, 0x00, 0x00, 0x01}, + {0x00, 0x00, 0x00, 0x03}, + {0x0a, 0x05, 0x00, 0x00}, + }, collected) + require.Equal(t, KeyBytes{0xff, 0x00, 0x00, 0x55}, it.Key()) +} diff --git a/sync2/dbsync/fptree.go b/sync2/dbsync/fptree.go index 17f165c105..15e55f568a 100644 --- a/sync2/dbsync/fptree.go +++ b/sync2/dbsync/fptree.go @@ -780,7 +780,9 @@ func (ft *fpTree) aggregateSimple(ac *aggContext) error { panic("BUG: bad followedPrefix") } ft.aggregateLeft(lca.left, load64(ac.x)<<(p.len()+1), p.left(), ac) - ft.aggregateRight(lca.right, load64(ac.y)<<(p.len()+1), p.right(), ac) + if ac.limit != 0 { + ft.aggregateRight(lca.right, load64(ac.y)<<(p.len()+1), p.right(), ac) + } case lcaIdx == noIndex || !lca.leaf(): ft.log("commonPrefix %s NOT found b/c no items have it", p) default: @@ -800,7 +802,7 @@ func (ft *fpTree) aggregateInverse(ac *aggContext) error { if idx0 != noIndex { pf0Node = ft.np.node(idx0) } - ft.log("pf0 %s idx0 %d found %v", pf0, idx0, found) + ft.log("pf0 %s idx0 %d found %v followedPrefix %s", pf0, idx0, found, followedPrefix) switch { case found && !pf0Node.leaf(): if followedPrefix != pf0 { diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index a5addfb9bf..286628f661 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -701,7 +701,6 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { // ft.traceEnabled = true var hs []types.Hash32 for _, hex := range tc.ids { - t.Logf("add: %s", hex) h := types.HexToHash32(hex) hs = append(hs, h) ft.addHash(h[:]) @@ -998,6 +997,48 @@ func dumbFP(hs hashList, x, y types.Hash32, limit int) fpResultWithBounds { return fpr } +func verifyInterval(t *testing.T, hs hashList, ft *fpTree, x, y types.Hash32, limit int) fpResult { + expFPR := dumbFP(hs, x, y, limit) + fpr, err := ft.fingerprintInterval(x[:], y[:], limit) + require.NoError(t, err) + require.Equal(t, expFPR, toFPResultWithBounds(fpr), + "x=%s y=%s limit=%d", x.String(), y.String(), limit) + + // QQQQQ: rm + if !reflect.DeepEqual(toFPResultWithBounds(fpr), expFPR) { + t.Logf("QQQQQ: x=%s y=%s", x.String(), y.String()) + for _, h := range hs { + t.Logf("QQQQQ: hash: %s", h.String()) + } + var sb strings.Builder + ft.dump(&sb) + t.Logf("QQQQQ: tree:\n%s", sb.String()) + } + // QQQQQ: /rm + + require.Equal(t, expFPR, toFPResultWithBounds(fpr), + "x=%s y=%s limit=%d", x.String(), y.String(), limit) + + return fpr +} + +func verifySubIntervals(t *testing.T, hs hashList, ft *fpTree, x, y types.Hash32, limit, d int) fpResult { + fpr := verifyInterval(t, hs, ft, x, y, limit) + // t.Logf("verifySubIntervals: x=%s y=%s limit=%d => count %d", x.String(), y.String(), limit, fpr.count) + if fpr.count > 1 { + c := int((fpr.count + 1) / 2) + if limit >= 0 { + require.Less(t, c, limit) + } + part := verifyInterval(t, hs, ft, x, y, c) + var m types.Hash32 + copy(m[:], part.end.Key().(KeyBytes)) + verifySubIntervals(t, hs, ft, x, m, -1, d+1) + verifySubIntervals(t, hs, ft, m, y, -1, d+1) + } + return fpr +} + func testFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool, numItems, maxDepth, repeat int) { var np nodePool ft := newFPTree(&np, idStore, 32, maxDepth) @@ -1031,55 +1072,15 @@ func testFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool, numItems, x = hs[rand.Intn(numItems)] y = hs[rand.Intn(numItems)] } - expFPR := dumbFP(hs, x, y, -1) - fpr, err := ft.fingerprintInterval(x[:], y[:], -1) - require.NoError(t, err) - - // QQQQQ: rm - if !reflect.DeepEqual(toFPResultWithBounds(fpr), expFPR) { - t.Logf("QQQQQ: x=%s y=%s", x.String(), y.String()) - for _, h := range hs { - t.Logf("QQQQQ: hash: %s", h.String()) - } - var sb strings.Builder - ft.dump(&sb) - t.Logf("QQQQQ: tree:\n%s", sb.String()) - } - // QQQQQ: /rm - - require.Equal(t, expFPR, toFPResultWithBounds(fpr), - "x=%s y=%s", x.String(), y.String()) - - limit := 0 - if fpr.count != 0 { - limit = rand.Intn(int(fpr.count)) - } - expFPR = dumbFP(hs, x, y, limit) - fpr, err = ft.fingerprintInterval(x[:], y[:], limit) - require.NoError(t, err) - - // QQQQQ: rm - if !reflect.DeepEqual(toFPResultWithBounds(fpr), expFPR) { - t.Logf("QQQQQ: x=%s y=%s", x.String(), y.String()) - for _, h := range hs { - t.Logf("QQQQQ: hash: %s", h.String()) - } - var sb strings.Builder - ft.dump(&sb) - t.Logf("QQQQQ: tree:\n%s", sb.String()) - } - // QQQQQ: /rm - - require.Equal(t, expFPR, toFPResultWithBounds(fpr), - "x=%s y=%s limit=%d", x.String(), y.String(), limit) + verifySubIntervals(t, hs, ft, x, y, -1, 0) } } func TestFPTreeManyItems(t *testing.T) { const ( - repeatOuter = 30 - repeatInner = 20 - numItems = 1 << 13 + repeatOuter = 3 + repeatInner = 5 + numItems = 1 << 10 maxDepth = 12 // numItems = 1 << 5 // maxDepth = 4 diff --git a/sync2/dbsync/p2p_test.go b/sync2/dbsync/p2p_test.go new file mode 100644 index 0000000000..b258735d94 --- /dev/null +++ b/sync2/dbsync/p2p_test.go @@ -0,0 +1,242 @@ +package dbsync + +import ( + "context" + "errors" + "io" + "slices" + "testing" + "time" + + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + "golang.org/x/sync/errgroup" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/common/util" + "github.com/spacemeshos/go-spacemesh/p2p/server" + "github.com/spacemeshos/go-spacemesh/sync2/hashsync" +) + +func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { + log := zaptest.NewLogger(t) + dbA := populateDB(t, 32, itemsA) + dbB := populateDB(t, 32, itemsB) + mesh, err := mocknet.FullMeshConnected(2) + require.NoError(t, err) + proto := "itest" + storeA := NewItemStoreAdapter(NewDBItemStore(dbA, "select id from foo", testQuery, 32, 24)) + storeB := NewItemStoreAdapter(NewDBItemStore(dbB, "select id from foo", testQuery, 32, 24)) + + // QQQQQ: rmme + // storeB.s.ft.traceEnabled = true + // storeB.qqqqq = true + // require.NoError(t, storeB.s.EnsureLoaded()) + // var sb strings.Builder + // storeA.s.ft.dump(&sb) + // t.Logf("storeA:\n%s", sb.String()) + // sb = strings.Builder{} + // storeB.s.ft.dump(&sb) + // t.Logf("storeB:\n%s", sb.String()) + + srvPeerID := mesh.Hosts()[0].ID() + srv := server.New(mesh.Hosts()[0], proto, + func(ctx context.Context, req []byte, stream io.ReadWriter) error { + pss := hashsync.NewPairwiseStoreSyncer(nil, []hashsync.RangeSetReconcilerOption{ + hashsync.WithMaxSendRange(1), + // uncomment to enable verbose logging which may slow down tests + // hashsync.WithRangeSyncLogger(log.Named("sideA")), + }) + return pss.Serve(ctx, req, stream, storeA) + }, + server.WithTimeout(10*time.Second), + server.WithLog(log)) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + var eg errgroup.Group + + client := server.New(mesh.Hosts()[1], proto, + func(ctx context.Context, req []byte, stream io.ReadWriter) error { + return errors.New("client should not receive requests") + }, + server.WithTimeout(10*time.Second), + server.WithLog(log)) + + defer func() { + cancel() + eg.Wait() + }() + eg.Go(func() error { + return srv.Run(ctx) + }) + eg.Go(func() error { + // TBD: this probably isn't needed + return client.Run(ctx) + }) + + require.Eventually(t, func() bool { + for _, h := range mesh.Hosts() { + if len(h.Mux().Protocols()) == 0 { + return false + } + } + return true + }, time.Second, 10*time.Millisecond) + + pss := hashsync.NewPairwiseStoreSyncer(client, []hashsync.RangeSetReconcilerOption{ + hashsync.WithMaxSendRange(1), + // uncomment to enable verbose logging which may slow down tests + // hashsync.WithRangeSyncLogger(log.Named("sideB")), + }) + require.NoError(t, pss.SyncStore(ctx, srvPeerID, storeB, nil, nil)) + + // // QQQQQ: rmme + // sb = strings.Builder{} + // storeA.s.ft.dump(&sb) + // t.Logf("storeA post-sync:\n%s", sb.String()) + // sb = strings.Builder{} + // storeB.s.ft.dump(&sb) + // t.Logf("storeB post-sync:\n%s", sb.String()) + + if len(combinedItems) == 0 { + return + } + it, err := storeA.Min() + require.NoError(t, err) + var actItemsA []KeyBytes + if len(combinedItems) == 0 { + assert.Nil(t, it) + } else { + for range combinedItems { + // t.Logf("synced itemA: %s", it.Key().(types.Hash32).String()) + h := it.Key().(types.Hash32) + actItemsA = append(actItemsA, h[:]) + require.NoError(t, it.Next()) + } + h := it.Key().(types.Hash32) + assert.Equal(t, actItemsA[0], KeyBytes(h[:])) + } + + it, err = storeB.Min() + require.NoError(t, err) + var actItemsB []KeyBytes + if len(combinedItems) == 0 { + assert.Nil(t, it) + } else { + for range combinedItems { + // t.Logf("synced itemB: %s", it.Key().(types.Hash32).String()) + h := it.Key().(types.Hash32) + actItemsB = append(actItemsB, h[:]) + require.NoError(t, it.Next()) + } + h := it.Key().(types.Hash32) + assert.Equal(t, actItemsB[0], KeyBytes(h[:])) + assert.Equal(t, combinedItems, actItemsA) + assert.Equal(t, actItemsA, actItemsB) + } +} + +func TestP2P(t *testing.T) { + t.Run("predefined items", func(t *testing.T) { + verifyP2P( + t, []KeyBytes{ + util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), + util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), + util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), + util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000"), + }, + []KeyBytes{ + util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), + util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), + util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), + util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), + }, + []KeyBytes{ + util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), + util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), + util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), + util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), + util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000"), + }) + }) + t.Run("predefined items 2", func(t *testing.T) { + verifyP2P( + t, []KeyBytes{ + util.FromHex("80e95b39faa731eb50eae7585a8b1cae98f503481f950fdb690e60ff86c21236"), + util.FromHex("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7"), + util.FromHex("d862b2413af5c252028e8f9871be8297e807661d64decd8249ac2682db168b90"), + util.FromHex("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567"), + }, + []KeyBytes{ + util.FromHex("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7"), + util.FromHex("bc6218a88d1648b8145fbf93ae74af8975f193af88788e7add3608e0bc50f701"), + util.FromHex("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567"), + util.FromHex("fbf03324234f79a3fe0587cf5505d7e4c826cb2be38d72eafa60296ed77b3f8f"), + }, + []KeyBytes{ + util.FromHex("80e95b39faa731eb50eae7585a8b1cae98f503481f950fdb690e60ff86c21236"), + util.FromHex("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7"), + util.FromHex("bc6218a88d1648b8145fbf93ae74af8975f193af88788e7add3608e0bc50f701"), + util.FromHex("d862b2413af5c252028e8f9871be8297e807661d64decd8249ac2682db168b90"), + util.FromHex("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567"), + util.FromHex("fbf03324234f79a3fe0587cf5505d7e4c826cb2be38d72eafa60296ed77b3f8f"), + }) + }) + t.Run("empty to non-empty", func(t *testing.T) { + verifyP2P( + t, nil, + []KeyBytes{ + util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), + util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), + util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), + util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), + }, + []KeyBytes{ + util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), + util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), + util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), + util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), + }) + }) + t.Run("empty to empty", func(t *testing.T) { + verifyP2P(t, nil, nil, nil) + }) + t.Run("random test", func(t *testing.T) { + // TODO: increase these values and profile + const nShared = 8000 + const nUniqueA = 400 + const nUniqueB = 800 + // const nShared = 2 + // const nUniqueA = 2 + // const nUniqueB = 2 + combined := make([]KeyBytes, 0, nShared+nUniqueA+nUniqueB) + itemsA := make([]KeyBytes, nShared+nUniqueA) + for i := range itemsA { + h := types.RandomHash() + itemsA[i] = KeyBytes(h[:]) + combined = append(combined, itemsA[i]) + // t.Logf("itemsA[%d] = %s", i, itemsA[i]) + } + itemsB := make([]KeyBytes, nShared+nUniqueB) + for i := range itemsB { + if i < nShared { + itemsB[i] = slices.Clone(itemsA[i]) + } else { + h := types.RandomHash() + itemsB[i] = KeyBytes(h[:]) + combined = append(combined, itemsB[i]) + } + // t.Logf("itemsB[%d] = %s", i, itemsB[i]) + } + slices.SortFunc(combined, func(a, b KeyBytes) int { + return a.Compare(b) + }) + // for i, v := range combined { + // t.Logf("combined[%d] = %s", i, v) + // } + verifyP2P(t, itemsA, itemsB, combined) + // TODO: multiple iterations + }) +} diff --git a/sync2/dbsync/sqlidstore.go b/sync2/dbsync/sqlidstore.go index d7bda85b14..8168f48919 100644 --- a/sync2/dbsync/sqlidstore.go +++ b/sync2/dbsync/sqlidstore.go @@ -77,7 +77,7 @@ func (s *dbBackedStore) start() (iterator, error) { } memIt, err := s.inMemIDStore.start() if err == nil { - return combineIterators(dbIt, memIt), nil + return combineIterators(nil, dbIt, memIt), nil } else if errors.Is(err, errEmptySet) { return dbIt, nil } @@ -94,7 +94,7 @@ func (s *dbBackedStore) iter(from KeyBytes) (iterator, error) { } memIt, err := s.inMemIDStore.iter(from) if err == nil { - return combineIterators(dbIt, memIt), nil + return combineIterators(from, dbIt, memIt), nil } else if errors.Is(err, errEmptySet) { return dbIt, nil } diff --git a/sync2/hashsync/log.go b/sync2/hashsync/log.go new file mode 100644 index 0000000000..863f273ae2 --- /dev/null +++ b/sync2/hashsync/log.go @@ -0,0 +1,55 @@ +package hashsync + +import ( + "encoding/hex" + "reflect" + + "go.uber.org/zap" + + "github.com/spacemeshos/go-spacemesh/common/types" +) + +func IteratorField(name string, it Iterator) zap.Field { + if it == nil { + return zap.String(name, "") + } + return HexField(name, it.Key()) +} + +// based on code from testify +func isNil(object any) bool { + if object == nil { + return true + } + + value := reflect.ValueOf(object) + switch value.Kind() { + case + reflect.Chan, reflect.Func, + reflect.Interface, reflect.Map, + reflect.Ptr, reflect.Slice, reflect.UnsafePointer: + + return value.IsNil() + } + + return false +} + +func HexField(name string, k any) zap.Field { + switch h := k.(type) { + case types.Hash32: + return zap.String(name, h.ShortString()) + case types.Hash12: + return zap.String(name, hex.EncodeToString(h[:5])) + case []byte: + if len(h) > 5 { + h = h[:5] + } + return zap.String(name, hex.EncodeToString(h[:5])) + default: + if isNil(k) { + return zap.String(name, "") + } + panic("unexpected type") + } +} diff --git a/sync2/hashsync/rangesync.go b/sync2/hashsync/rangesync.go index 7dfce5494c..77bae1608d 100644 --- a/sync2/hashsync/rangesync.go +++ b/sync2/hashsync/rangesync.go @@ -8,6 +8,8 @@ import ( "reflect" "slices" "strings" + + "go.uber.org/zap" ) const ( @@ -64,19 +66,19 @@ func SyncMessageToString(m SyncMessage) string { sb.WriteString("<" + m.Type().String()) if x := m.X(); x != nil { - sb.WriteString(" X=" + x.(fmt.Stringer).String()) + sb.WriteString(" X=" + x.(fmt.Stringer).String()[:10]) } if y := m.Y(); y != nil { - sb.WriteString(" Y=" + y.(fmt.Stringer).String()) + sb.WriteString(" Y=" + y.(fmt.Stringer).String()[:10]) } if count := m.Count(); count != 0 { fmt.Fprintf(&sb, " Count=%d", count) } if fp := m.Fingerprint(); fp != nil { - sb.WriteString(" FP=" + fp.(fmt.Stringer).String()) + sb.WriteString(" FP=" + fp.(fmt.Stringer).String()[:10]) } for _, k := range m.Keys() { - fmt.Fprintf(&sb, " item=%s", k.(fmt.Stringer).String()) + fmt.Fprintf(&sb, " item=%s", k.(fmt.Stringer).String()[:10]) } sb.WriteString(">") return sb.String() @@ -143,6 +145,14 @@ func WithSampleSize(s int) RangeSetReconcilerOption { } } +// TODO: RangeSetReconciler should sit in a separate package +// and WithRangeSyncLogger should be named WithLogger +func WithRangeSyncLogger(log *zap.Logger) RangeSetReconcilerOption { + return func(r *RangeSetReconciler) { + r.log = log + } +} + type ProbeResult struct { FP any Count int @@ -154,6 +164,7 @@ type RangeSetReconciler struct { maxSendRange int itemChunkSize int sampleSize int + log *zap.Logger } func NewRangeSetReconciler(is ItemStore, opts ...RangeSetReconcilerOption) *RangeSetReconciler { @@ -162,6 +173,7 @@ func NewRangeSetReconciler(is ItemStore, opts ...RangeSetReconcilerOption) *Rang maxSendRange: DefaultMaxSendRange, itemChunkSize: DefaultItemChunkSize, sampleSize: DefaultSampleSize, + log: zap.NewNop(), } for _, opt := range opts { opt(rsr) @@ -189,6 +201,7 @@ func (rsr *RangeSetReconciler) processSubrange(c Conduit, preceding Iterator, x, // fmt.Fprintf(os.Stderr, "QQQQQ: preceding=%q\n", // qqqqRmmeK(preceding)) // TODO: don't re-request range info for the first part of range after stop + rsr.log.Debug("processSubrange", IteratorField("preceding", preceding), HexField("x", x), HexField("y", y)) info, err := rsr.is.GetRangeInfo(preceding, x, y, -1) if err != nil { return nil, err @@ -212,12 +225,15 @@ func (rsr *RangeSetReconciler) processSubrange(c Conduit, preceding Iterator, x, case info.Count == 0: // We have no more items in this subrange. // Ask peer to send any items it has in the range + rsr.log.Debug("processSubrange: send empty range", HexField("x", x), HexField("y", y)) if err := c.SendEmptyRange(x, y); err != nil { return nil, err } default: // The range is non-empty and large enough. // Send fingerprint so that the peer can further subdivide it. + rsr.log.Debug("processSubrange: send fingerprint", HexField("x", x), HexField("y", y), + zap.Int("count", info.Count)) if err := c.SendFingerprint(x, y, info.Fingerprint, info.Count); err != nil { return nil, err } @@ -227,6 +243,8 @@ func (rsr *RangeSetReconciler) processSubrange(c Conduit, preceding Iterator, x, } func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg SyncMessage) (it Iterator, done bool, err error) { + rsr.log.Debug("handleMessage", IteratorField("preceding", preceding), + zap.String("msg", SyncMessageToString(msg))) x := msg.X() y := msg.Y() done = true @@ -244,6 +262,10 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg if err != nil { return nil, false, err } + rsr.log.Debug("handleMessage: send probe response", + HexField("fingerpint", info.Fingerprint), + zap.Int("count", info.Count), + IteratorField("it", it)) if err := c.SendProbeResponse(x, y, info.Fingerprint, info.Count, 0, it); err != nil { return nil, false, err } @@ -259,6 +281,13 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg if err != nil { return nil, false, err } + rsr.log.Debug("handleMessage: range info", + HexField("x", x), HexField("y", y), + IteratorField("start", info.Start), + IteratorField("end", info.End), + zap.Int("count", info.Count), + HexField("fingerprint", info.Fingerprint)) + // fmt.Fprintf(os.Stderr, "QQQQQ msg %s %#v fp %v start %#v end %#v count %d\n", msg.Type(), msg, info.Fingerprint, info.Start, info.End, info.Count) switch { case msg.Type() == MessageTypeEmptyRange || @@ -271,9 +300,13 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg // the range doesn't need any further handling by the peer. if info.Count != 0 { done = false + rsr.log.Debug("handleMessage: send items", zap.Int("count", info.Count), + IteratorField("start", info.Start)) if err := c.SendItems(info.Count, rsr.itemChunkSize, info.Start); err != nil { return nil, false, err } + } else { + rsr.log.Debug("handleMessage: local range is empty") } case msg.Type() == MessageTypeProbe: sampleSize := msg.Count() @@ -306,6 +339,8 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg // empty on our side done = false if info.Count != 0 { + rsr.log.Debug("handleMessage: send small range", + HexField("x", x), HexField("y", y), zap.Int("count", info.Count)) // fmt.Fprintf(os.Stderr, "small incoming range: %s -> SendItems\n", msg) if err := c.SendRangeContents(x, y, info.Count); err != nil { return nil, false, err @@ -314,6 +349,8 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg return nil, false, err } } else { + rsr.log.Debug("handleMessage: empty incoming range", + HexField("x", x), HexField("y", y)) // fmt.Fprintf(os.Stderr, "small incoming range: %s -> empty range msg\n", msg) if err := c.SendEmptyRange(x, y); err != nil { return nil, false, err @@ -323,7 +360,12 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg // Need to split the range. // Note that there's no special handling for rollover ranges with x >= y // These need to be handled by ItemStore.GetRangeInfo() + // TODO: instead of count-based split, use split between X and Y with + // lower bits set to zero to avoid SQL queries on the edges count := (info.Count + 1) / 2 + rsr.log.Debug("handleMessage: PRE split range", + HexField("x", x), HexField("y", y), + zap.Int("countArg", count)) part, err := rsr.is.GetRangeInfo(preceding, x, y, count) if err != nil { return nil, false, err @@ -331,6 +373,13 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg if part.End == nil { panic("BUG: can't split range with count > 1") } + rsr.log.Debug("handleMessage: split range", + HexField("x", x), HexField("y", y), + zap.Int("countArg", count), + zap.Int("count", part.Count), + HexField("fingerprint", part.Fingerprint), + IteratorField("start", part.Start), + IteratorField("middle", part.End)) middle := part.End.Key() next, err := rsr.processSubrange(c, info.Start, x, middle) if err != nil { @@ -361,7 +410,9 @@ func (rsr *RangeSetReconciler) Initiate(c Conduit) error { } func (rsr *RangeSetReconciler) InitiateBounded(c Conduit, x, y Ordered) error { + rsr.log.Debug("inititate", HexField("x", x), HexField("y", y)) if x == nil { + rsr.log.Debug("initiate: send empty set") if err := c.SendEmptySet(); err != nil { return err } @@ -374,6 +425,7 @@ func (rsr *RangeSetReconciler) InitiateBounded(c Conduit, x, y Ordered) error { case info.Count == 0: panic("empty full min-min range") case info.Count < rsr.maxSendRange: + rsr.log.Debug("initiate: send whole range", zap.Int("count", info.Count)) if err := c.SendRangeContents(x, y, info.Count); err != nil { return err } @@ -381,6 +433,7 @@ func (rsr *RangeSetReconciler) InitiateBounded(c Conduit, x, y Ordered) error { return err } default: + rsr.log.Debug("initiate: send fingerprint", zap.Int("count", info.Count)) if err := c.SendFingerprint(x, y, info.Fingerprint, info.Count); err != nil { return err } @@ -546,6 +599,7 @@ func (rsr *RangeSetReconciler) Process(ctx context.Context, c Conduit) (done boo for _, msg := range msgs { if msg.Type() == MessageTypeItemBatch { for _, k := range msg.Keys() { + rsr.log.Debug("Process: add item", HexField("item", k)) if err := rsr.is.Add(ctx, k); err != nil { return false, fmt.Errorf("error adding an item to the store: %w", err) } From dc5b1f5515c0c14388993938dc3b8f11adcf9bc9 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 31 Jul 2024 02:20:15 +0400 Subject: [PATCH 53/76] wip --- sync2/dbsync/dbiter.go | 2 + sync2/dbsync/fptree.go | 51 +++++++++++++++-- sync2/dbsync/fptree_test.go | 109 +++++++++++++++++++++++++++++++++--- sync2/hashsync/log.go | 7 ++- sync2/hashsync/rangesync.go | 24 ++++++-- 5 files changed, 174 insertions(+), 19 deletions(-) diff --git a/sync2/dbsync/dbiter.go b/sync2/dbsync/dbiter.go index f6b4e2f391..59bd629d58 100644 --- a/sync2/dbsync/dbiter.go +++ b/sync2/dbsync/dbiter.go @@ -92,6 +92,8 @@ func newDBRangeIterator( chunk: make([]KeyBytes, maxChunkSize), singleChunk: false, } + // panic("TBD: QQQQQ: do not preload the iterator! Key should panic upon no entries. With from > max item, iterator should work, wrapping around (TEST)!") + // panic("TBD: QQQQQ: Key() should return an error!") if err := it.load(); err != nil { return nil, err } diff --git a/sync2/dbsync/fptree.go b/sync2/dbsync/fptree.go index 15e55f568a..8fc2aca0ce 100644 --- a/sync2/dbsync/fptree.go +++ b/sync2/dbsync/fptree.go @@ -302,14 +302,50 @@ type aggContext struct { lastPrefix *prefix } +// prefixAtOrAfterX verifies that the any key with the prefix p is at or after x. +// It can be used for the whole interval in case of a normal interval. +// With inverse intervals, it should only be used for the [x, max) part +// of the interval. func (ac *aggContext) prefixAtOrAfterX(p prefix) bool { return p.bits()<<(64-p.len()) >= load64(ac.x) } +// prefixBelowY verifies that the any key with the prefix p is below y. +// It can be used for the whole interval in case of a normal interval. +// With inverse intervals, it should only be used for the [0, y) part +// of the interval. func (ac *aggContext) prefixBelowY(p prefix) bool { + // QQQQQ: TBD: <= must work, check !!!!! return (p.bits()+1)<<(64-p.len())-1 < load64(ac.y) } +func (ac *aggContext) fingreprintAtOrAfterX(fp fingerprint) bool { + k := make(KeyBytes, len(ac.x)) + copy(k, fp[:]) + return bytes.Compare(k, ac.x) >= 0 +} + +func (ac *aggContext) fingreprintBelowY(fp fingerprint) bool { + k := make(KeyBytes, len(ac.x)) + copy(k, fp[:]) + k[:fingerprintBytes].inc() // 1 after max key derived from the fingerprint + return bytes.Compare(k, ac.y) <= 0 +} + +func (ac *aggContext) nodeAtOrAfterX(node node, p prefix) bool { + if node.c == 1 { + return ac.fingreprintAtOrAfterX(node.fp) + } + return ac.prefixAtOrAfterX(p) +} + +func (ac *aggContext) nodeBelowY(node node, p prefix) bool { + if node.c == 1 { + return ac.fingreprintBelowY(node.fp) + } + return ac.prefixBelowY(p) +} + func (ac *aggContext) maybeIncludeNode(node node, p prefix) bool { switch { case ac.limit < 0: @@ -688,13 +724,16 @@ func (ft *fpTree) aggregateLeft(idx nodeIndex, v uint64, p prefix, ac *aggContex // case ac.limit == 0: // ft.log("stop: limit exhausted") // return false, ft.markEnd(p, ac) - case ac.prefixAtOrAfterX(p) && ac.maybeIncludeNode(node, p): + case ac.nodeAtOrAfterX(node, p) && ac.maybeIncludeNode(node, p): ft.log("including node in full: %s limit %d", p, ac.limit) return ac.limit != 0, nil case p.len() == ft.maxDepth || node.leaf(): if node.left != noIndex || node.right != noIndex { panic("BUG: node @ maxDepth has children") } + // if node.c == 1 {; + // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateLeft: edge with x %s p %s limit %d\n", ac.x, p, ac.limit) + // } return ft.aggregateEdge(ac.x, nil, p, ac) case v&bit63 == 0: ft.log("incl right node %d + go left to node %d", node.right, node.left) @@ -725,7 +764,7 @@ func (ft *fpTree) aggregateRight(idx nodeIndex, v uint64, p prefix, ac *aggConte // case ac.limit == 0: // ft.log("stop: limit exhausted") // return false, ft.markEnd(p, ac) - case ac.prefixBelowY(p) && ac.maybeIncludeNode(node, p): + case ac.nodeBelowY(node, p) && ac.maybeIncludeNode(node, p): ft.log("including node in full: %s limit %d", p, ac.limit) return ac.limit != 0, nil case p.len() == ft.maxDepth || node.leaf(): @@ -785,7 +824,11 @@ func (ft *fpTree) aggregateSimple(ac *aggContext) error { } case lcaIdx == noIndex || !lca.leaf(): ft.log("commonPrefix %s NOT found b/c no items have it", p) + case ac.nodeAtOrAfterX(lca, lcaPrefix) && ac.nodeBelowY(lca, lcaPrefix) && + ac.maybeIncludeNode(lca, lcaPrefix): + ft.log("commonPrefix %s -- lca node %d included in full", p, lcaIdx) default: + //ac.prefixAtOrAfterX(lcaPrefix) && ac.prefixBelowY(lcaPrefix): ft.log("commonPrefix %s -- lca %d", p, lcaIdx) _, err := ft.aggregateEdge(ac.x, ac.y, lcaPrefix, ac) return err @@ -811,7 +854,7 @@ func (ft *fpTree) aggregateInverse(ac *aggContext) error { ft.aggregateLeft(idx0, load64(ac.x)<") } - panic("unexpected type") + panic("unexpected type: " + reflect.TypeOf(k).String()) } } diff --git a/sync2/hashsync/rangesync.go b/sync2/hashsync/rangesync.go index 77bae1608d..a10f8e2876 100644 --- a/sync2/hashsync/rangesync.go +++ b/sync2/hashsync/rangesync.go @@ -61,24 +61,38 @@ type SyncMessage interface { Keys() []Ordered } +func formatID(v any) string { + switch v := v.(type) { + case fmt.Stringer: + s := v.String() + if len(s) > 10 { + return s[:10] + } + return s + case string: + return v + default: + return "" + } +} + func SyncMessageToString(m SyncMessage) string { var sb strings.Builder sb.WriteString("<" + m.Type().String()) - if x := m.X(); x != nil { - sb.WriteString(" X=" + x.(fmt.Stringer).String()[:10]) + sb.WriteString(" X=" + formatID(x)) } if y := m.Y(); y != nil { - sb.WriteString(" Y=" + y.(fmt.Stringer).String()[:10]) + sb.WriteString(" Y=" + formatID(y)) } if count := m.Count(); count != 0 { fmt.Fprintf(&sb, " Count=%d", count) } if fp := m.Fingerprint(); fp != nil { - sb.WriteString(" FP=" + fp.(fmt.Stringer).String()[:10]) + sb.WriteString(" FP=" + formatID(fp)) } for _, k := range m.Keys() { - fmt.Fprintf(&sb, " item=%s", k.(fmt.Stringer).String()[:10]) + fmt.Fprintf(&sb, " item=%s", formatID(k)) } sb.WriteString(">") return sb.String() From 6ec95f68bc53074c969dd6968105c08b84ed7b80 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 31 Jul 2024 03:59:10 +0400 Subject: [PATCH 54/76] wip2 --- sync2/dbsync/dbitemstore.go | 51 ++++++----- sync2/dbsync/dbitemstore_test.go | 26 +++--- sync2/dbsync/dbiter.go | 135 +++++++++++++++++++++--------- sync2/dbsync/dbiter_test.go | 55 ++++++------ sync2/dbsync/fptree.go | 90 +++++++++++++------- sync2/dbsync/fptree_test.go | 32 +++---- sync2/dbsync/inmemidstore.go | 22 ++--- sync2/dbsync/inmemidstore_test.go | 25 +++--- sync2/dbsync/p2p_test.go | 20 +++-- sync2/dbsync/sqlidstore.go | 43 +++------- sync2/dbsync/sqlidstore_test.go | 15 ++-- sync2/hashsync/handler.go | 13 ++- sync2/hashsync/handler_test.go | 21 +++-- sync2/hashsync/interface.go | 2 +- sync2/hashsync/log.go | 6 +- sync2/hashsync/rangesync.go | 45 +++++++--- sync2/hashsync/rangesync_test.go | 75 ++++++++++++----- sync2/hashsync/setsyncbase.go | 5 +- sync2/hashsync/sync_tree_store.go | 9 +- sync2/p2p_test.go | 4 +- 20 files changed, 427 insertions(+), 267 deletions(-) diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go index a9fa4972b8..50ca936b56 100644 --- a/sync2/dbsync/dbitemstore.go +++ b/sync2/dbsync/dbitemstore.go @@ -2,7 +2,6 @@ package dbsync import ( "context" - "errors" "sync" "github.com/spacemeshos/go-spacemesh/common/types" @@ -65,7 +64,9 @@ func (d *DBItemStore) EnsureLoaded() error { // Add implements hashsync.ItemStore. func (d *DBItemStore) Add(ctx context.Context, k hashsync.Ordered) error { - d.EnsureLoaded() + if err := d.EnsureLoaded(); err != nil { + return err + } has, err := d.Has(k) // TODO: this check shouldn't be needed if has || err != nil { return err @@ -79,7 +80,9 @@ func (d *DBItemStore) GetRangeInfo( x, y hashsync.Ordered, count int, ) (hashsync.RangeInfo, error) { - d.EnsureLoaded() + if err := d.EnsureLoaded(); err != nil { + return hashsync.RangeInfo{}, err + } fpr, err := d.ft.fingerprintInterval(x.(KeyBytes), y.(KeyBytes), count) if err != nil { return hashsync.RangeInfo{}, err @@ -94,16 +97,17 @@ func (d *DBItemStore) GetRangeInfo( // Min implements hashsync.ItemStore. func (d *DBItemStore) Min() (hashsync.Iterator, error) { - d.EnsureLoaded() - it, err := d.ft.start() - switch { - case err == nil: - return it, nil - case errors.Is(err, errEmptySet): + if err := d.EnsureLoaded(); err != nil { + return nil, err + } + if d.ft.count() == 0 { return nil, nil - default: + } + it := d.ft.start() + if _, err := it.Key(); err != nil { return nil, err } + return it, nil } // Copy implements hashsync.ItemStore. @@ -122,17 +126,20 @@ func (d *DBItemStore) Copy() hashsync.ItemStore { // Has implements hashsync.ItemStore. func (d *DBItemStore) Has(k hashsync.Ordered) (bool, error) { - d.EnsureLoaded() + if err := d.EnsureLoaded(); err != nil { + return false, err + } + if d.ft.count() == 0 { + return false, nil + } // TODO: should often be able to avoid querying the database if we check the key // against the fptree - it, err := d.ft.iter(k.(KeyBytes)) - if err == nil { - return k.Compare(it.Key()) == 0, nil - } - if err != errEmptySet { + it := d.ft.iter(k.(KeyBytes)) + itK, err := it.Key() + if err != nil { return false, err } - return false, nil + return itK.Compare(k) == 0, nil } // TODO: get rid of ItemStoreAdapter, it shouldn't be needed @@ -204,10 +211,14 @@ type iteratorAdapter struct { var _ hashsync.Iterator = &iteratorAdapter{} -func (ia *iteratorAdapter) Key() hashsync.Ordered { +func (ia *iteratorAdapter) Key() (hashsync.Ordered, error) { + k, err := ia.it.Key() + if err != nil { + return nil, err + } var h types.Hash32 - copy(h[:], ia.it.Key().(KeyBytes)) - return h + copy(h[:], k.(KeyBytes)) + return h, nil } func (ia *iteratorAdapter) Next() error { diff --git a/sync2/dbsync/dbitemstore_test.go b/sync2/dbsync/dbitemstore_test.go index aae101866b..8af95aa786 100644 --- a/sync2/dbsync/dbitemstore_test.go +++ b/sync2/dbsync/dbitemstore_test.go @@ -50,7 +50,7 @@ func TestDBItemStore(t *testing.T) { it, err := s.Min() require.NoError(t, err) require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", - it.Key().(KeyBytes).String()) + itKey(t, it).String()) has, err := s.Has(KeyBytes(util.FromHex("9876000000000000000000000000000000000000000000000000000000000000"))) require.NoError(t, err) require.False(t, has) @@ -115,8 +115,8 @@ func TestDBItemStore(t *testing.T) { require.NoError(t, err) require.Equal(t, tc.count, info.Count) require.Equal(t, tc.fp, info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[tc.startIdx], info.Start.Key().(KeyBytes)) - require.Equal(t, ids[tc.endIdx], info.End.Key().(KeyBytes)) + require.Equal(t, ids[tc.startIdx], itKey(t, info.Start)) + require.Equal(t, ids[tc.endIdx], itKey(t, info.End)) has, err := s.Has(ids[tc.startIdx]) require.NoError(t, err) require.True(t, has) @@ -139,7 +139,7 @@ func TestDBItemStoreAdd(t *testing.T) { it, err := s.Min() require.NoError(t, err) require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", - it.Key().(KeyBytes).String()) + itKey(t, it).String()) newID := KeyBytes(util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000")) require.NoError(t, s.Add(context.Background(), newID)) @@ -154,8 +154,8 @@ func TestDBItemStoreAdd(t *testing.T) { require.NoError(t, err) require.Equal(t, 3, info.Count) require.Equal(t, "761032cfe98ba54ddddddddd", info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[2], info.Start.Key().(KeyBytes)) - require.Equal(t, ids[0], info.End.Key().(KeyBytes)) + require.Equal(t, ids[2], itKey(t, info.Start)) + require.Equal(t, ids[0], itKey(t, info.End)) } func TestDBItemStoreCopy(t *testing.T) { @@ -170,7 +170,7 @@ func TestDBItemStoreCopy(t *testing.T) { it, err := s.Min() require.NoError(t, err) require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", - it.Key().(KeyBytes).String()) + itKey(t, it).String()) copy := s.Copy() @@ -178,8 +178,8 @@ func TestDBItemStoreCopy(t *testing.T) { require.NoError(t, err) require.Equal(t, 2, info.Count) require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[2], info.Start.Key().(KeyBytes)) - require.Equal(t, ids[0], info.End.Key().(KeyBytes)) + require.Equal(t, ids[2], itKey(t, info.Start)) + require.Equal(t, ids[0], itKey(t, info.End)) newID := KeyBytes(util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000")) require.NoError(t, copy.Add(context.Background(), newID)) @@ -188,13 +188,13 @@ func TestDBItemStoreCopy(t *testing.T) { require.NoError(t, err) require.Equal(t, 2, info.Count) require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[2], info.Start.Key().(KeyBytes)) - require.Equal(t, ids[0], info.End.Key().(KeyBytes)) + require.Equal(t, ids[2], itKey(t, info.Start)) + require.Equal(t, ids[0], itKey(t, info.End)) info, err = copy.GetRangeInfo(nil, ids[2], ids[0], -1) require.NoError(t, err) require.Equal(t, 3, info.Count) require.Equal(t, "761032cfe98ba54ddddddddd", info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[2], info.Start.Key().(KeyBytes)) - require.Equal(t, ids[0], info.End.Key().(KeyBytes)) + require.Equal(t, ids[2], itKey(t, info.Start)) + require.Equal(t, ids[0], itKey(t, info.End)) } diff --git a/sync2/dbsync/dbiter.go b/sync2/dbsync/dbiter.go index 59bd629d58..e834352d32 100644 --- a/sync2/dbsync/dbiter.go +++ b/sync2/dbsync/dbiter.go @@ -64,6 +64,7 @@ type dbRangeIterator struct { pos int keyLen int singleChunk bool + loaded bool } var _ iterator = &dbRangeIterator{} @@ -75,14 +76,16 @@ func newDBRangeIterator( query string, from KeyBytes, maxChunkSize int, -) (iterator, error) { +) iterator { if from == nil { panic("BUG: makeDBIterator: nil from") } if maxChunkSize <= 0 { panic("BUG: makeDBIterator: chunkSize must be > 0") } - it := &dbRangeIterator{ + // panic("TBD: QQQQQ: do not preload the iterator! Key should panic upon no entries. With from > max item, iterator should work, wrapping around (TEST)!") + // panic("TBD: QQQQQ: Key() should return an error!") + return &dbRangeIterator{ db: db, from: from.Clone(), query: query, @@ -91,13 +94,8 @@ func newDBRangeIterator( keyLen: len(from), chunk: make([]KeyBytes, maxChunkSize), singleChunk: false, + loaded: false, } - // panic("TBD: QQQQQ: do not preload the iterator! Key should panic upon no entries. With from > max item, iterator should work, wrapping around (TEST)!") - // panic("TBD: QQQQQ: Key() should return an error!") - if err := it.load(); err != nil { - return nil, err - } - return it, nil } func (it *dbRangeIterator) load() error { @@ -173,15 +171,29 @@ func (it *dbRangeIterator) load() error { return nil } -func (it *dbRangeIterator) Key() hashsync.Ordered { +func (it *dbRangeIterator) Key() (hashsync.Ordered, error) { + if !it.loaded { + if err := it.load(); err != nil { + return nil, err + } + it.loaded = true + } if it.pos < len(it.chunk) { - return slices.Clone(it.chunk[it.pos]) + return slices.Clone(it.chunk[it.pos]), nil } - return nil + return nil, errEmptySet } func (it *dbRangeIterator) Next() error { - if it.pos >= len(it.chunk) { + if !it.loaded { + if err := it.load(); err != nil { + return err + } + it.loaded = true + if len(it.chunk) == 0 || it.pos != 0 { + panic("BUG: load didn't report empty set or set a wrong pos") + } + it.pos++ return nil } it.pos++ @@ -202,43 +214,57 @@ func (it *dbRangeIterator) clone() iterator { } type combinedIterator struct { - iters []iterator - wrapped []iterator - ahead iterator - aheadIdx int + startingPoint hashsync.Ordered + iters []iterator + wrapped []iterator + ahead iterator + aheadIdx int } // combineIterators combines multiple iterators into one, returning the smallest current // key among all iterators at each step. func combineIterators(startingPoint hashsync.Ordered, iters ...iterator) iterator { - var c combinedIterator + return &combinedIterator{startingPoint: startingPoint, iters: iters} +} + +func (c *combinedIterator) begin() error { // Some of the iterators may already be wrapped around. // This corresponds to the case when we ask an idStore for iterator // with a starting point beyond the last key in the store. - if startingPoint == nil { - c.iters = iters - } else { - for _, it := range iters { - if it.Key().Compare(startingPoint) < 0 { - c.wrapped = append(c.wrapped, it) - } else { - c.iters = append(c.iters, it) + iters := c.iters + c.iters = nil + for _, it := range iters { + k, err := it.Key() + if err != nil { + if errors.Is(err, errEmptySet) { + // ignore empty iterators + continue } + return err } - if len(c.iters) == 0 { - // all iterators wrapped around - c.iters = c.wrapped - c.wrapped = nil + if c.startingPoint != nil && k.Compare(c.startingPoint) < 0 { + c.wrapped = append(c.wrapped, it) + } else { + c.iters = append(c.iters, it) } } - return &c + if len(c.iters) == 0 { + // all iterators wrapped around + c.iters = c.wrapped + c.wrapped = nil + } + c.startingPoint = nil + return nil } -func (c *combinedIterator) aheadIterator() iterator { +func (c *combinedIterator) aheadIterator() (iterator, error) { + if err := c.begin(); err != nil { + return nil, err + } if c.ahead == nil { if len(c.iters) == 0 { if len(c.wrapped) == 0 { - return nil + return nil, nil } c.iters = c.wrapped c.wrapped = nil @@ -246,29 +272,51 @@ func (c *combinedIterator) aheadIterator() iterator { c.ahead = c.iters[0] c.aheadIdx = 0 for i := 1; i < len(c.iters); i++ { - if c.iters[i].Key() != nil { - if c.ahead.Key() == nil || c.iters[i].Key().Compare(c.ahead.Key()) < 0 { + curK, err := c.iters[i].Key() + if err != nil { + return nil, err + } + if curK != nil { + aK, err := c.ahead.Key() + if err != nil { + return nil, err + } + if curK.Compare(aK) < 0 { c.ahead = c.iters[i] c.aheadIdx = i } } } } - return c.ahead + return c.ahead, nil } -func (c *combinedIterator) Key() hashsync.Ordered { - return c.aheadIterator().Key() +func (c *combinedIterator) Key() (hashsync.Ordered, error) { + it, err := c.aheadIterator() + if err != nil { + return nil, err + } + return it.Key() } func (c *combinedIterator) Next() error { - it := c.aheadIterator() - oldKey := it.Key() + it, err := c.aheadIterator() + if err != nil { + return err + } + oldKey, err := it.Key() + if err != nil { + return err + } if err := it.Next(); err != nil { return err } c.ahead = nil - if oldKey.Compare(it.Key()) >= 0 { + newKey, err := it.Key() + if err != nil { + return err + } + if oldKey.Compare(newKey) >= 0 { // the iterator has wrapped around, move it to the wrapped list // which will be used after all the iterators have wrapped around c.wrapped = append(c.wrapped, it) @@ -279,10 +327,15 @@ func (c *combinedIterator) Next() error { func (c *combinedIterator) clone() iterator { cloned := &combinedIterator{ - iters: make([]iterator, len(c.iters)), + iters: make([]iterator, len(c.iters)), + wrapped: make([]iterator, len(c.wrapped)), + startingPoint: c.startingPoint, } for i, it := range c.iters { cloned.iters[i] = it.clone() } + for i, it := range c.wrapped { + cloned.wrapped[i] = it.clone() + } return cloned } diff --git a/sync2/dbsync/dbiter_test.go b/sync2/dbsync/dbiter_test.go index eeac4c779d..3463c79281 100644 --- a/sync2/dbsync/dbiter_test.go +++ b/sync2/dbsync/dbiter_test.go @@ -288,25 +288,25 @@ func TestDBRangeIterator(t *testing.T) { deleteDBItems(t, db) insertDBItems(t, db, tc.items) for maxChunkSize := 1; maxChunkSize < 12; maxChunkSize++ { - it, err := newDBRangeIterator(db, testQuery, tc.from, maxChunkSize) + it := newDBRangeIterator(db, testQuery, tc.from, maxChunkSize) if tc.expErr != nil { + _, err := it.Key() require.ErrorIs(t, err, tc.expErr) continue } - require.NoError(t, err) // when there are no items, errEmptySet is returned require.NotEmpty(t, tc.items) clonedIt := it.clone() var collected []KeyBytes for i := 0; i < len(tc.items); i++ { - k := it.Key() + k := itKey(t, it) require.NotNil(t, k) - collected = append(collected, k.(KeyBytes)) - require.Equal(t, k, clonedIt.Key()) + collected = append(collected, k) + require.Equal(t, k, itKey(t, clonedIt)) require.NoError(t, it.Next()) // calling Next on the original iterator // shouldn't affect the cloned one - require.Equal(t, k, clonedIt.Key()) + require.Equal(t, k, itKey(t, clonedIt)) require.NoError(t, clonedIt.Next()) } expected := slices.Concat(tc.items[tc.fromN:], tc.items[:tc.fromN]) @@ -315,11 +315,11 @@ func TestDBRangeIterator(t *testing.T) { clonedIt = it.clone() for range 2 { for i := 0; i < len(tc.items); i++ { - k := it.Key() - require.Equal(t, collected[i], k.(KeyBytes)) - require.Equal(t, k, clonedIt.Key()) + k := itKey(t, it) + require.Equal(t, collected[i], k) + require.Equal(t, k, itKey(t, clonedIt)) require.NoError(t, it.Next()) - require.Equal(t, k, clonedIt.Key()) + require.Equal(t, k, itKey(t, clonedIt)) require.NoError(t, clonedIt.Next()) } } @@ -333,14 +333,14 @@ type fakeIterator struct { var _ hashsync.Iterator = &fakeIterator{} -func (it *fakeIterator) Key() hashsync.Ordered { +func (it *fakeIterator) Key() (hashsync.Ordered, error) { if len(it.allItems) == 0 { - panic("no items") + return nil, errEmptySet } if len(it.items) == 0 { it.items = it.allItems } - return KeyBytes(it.items[0]) + return KeyBytes(it.items[0]), nil } func (it *fakeIterator) Next() error { @@ -384,11 +384,11 @@ func TestCombineIterators(t *testing.T) { for range 3 { var collected []KeyBytes for range 4 { - k := it.Key() - collected = append(collected, k.(KeyBytes)) - require.Equal(t, k, clonedIt.Key()) + k := itKey(t, it) + collected = append(collected, k) + require.Equal(t, k, itKey(t, clonedIt)) require.NoError(t, it.Next()) - require.Equal(t, k, clonedIt.Key()) + require.Equal(t, k, itKey(t, clonedIt)) require.NoError(t, clonedIt.Next()) } require.Equal(t, []KeyBytes{ @@ -397,23 +397,23 @@ func TestCombineIterators(t *testing.T) { {0x0a, 0x05, 0x00, 0x00}, {0xff, 0xff, 0xff, 0xff}, }, collected) - require.Equal(t, KeyBytes{0x00, 0x00, 0x00, 0x01}, it.Key()) + require.Equal(t, KeyBytes{0x00, 0x00, 0x00, 0x01}, itKey(t, it)) } it1 = &fakeIterator{allItems: []KeyBytes{KeyBytes{0, 0, 0, 0}, KeyBytes("error")}} it2 = &fakeIterator{allItems: []KeyBytes{KeyBytes{0, 0, 0, 1}}} it = combineIterators(nil, it1, it2) - require.Equal(t, KeyBytes{0, 0, 0, 0}, it.Key()) + require.Equal(t, KeyBytes{0, 0, 0, 0}, itKey(t, it)) require.Error(t, it.Next()) it1 = &fakeIterator{allItems: []KeyBytes{KeyBytes{0, 0, 0, 0}}} it2 = &fakeIterator{allItems: []KeyBytes{KeyBytes{0, 0, 0, 1}, KeyBytes("error")}} it = combineIterators(nil, it1, it2) - require.Equal(t, KeyBytes{0, 0, 0, 0}, it.Key()) + require.Equal(t, KeyBytes{0, 0, 0, 0}, itKey(t, it)) require.NoError(t, it.Next()) - require.Equal(t, KeyBytes{0, 0, 0, 1}, it.Key()) + require.Equal(t, KeyBytes{0, 0, 0, 1}, itKey(t, it)) require.Error(t, it.Next()) } @@ -434,8 +434,8 @@ func TestCombineIteratorsInitiallyWrapped(t *testing.T) { it := combineIterators(KeyBytes{0xff, 0x00, 0x00, 0x55}, it1, it2) var collected []KeyBytes for range 4 { - k := it.Key() - collected = append(collected, k.(KeyBytes)) + k := itKey(t, it) + collected = append(collected, k) require.NoError(t, it.Next()) } require.Equal(t, []KeyBytes{ @@ -444,5 +444,12 @@ func TestCombineIteratorsInitiallyWrapped(t *testing.T) { {0x00, 0x00, 0x00, 0x03}, {0x0a, 0x05, 0x00, 0x00}, }, collected) - require.Equal(t, KeyBytes{0xff, 0x00, 0x00, 0x55}, it.Key()) + require.Equal(t, KeyBytes{0xff, 0x00, 0x00, 0x55}, itKey(t, it)) +} + +func itKey(t *testing.T, it hashsync.Iterator) KeyBytes { + k, err := it.Key() + require.NoError(t, err) + require.NotNil(t, k) + return k.(KeyBytes) } diff --git a/sync2/dbsync/fptree.go b/sync2/dbsync/fptree.go index 8fc2aca0ce..03bcf8ee2e 100644 --- a/sync2/dbsync/fptree.go +++ b/sync2/dbsync/fptree.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/binary" "encoding/hex" - "errors" "fmt" "io" "math/bits" @@ -30,6 +29,11 @@ func (t *trace) enter(format string, args ...any) { if !t.traceEnabled { return } + for n, arg := range args { + if it, ok := arg.(iterator); ok { + args[n] = formatIter(it) + } + } msg := fmt.Sprintf(format, args...) t.out("ENTER: " + msg) t.traceStack = append(t.traceStack, msg) @@ -42,6 +46,11 @@ func (t *trace) leave(results ...any) { if len(t.traceStack) == 0 { panic("BUG: trace stack underflow") } + for n, r := range results { + if it, ok := r.(iterator); ok { + results[n] = formatIter(it) + } + } msg := t.traceStack[len(t.traceStack)-1] if len(results) != 0 { var r []string @@ -56,6 +65,11 @@ func (t *trace) leave(results ...any) { func (t *trace) log(format string, args ...any) { if t.traceEnabled { + for n, arg := range args { + if it, ok := arg.(iterator); ok { + args[n] = formatIter(it) + } + } msg := fmt.Sprintf(format, args...) t.out(msg) } @@ -368,8 +382,8 @@ type iterator interface { type idStore interface { clone() idStore registerHash(h KeyBytes) error - start() (iterator, error) - iter(from KeyBytes) (iterator, error) + start() iterator + iter(from KeyBytes) iterator } type fpTree struct { @@ -576,41 +590,38 @@ func (ft *fpTree) aggregateEdge(x, y KeyBytes, p prefix, ac *aggContext) (cont b startFrom = x } ft.log("aggregateEdge: startFrom %s", startFrom) - it, err := ft.iter(startFrom) - if err != nil { - if errors.Is(err, errEmptySet) { - ft.log("aggregateEdge: empty set") - return false, nil - } - ft.log("aggregateEdge: error: %v", err) - return false, err - } + it := ft.iter(startFrom) if ac.limit == 0 { ac.end = it.clone() if x != nil { + ft.log("aggregateEdge: limit 0: x is not nil, setting start to %s", ac.start) ac.start = ac.end } - ft.log("aggregateEdge: limit is 0 at %s", ac.end.Key().(fmt.Stringer)) + ft.log("aggregateEdge: limit is 0 at %s", ac.end) return false, nil } if x != nil { ac.start = it.clone() + ft.log("aggregateEdge: x is not nil, setting start to %s", ac.start) } for range ft.np.node(ft.root).c { - id := it.Key().(KeyBytes) + id, err := it.Key() + if err != nil { + return false, err + } ft.log("aggregateEdge: ID %s", id) if y != nil && id.Compare(y) >= 0 { ac.end = it ft.log("aggregateEdge: ID is over Y: %s", id) return false, nil } - if !p.match(id) { + if !p.match(id.(KeyBytes)) { ft.log("aggregateEdge: ID doesn't match the prefix: %s", id) ac.lastPrefix = &p return true, nil } - ac.fp.update(id) + ac.fp.update(id.(KeyBytes)) ac.count++ if ac.limit > 0 { ac.limit-- @@ -923,7 +934,7 @@ func (ft *fpTree) fingerprintInterval(x, y KeyBytes, limit int) (fpr fpResult, e ft.enter("fingerprintInterval: x %s y %s limit %d", x, y, limit) defer func() { if fpr.start != nil && fpr.end != nil { - ft.leave(fpr.fp, fpr.count, fpr.itype, fpr.start.Key(), fpr.end.Key()) + ft.leave(fpr.fp, fpr.count, fpr.itype, fpr.start, fpr.end) } else { ft.leave(fpr.fp, fpr.count, fpr.itype) } @@ -943,31 +954,27 @@ func (ft *fpTree) fingerprintInterval(x, y KeyBytes, limit int) (fpr fpResult, e } if ac.start != nil { - ft.log("fingerprintInterval: start %s", ac.start.Key().(fmt.Stringer)) + ft.log("fingerprintInterval: start %s", ac.start) fpr.start = ac.start - } else if fpr.start, err = ft.iter(x); err != nil { - return fpResult{}, err } else { - ft.log("fingerprintInterval: start from x: %s", fpr.start.Key().(fmt.Stringer)) + fpr.start = ft.iter(x) + ft.log("fingerprintInterval: start from x: %s", fpr.start) } if ac.end != nil { - ft.log("fingerprintInterval: end %s", ac.end.Key().(fmt.Stringer)) + ft.log("fingerprintInterval: end %s", ac.end) fpr.end = ac.end } else if (fpr.itype == 0 && limit < 0) || fpr.count == 0 { fpr.end = fpr.start - ft.log("fingerprintInterval: end at start %s", fpr.end.Key().(fmt.Stringer)) + ft.log("fingerprintInterval: end at start %s", fpr.end) } else if ac.lastPrefix != nil { k := make(KeyBytes, ft.keyLen) ac.lastPrefix.idAfter(k) - if fpr.end, err = ft.iter(k); err != nil { - return fpResult{}, err - } - ft.log("fingerprintInterval: end at lastPrefix %s", fpr.end.Key().(fmt.Stringer)) - } else if fpr.end, err = ft.iter(y); err != nil { - return fpResult{}, err + fpr.end = ft.iter(k) + ft.log("fingerprintInterval: end at lastPrefix %s", fpr.end) } else { - ft.log("fingerprintInterval: end at y: %s", fpr.end.Key().(fmt.Stringer)) + fpr.end = ft.iter(y) + ft.log("fingerprintInterval: end at y: %s", fpr.end) } return fpr, nil @@ -1000,6 +1007,29 @@ func (ft *fpTree) dump(w io.Writer) { } } +func (ft *fpTree) count() int { + if ft.root == noIndex { + return 0 + } + return int(ft.np.node(ft.root).c) +} + +type iterFormatter struct { + it iterator +} + +func (f iterFormatter) String() string { + if k, err := f.it.Key(); err != nil { + return fmt.Sprintf("", err) + } else { + return k.(fmt.Stringer).String() + } +} + +func formatIter(it iterator) fmt.Stringer { + return iterFormatter{it: it} +} + // TBD: optimize, get rid of binary.BigEndian.* // TBD: QQQQQ: detect unbalancedness when a ref gets too many items // TBD: QQQQQ: ItemStore.Close(): close db conns, also free fpTree instead of using finalizer! diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index 12d3e3f422..af47ba5228 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -738,7 +738,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { } else { require.NotNil(t, fpr.start, "start") expK := KeyBytes(hs[rtc.startIdx][:]) - assert.Equal(t, expK, fpr.start.Key(), "start") + assert.Equal(t, expK, itKey(t, fpr.start), "start") } if rtc.endIdx == -1 { @@ -746,7 +746,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { } else { require.NotNil(t, fpr.end, "end") expK := KeyBytes(hs[rtc.endIdx][:]) - assert.Equal(t, expK, fpr.end.Key(), "end") + assert.Equal(t, expK, itKey(t, fpr.end), "end") } }) } @@ -785,19 +785,19 @@ func (noIDStore) registerHash(h KeyBytes) error { return nil } -func (noIDStore) start() (iterator, error) { +func (noIDStore) start() iterator { panic("no ID store") } -func (noIDStore) iter(from KeyBytes) (iterator, error) { - return noIter{}, nil +func (noIDStore) iter(from KeyBytes) iterator { + return noIter{} } type noIter struct{} -func (noIter) Key() hashsync.Ordered { - return make(KeyBytes, 32) +func (noIter) Key() (hashsync.Ordered, error) { + return make(KeyBytes, 32), nil } func (noIter) Next() error { @@ -999,17 +999,17 @@ type fpResultWithBounds struct { end KeyBytes } -func toFPResultWithBounds(fpr fpResult) fpResultWithBounds { +func toFPResultWithBounds(t *testing.T, fpr fpResult) fpResultWithBounds { r := fpResultWithBounds{ fp: fpr.fp, count: fpr.count, itype: fpr.itype, } if fpr.start != nil { - r.start = fpr.start.Key().(KeyBytes) + r.start = itKey(t, fpr.start) } if fpr.end != nil { - r.end = fpr.end.Key().(KeyBytes) + r.end = itKey(t, fpr.end) } return r } @@ -1092,11 +1092,11 @@ func verifyInterval(t *testing.T, hs hashList, ft *fpTree, x, y types.Hash32, li expFPR := dumbFP(hs, x, y, limit) fpr, err := ft.fingerprintInterval(x[:], y[:], limit) require.NoError(t, err) - require.Equal(t, expFPR, toFPResultWithBounds(fpr), + require.Equal(t, expFPR, toFPResultWithBounds(t, fpr), "x=%s y=%s limit=%d", x.String(), y.String(), limit) // QQQQQ: rm - if !reflect.DeepEqual(toFPResultWithBounds(fpr), expFPR) { + if !reflect.DeepEqual(toFPResultWithBounds(t, fpr), expFPR) { t.Logf("QQQQQ: x=%s y=%s", x.String(), y.String()) for _, h := range hs { t.Logf("QQQQQ: hash: %s", h.String()) @@ -1107,7 +1107,7 @@ func verifyInterval(t *testing.T, hs hashList, ft *fpTree, x, y types.Hash32, li } // QQQQQ: /rm - require.Equal(t, expFPR, toFPResultWithBounds(fpr), + require.Equal(t, expFPR, toFPResultWithBounds(t, fpr), "x=%s y=%s limit=%d", x.String(), y.String(), limit) return fpr @@ -1123,7 +1123,7 @@ func verifySubIntervals(t *testing.T, hs hashList, ft *fpTree, x, y types.Hash32 } part := verifyInterval(t, hs, ft, x, y, c) var m types.Hash32 - copy(m[:], part.end.Key().(KeyBytes)) + copy(m[:], itKey(t, part.end)) verifySubIntervals(t, hs, ft, x, m, -1, d+1) verifySubIntervals(t, hs, ft, m, y, -1, d+1) } @@ -1376,7 +1376,7 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { //expFPResult := dumbAggATXs(t, db, x, y) fpr, err := ft.fingerprintInterval(x[:], y[:], -1) require.NoError(t, err) - require.Equal(t, expFPResult, toFPResultWithBounds(fpr), + require.Equal(t, expFPResult, toFPResultWithBounds(t, fpr), "x=%s y=%s", x.String(), y.String()) limit := 0 @@ -1387,7 +1387,7 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { expFPResult = dumbFP(*hs, x, y, limit) fpr, err = ft.fingerprintInterval(x[:], y[:], limit) require.NoError(t, err) - require.Equal(t, expFPResult, toFPResultWithBounds(fpr), + require.Equal(t, expFPResult, toFPResultWithBounds(t, fpr), "x=%s y=%s limit=%d", x.String(), y.String(), limit) } diff --git a/sync2/dbsync/inmemidstore.go b/sync2/dbsync/inmemidstore.go index 5fbc56a3a4..c3619f21bb 100644 --- a/sync2/dbsync/inmemidstore.go +++ b/sync2/dbsync/inmemidstore.go @@ -34,23 +34,16 @@ func (s *inMemIDStore) registerHash(h KeyBytes) error { return nil } -func (s *inMemIDStore) start() (iterator, error) { - node := s.sl.First() - if node == nil { - return nil, errEmptySet - } - return &inMemIDStoreIterator{sl: s.sl, node: node}, nil +func (s *inMemIDStore) start() iterator { + return &inMemIDStoreIterator{sl: s.sl, node: s.sl.First()} } -func (s *inMemIDStore) iter(from KeyBytes) (iterator, error) { +func (s *inMemIDStore) iter(from KeyBytes) iterator { node := s.sl.FindGTENode(from) if node == nil { node = s.sl.First() - if node == nil { - return nil, errEmptySet - } } - return &inMemIDStoreIterator{sl: s.sl, node: node}, nil + return &inMemIDStoreIterator{sl: s.sl, node: node} } type inMemIDStoreIterator struct { @@ -60,8 +53,11 @@ type inMemIDStoreIterator struct { var _ iterator = &inMemIDStoreIterator{} -func (it *inMemIDStoreIterator) Key() hashsync.Ordered { - return KeyBytes(it.node.Key()) +func (it *inMemIDStoreIterator) Key() (hashsync.Ordered, error) { + if it.node == nil { + return nil, errEmptySet + } + return KeyBytes(it.node.Key()), nil } func (it *inMemIDStoreIterator) Next() error { diff --git a/sync2/dbsync/inmemidstore_test.go b/sync2/dbsync/inmemidstore_test.go index f8ca3719ed..623e881757 100644 --- a/sync2/dbsync/inmemidstore_test.go +++ b/sync2/dbsync/inmemidstore_test.go @@ -16,10 +16,10 @@ func TestInMemIDStore(t *testing.T) { ) s := newInMemIDStore(32) - _, err = s.start() + _, err = s.start().Key() require.ErrorIs(t, err, errEmptySet) - _, err = s.iter(util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")) + _, err = s.iter(util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")).Key() require.ErrorIs(t, err, errEmptySet) for _, h := range []string{ @@ -36,15 +36,14 @@ func TestInMemIDStore(t *testing.T) { for i := range 6 { if i%2 == 0 { - it, err = s.start() + it = s.start() } else { - it, err = s.iter( + it = s.iter( util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")) } - require.NoError(t, err) var items []string for range 7 { - items = append(items, hex.EncodeToString(it.Key().(KeyBytes))) + items = append(items, hex.EncodeToString(itKey(t, it))) require.NoError(t, it.Next()) } require.Equal(t, []string{ @@ -58,18 +57,17 @@ func TestInMemIDStore(t *testing.T) { }, items) require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", - hex.EncodeToString(it.Key().(KeyBytes))) + hex.EncodeToString(itKey(t, it))) s1 := s.clone() h := types.BytesToHash( util.FromHex("2000000000000000000000000000000000000000000000000000000000000000")) s1.registerHash(h[:]) items = nil - it, err = s1.iter( + it = s1.iter( util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")) - require.NoError(t, err) for range 8 { - items = append(items, hex.EncodeToString(it.Key().(KeyBytes))) + items = append(items, hex.EncodeToString(itKey(t, it))) require.NoError(t, it.Next()) } require.Equal(t, []string{ @@ -84,13 +82,12 @@ func TestInMemIDStore(t *testing.T) { }, items) require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", - hex.EncodeToString(it.Key().(KeyBytes))) + hex.EncodeToString(itKey(t, it))) - it, err = s1.iter( + it = s1.iter( util.FromHex("fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0")) - require.NoError(t, err) require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", - hex.EncodeToString(it.Key().(KeyBytes))) + hex.EncodeToString(itKey(t, it))) } } diff --git a/sync2/dbsync/p2p_test.go b/sync2/dbsync/p2p_test.go index b258735d94..5d32a03554 100644 --- a/sync2/dbsync/p2p_test.go +++ b/sync2/dbsync/p2p_test.go @@ -110,12 +110,16 @@ func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { assert.Nil(t, it) } else { for range combinedItems { - // t.Logf("synced itemA: %s", it.Key().(types.Hash32).String()) - h := it.Key().(types.Hash32) + k, err := it.Key() + require.NoError(t, err) + h := k.(types.Hash32) + // t.Logf("synced itemA: %s", h.String()) actItemsA = append(actItemsA, h[:]) require.NoError(t, it.Next()) } - h := it.Key().(types.Hash32) + k, err := it.Key() + require.NoError(t, err) + h := k.(types.Hash32) assert.Equal(t, actItemsA[0], KeyBytes(h[:])) } @@ -126,12 +130,16 @@ func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { assert.Nil(t, it) } else { for range combinedItems { - // t.Logf("synced itemB: %s", it.Key().(types.Hash32).String()) - h := it.Key().(types.Hash32) + k, err := it.Key() + require.NoError(t, err) + h := k.(types.Hash32) + // t.Logf("synced itemB: %s", h.String()) actItemsB = append(actItemsB, h[:]) require.NoError(t, it.Next()) } - h := it.Key().(types.Hash32) + k, err := it.Key() + require.NoError(t, err) + h := k.(types.Hash32) assert.Equal(t, actItemsB[0], KeyBytes(h[:])) assert.Equal(t, combinedItems, actItemsA) assert.Equal(t, actItemsA, actItemsB) diff --git a/sync2/dbsync/sqlidstore.go b/sync2/dbsync/sqlidstore.go index 8168f48919..2569953c43 100644 --- a/sync2/dbsync/sqlidstore.go +++ b/sync2/dbsync/sqlidstore.go @@ -2,7 +2,6 @@ package dbsync import ( "bytes" - "errors" "github.com/spacemeshos/go-spacemesh/sql" ) @@ -30,12 +29,12 @@ func (s *sqlIDStore) registerHash(h KeyBytes) error { return nil } -func (s *sqlIDStore) start() (iterator, error) { +func (s *sqlIDStore) start() iterator { // TODO: should probably use a different query to get the first key return s.iter(make(KeyBytes, s.keyLen)) } -func (s *sqlIDStore) iter(from KeyBytes) (iterator, error) { +func (s *sqlIDStore) iter(from KeyBytes) iterator { if len(from) != s.keyLen { panic("BUG: invalid key length") } @@ -67,38 +66,16 @@ func (s *dbBackedStore) registerHash(h KeyBytes) error { return s.inMemIDStore.registerHash(h) } -func (s *dbBackedStore) start() (iterator, error) { - dbIt, err := s.sqlIDStore.start() - if err != nil { - if errors.Is(err, errEmptySet) { - return s.inMemIDStore.start() - } - return nil, err - } - memIt, err := s.inMemIDStore.start() - if err == nil { - return combineIterators(nil, dbIt, memIt), nil - } else if errors.Is(err, errEmptySet) { - return dbIt, nil - } - return nil, err +func (s *dbBackedStore) start() iterator { + dbIt := s.sqlIDStore.start() + memIt := s.inMemIDStore.start() + return combineIterators(nil, dbIt, memIt) } -func (s *dbBackedStore) iter(from KeyBytes) (iterator, error) { - dbIt, err := s.sqlIDStore.iter(from) - if err != nil { - if errors.Is(err, errEmptySet) { - return s.inMemIDStore.iter(from) - } - return nil, err - } - memIt, err := s.inMemIDStore.iter(from) - if err == nil { - return combineIterators(from, dbIt, memIt), nil - } else if errors.Is(err, errEmptySet) { - return dbIt, nil - } - return nil, err +func (s *dbBackedStore) iter(from KeyBytes) iterator { + dbIt := s.sqlIDStore.iter(from) + memIt := s.inMemIDStore.iter(from) + return combineIterators(from, dbIt, memIt) } func idWithinInterval(id, x, y KeyBytes, itype int) bool { diff --git a/sync2/dbsync/sqlidstore_test.go b/sync2/dbsync/sqlidstore_test.go index f756bac693..02e9169b67 100644 --- a/sync2/dbsync/sqlidstore_test.go +++ b/sync2/dbsync/sqlidstore_test.go @@ -16,10 +16,9 @@ func TestDBBackedStore(t *testing.T) { db := populateDB(t, 8, initialIDs) store := newDBBackedStore(db, fakeIDQuery, 8) var actualIDs []KeyBytes - it, err := store.iter(KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}) - require.NoError(t, err) + it := store.iter(KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}) for range 5 { - actualIDs = append(actualIDs, it.Key().(KeyBytes)) + actualIDs = append(actualIDs, itKey(t, it)) require.NoError(t, it.Next()) } require.Equal(t, []KeyBytes{ @@ -30,20 +29,18 @@ func TestDBBackedStore(t *testing.T) { {0, 0, 0, 1, 0, 0, 0, 0}, // wrapped around }, actualIDs) - it, err = store.start() - require.NoError(t, err) + it = store.start() for n := range 5 { - require.Equal(t, actualIDs[n], it.Key().(KeyBytes)) + require.Equal(t, actualIDs[n], itKey(t, it)) require.NoError(t, it.Next()) } require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 2, 0, 0, 0, 0})) require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 9, 0, 0, 0, 0})) actualIDs = nil - it, err = store.iter(KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}) - require.NoError(t, err) + it = store.iter(KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}) for range 6 { - actualIDs = append(actualIDs, it.Key().(KeyBytes)) + actualIDs = append(actualIDs, itKey(t, it)) require.NoError(t, it.Next()) } require.Equal(t, []KeyBytes{ diff --git a/sync2/hashsync/handler.go b/sync2/hashsync/handler.go index 2b7939266e..627e17e8ef 100644 --- a/sync2/hashsync/handler.go +++ b/sync2/hashsync/handler.go @@ -160,10 +160,11 @@ func (c *wireConduit) SendItems(count, itemChunkSize int, it Iterator) error { var msg ItemBatchMessage n := min(itemChunkSize, count-i) for n > 0 { - if it.Key() == nil { - panic("fakeConduit.SendItems: went got to the end of the tree") + k, err := it.Key() + if err != nil { + return err } - msg.ContentKeys = append(msg.ContentKeys, it.Key().(types.Hash32)) + msg.ContentKeys = append(msg.ContentKeys, k.(types.Hash32)) if err := it.Next(); err != nil { return err } @@ -209,7 +210,11 @@ func (c *wireConduit) SendProbeResponse(x, y Ordered, fingerprint any, count, sa } // fmt.Fprintf(os.Stderr, "QQQQQ: begin sending items\n") for n := 0; n < sampleSize; n++ { - m.Sample[n] = MinhashSampleItemFromHash32(it.Key().(types.Hash32)) + k, err := it.Key() + if err != nil { + return err + } + m.Sample[n] = MinhashSampleItemFromHash32(k.(types.Hash32)) // fmt.Fprintf(os.Stderr, "QQQQQ: m.Sample[%d] = %s\n", n, m.Sample[n]) if err := it.Next(); err != nil { return err diff --git a/sync2/hashsync/handler_test.go b/sync2/hashsync/handler_test.go index 1764626b42..58b37361d7 100644 --- a/sync2/hashsync/handler_test.go +++ b/sync2/hashsync/handler_test.go @@ -118,11 +118,11 @@ func (it *sliceIterator) Equal(other Iterator) bool { return false } -func (it *sliceIterator) Key() Ordered { +func (it *sliceIterator) Key() (Ordered, error) { if len(it.s) != 0 { - return it.s[0] + return it.s[0], nil } - return nil + return nil, nil } func (it *sliceIterator) Next() error { @@ -478,7 +478,9 @@ func testWireProbe(t *testing.T, getRequester getRequesterFunc) Requester { pss := NewPairwiseStoreSyncer(client, opts) minA, err := storeA.Min() require.NoError(t, err) - infoA, err := storeA.GetRangeInfo(nil, minA.Key(), minA.Key(), -1) + kA, err := minA.Key() + require.NoError(t, err) + infoA, err := storeA.GetRangeInfo(nil, kA, kA, -1) require.NoError(t, err) prA, err := pss.Probe(ctx, srvPeerID, storeB, nil, nil) require.NoError(t, err) @@ -488,10 +490,15 @@ func testWireProbe(t *testing.T, getRequester getRequesterFunc) Requester { minA, err = storeA.Min() require.NoError(t, err) - partInfoA, err := storeA.GetRangeInfo(nil, minA.Key(), minA.Key(), infoA.Count/2) + kA, err = minA.Key() + require.NoError(t, err) + partInfoA, err := storeA.GetRangeInfo(nil, kA, kA, infoA.Count/2) + require.NoError(t, err) + xK, err := partInfoA.Start.Key() require.NoError(t, err) - x := partInfoA.Start.Key().(types.Hash32) - y := partInfoA.End.Key().(types.Hash32) + x := xK.(types.Hash32) + yK, err := partInfoA.End.Key() + y := yK.(types.Hash32) // partInfoA = storeA.GetRangeInfo(nil, x, y, -1) prA, err = pss.Probe(ctx, srvPeerID, storeB, &x, &y) require.NoError(t, err) diff --git a/sync2/hashsync/interface.go b/sync2/hashsync/interface.go index 88aa1456ec..32bfc1f46d 100644 --- a/sync2/hashsync/interface.go +++ b/sync2/hashsync/interface.go @@ -17,7 +17,7 @@ type Iterator interface { // nil if the ItemStore is empty // If the iterator is returned along with a count, the return value of Key() // after calling Next() count times is dependent on the implementation. - Key() Ordered + Key() (Ordered, error) // Next advances the iterator Next() error } diff --git a/sync2/hashsync/log.go b/sync2/hashsync/log.go index 369843d5a7..8d3f085b06 100644 --- a/sync2/hashsync/log.go +++ b/sync2/hashsync/log.go @@ -14,7 +14,11 @@ func IteratorField(name string, it Iterator) zap.Field { if it == nil { return zap.String(name, "") } - return HexField(name, it.Key()) + k, err := it.Key() + if err != nil { + return zap.String(name, fmt.Sprintf("", err)) + } + return HexField(name, k) } // based on code from testify diff --git a/sync2/hashsync/rangesync.go b/sync2/hashsync/rangesync.go index a10f8e2876..392310f858 100644 --- a/sync2/hashsync/rangesync.go +++ b/sync2/hashsync/rangesync.go @@ -209,8 +209,14 @@ func NewRangeSetReconciler(is ItemStore, opts ...RangeSetReconcilerOption) *Rang // } func (rsr *RangeSetReconciler) processSubrange(c Conduit, preceding Iterator, x, y Ordered) (Iterator, error) { - if preceding != nil && preceding.Key().Compare(x) > 0 { - preceding = nil + if preceding != nil { + k, err := preceding.Key() + if err != nil { + return nil, err + } + if k.Compare(x) > 0 { + preceding = nil + } } // fmt.Fprintf(os.Stderr, "QQQQQ: preceding=%q\n", // qqqqRmmeK(preceding)) @@ -286,7 +292,10 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg } return nil, true, nil } - x = it.Key() + x, err = it.Key() + if err != nil { + return nil, false, err + } y = x } else if x == nil || y == nil { return nil, false, errors.New("bad X or Y") @@ -394,7 +403,10 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg HexField("fingerprint", part.Fingerprint), IteratorField("start", part.Start), IteratorField("middle", part.End)) - middle := part.End.Key() + middle, err := part.End.Key() + if err != nil { + return nil, false, err + } next, err := rsr.processSubrange(c, info.Start, x, middle) if err != nil { return nil, false, err @@ -418,7 +430,10 @@ func (rsr *RangeSetReconciler) Initiate(c Conduit) error { } var x Ordered if it != nil { - x = it.Key() + x, err = it.Key() + if err != nil { + return err + } } return rsr.InitiateBounded(c, x, x) } @@ -512,10 +527,11 @@ func (rsr *RangeSetReconciler) calcSim(c Conduit, info RangeInfo, remoteSample [ for n := 0; n < sampleSize; n++ { // fmt.Fprintf(os.Stderr, "QQQQQ: n %d sampleSize %d info.Count %d rsr.sampleSize %d %#v\n", // n, sampleSize, info.Count, rsr.sampleSize, it.Key()) - if it.Key() == nil { - panic("BUG: no key") + k, err := it.Key() + if err != nil { + return 0, err } - localSample[n] = c.ShortenKey(it.Key()) + localSample[n] = c.ShortenKey(k) if err := it.Next(); err != nil { return 0, err } @@ -666,10 +682,14 @@ func CollectStoreItems[K Ordered](is ItemStore) ([]K, error) { if err != nil { return nil, err } - if it == nil || it.Key() == nil { + if it == nil { return nil, nil } - info, err := is.GetRangeInfo(nil, it.Key(), it.Key(), -1) + k, err := it.Key() + if err != nil { + return nil, err + } + info, err := is.GetRangeInfo(nil, k, k, -1) if err != nil { return nil, err } @@ -678,7 +698,10 @@ func CollectStoreItems[K Ordered](is ItemStore) ([]K, error) { return nil, err } for n := 0; n < info.Count; n++ { - k := it.Key() + k, err := it.Key() + if err != nil { + return nil, err + } if k == nil { fmt.Fprintf(os.Stderr, "QQQQQ: it: %#v\n", it) panic("BUG: iterator exausted before Count reached") diff --git a/sync2/hashsync/rangesync_test.go b/sync2/hashsync/rangesync_test.go index e396c66c34..4fac1a6845 100644 --- a/sync2/hashsync/rangesync_test.go +++ b/sync2/hashsync/rangesync_test.go @@ -120,10 +120,11 @@ func (fc *fakeConduit) SendItems(count, itemChunkSize int, it Iterator) error { msg := rangeMessage{mtype: MessageTypeItemBatch} n := min(itemChunkSize, count-i) for n > 0 { - if it.Key() == nil { - panic("fakeConduit.SendItems: went got to the end of the tree") + k, err := it.Key() + if err != nil { + return fmt.Errorf("getting item: %w", err) } - msg.keys = append(msg.keys, it.Key()) + msg.keys = append(msg.keys, k) if err := it.Next(); err != nil { return err } @@ -165,8 +166,10 @@ func (fc *fakeConduit) SendProbeResponse(x, y Ordered, fingerprint any, count, s keys: make([]Ordered, sampleSize), } for n := 0; n < sampleSize; n++ { - require.NotNil(fc.t, it.Key()) - msg.keys[n] = it.Key() + k, err := it.Key() + require.NoError(fc.t, err) + require.NotNil(fc.t, k) + msg.keys[n] = k if err := it.Next(); err != nil { return err } @@ -194,8 +197,8 @@ func (it *dumbStoreIterator) Equal(other Iterator) bool { return it.n == o.n } -func (it *dumbStoreIterator) Key() Ordered { - return it.ds.keys[it.n] +func (it *dumbStoreIterator) Key() (Ordered, error) { + return it.ds.keys[it.n], nil } func (it *dumbStoreIterator) Next() error { @@ -265,7 +268,10 @@ func (ds *dumbStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) ( Fingerprint: "", }, nil } else { - x = it.Key() + x, err = it.Key() + if err != nil { + return RangeInfo{}, err + } y = x } } else if x == nil || y == nil { @@ -277,8 +283,14 @@ func (ds *dumbStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) ( } vx := x.(sampleID) vy := y.(sampleID) - if preceding != nil && preceding.Key().Compare(x) > 0 { - panic("preceding info after x") + if preceding != nil { + k, err := preceding.Key() + if err != nil { + return RangeInfo{}, err + } + if k.Compare(x) > 0 { + panic("preceding info after x") + } } fp, startStr, endStr := naiveRange(all, string(vx), string(vy), count) r := RangeInfo{ @@ -326,11 +338,16 @@ type verifiedStoreIterator struct { var _ Iterator = &verifiedStoreIterator{} -func (it verifiedStoreIterator) Key() Ordered { - k1 := it.knownGood.Key() - k2 := it.it.Key() - assert.Equal(it.t, k1, k2, "keys") - return k2 +func (it verifiedStoreIterator) Key() (Ordered, error) { + k1, err := it.knownGood.Key() + if err != nil { + return nil, err + } + k2, err := it.it.Key() + if err == nil { + assert.Equal(it.t, k1, k2, "keys") + } + return k2, nil } func (it verifiedStoreIterator) Next() error { @@ -338,7 +355,15 @@ func (it verifiedStoreIterator) Next() error { err2 := it.it.Next() switch { case err1 == nil && err2 == nil: - assert.Equal(it.t, it.knownGood.Key(), it.it.Key(), "keys for Next()") + k1, err := it.knownGood.Key() + if err != nil { + return err + } + k2, err := it.it.Key() + if err != nil { + return err + } + assert.Equal(it.t, k1, k2, "keys for Next()") case err1 != nil && err2 != nil: return err2 default: @@ -410,7 +435,11 @@ func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count in require.Nil(vs.t, ri2.End, "range info end") } else { require.NotNil(vs.t, ri2.Start, "range info start") - require.Equal(vs.t, ri1.Start.Key(), ri2.Start.Key(), "range info start key") + k1, err := ri1.Start.Key() + require.NoError(vs.t, err) + k2, err := ri2.Start.Key() + require.NoError(vs.t, err) + require.Equal(vs.t, k1, k2, "range info start key") require.NotNil(vs.t, ri1.End, "range info end (known good)") require.NotNil(vs.t, ri2.End, "range info end") ri.Start = verifiedStoreIterator{ @@ -423,7 +452,11 @@ func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count in require.Nil(vs.t, ri2.End, "range info end") } else { require.NotNil(vs.t, ri2.End, "range info end") - require.Equal(vs.t, ri1.End.Key(), ri2.End.Key(), "range info end key") + k1, err := ri1.Start.Key() + require.NoError(vs.t, err) + k2, err := ri2.Start.Key() + require.NoError(vs.t, err) + require.Equal(vs.t, k1, k2, "range info end key") ri.End = verifiedStoreIterator{ t: vs.t, knownGood: ri1.End, @@ -445,7 +478,11 @@ func (vs *verifiedStore) Min() (Iterator, error) { return nil, nil } else { require.NotNil(vs.t, m2, "Min") - require.Equal(vs.t, m1.Key(), m2.Key(), "Min key") + k1, err := m1.Key() + require.NoError(vs.t, err) + k2, err := m2.Key() + require.NoError(vs.t, err) + require.Equal(vs.t, k1, k2, "Min key") } return verifiedStoreIterator{ t: vs.t, diff --git a/sync2/hashsync/setsyncbase.go b/sync2/hashsync/setsyncbase.go index ed49c86e17..15b8d7225f 100644 --- a/sync2/hashsync/setsyncbase.go +++ b/sync2/hashsync/setsyncbase.go @@ -43,7 +43,10 @@ func (ssb *SetSyncBase) Count() (int, error) { if it == nil || err != nil { return 0, err } - x := it.Key() + x, err := it.Key() + if err != nil { + return 0, err + } info, err := ssb.is.GetRangeInfo(nil, x, x, -1) if err != nil { return 0, err diff --git a/sync2/hashsync/sync_tree_store.go b/sync2/hashsync/sync_tree_store.go index 43df94cade..271c411c51 100644 --- a/sync2/hashsync/sync_tree_store.go +++ b/sync2/hashsync/sync_tree_store.go @@ -17,8 +17,8 @@ func (it *syncTreeIterator) Equal(other Iterator) bool { return it.ptr.Equal(o.ptr) } -func (it *syncTreeIterator) Key() Ordered { - return it.ptr.Key() +func (it *syncTreeIterator) Key() (Ordered, error) { + return it.ptr.Key(), nil } func (it *syncTreeIterator) Next() error { @@ -71,7 +71,10 @@ func (sts *SyncTreeStore) GetRangeInfo(preceding Iterator, x, y Ordered, count i Fingerprint: sts.identity, }, nil } else { - x = it.Key() + x, err = it.Key() + if err != nil { + return RangeInfo{}, err + } y = x } } else if x == nil || y == nil { diff --git a/sync2/p2p_test.go b/sync2/p2p_test.go index d6d95813be..c660440657 100644 --- a/sync2/p2p_test.go +++ b/sync2/p2p_test.go @@ -75,7 +75,9 @@ func TestP2P(t *testing.T) { if it == nil { return false } - info, err := is.GetRangeInfo(nil, it.Key(), it.Key(), -1) + k, err := it.Key() + require.NoError(t, err) + info, err := is.GetRangeInfo(nil, k, k, -1) require.NoError(t, err) if info.Count < numHashes { return false From 183908190de9ea94fe3baa9012afd3ffa2b1f65b Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 1 Aug 2024 06:48:46 +0400 Subject: [PATCH 55/76] wip3 --- sync2/dbsync/dbitemstore.go | 55 +++++ sync2/dbsync/dbiter.go | 26 +-- sync2/dbsync/dbiter_test.go | 8 +- sync2/dbsync/fptree.go | 367 ++++++++++++++++++++---------- sync2/dbsync/fptree_test.go | 82 ++++++- sync2/dbsync/inmemidstore.go | 8 +- sync2/dbsync/inmemidstore_test.go | 6 +- sync2/dbsync/sqlidstore.go | 9 +- sync2/hashsync/handler_test.go | 5 + sync2/hashsync/interface.go | 5 + sync2/hashsync/mocks_test.go | 53 ++++- sync2/hashsync/rangesync.go | 49 ++-- sync2/hashsync/rangesync_test.go | 98 ++++++-- sync2/hashsync/sync_tree.go | 8 + sync2/hashsync/sync_tree_store.go | 35 ++- 15 files changed, 602 insertions(+), 212 deletions(-) diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go index 50ca936b56..83e7c61724 100644 --- a/sync2/dbsync/dbitemstore.go +++ b/sync2/dbsync/dbitemstore.go @@ -95,6 +95,32 @@ func (d *DBItemStore) GetRangeInfo( }, nil } +func (d *DBItemStore) SplitRange( + preceding hashsync.Iterator, + x, y hashsync.Ordered, + count int, +) (hashsync.RangeInfo, hashsync.RangeInfo, error) { + if err := d.EnsureLoaded(); err != nil { + return hashsync.RangeInfo{}, hashsync.RangeInfo{}, err + } + panic("TBD") + // fpr1, fpr2, err := d.ft.splitFingerprintInterval(x.(KeyBytes), y.(KeyBytes), count) + // if err != nil { + // return hashsync.RangeInfo{}, hashsync.RangeInfo{}, err + // } + // return hashsync.RangeInfo{ + // Fingerprint: fpr1.fp, + // Count: int(fpr1.count), + // Start: fpr1.start, + // End: fpr1.end, + // }, hashsync.RangeInfo{ + // Fingerprint: fpr2.fp, + // Count: int(fpr2.count), + // Start: fpr2.start, + // End: fpr2.end, + // }, nil +} + // Min implements hashsync.ItemStore. func (d *DBItemStore) Min() (hashsync.Iterator, error) { if err := d.EnsureLoaded(); err != nil { @@ -190,6 +216,31 @@ func (a *ItemStoreAdapter) GetRangeInfo(preceding hashsync.Iterator, x hashsync. }, nil } +func (a *ItemStoreAdapter) SplitRange(preceding hashsync.Iterator, x hashsync.Ordered, y hashsync.Ordered, count int) (hashsync.RangeInfo, hashsync.RangeInfo, error) { + hx := x.(types.Hash32) + hy := y.(types.Hash32) + info1, info2, err := a.s.SplitRange(preceding, KeyBytes(hx[:]), KeyBytes(hy[:]), count) + if err != nil { + return hashsync.RangeInfo{}, hashsync.RangeInfo{}, err + } + var fp1, fp2 types.Hash12 + src1 := info1.Fingerprint.(fingerprint) + src2 := info2.Fingerprint.(fingerprint) + copy(fp1[:], src1[:]) + copy(fp2[:], src2[:]) + return hashsync.RangeInfo{ + Fingerprint: fp1, + Count: info1.Count, + Start: a.wrapIterator(info1.Start), + End: a.wrapIterator(info1.End), + }, hashsync.RangeInfo{ + Fingerprint: fp2, + Count: info2.Count, + Start: a.wrapIterator(info2.Start), + End: a.wrapIterator(info2.End), + }, nil +} + // Has implements hashsync.ItemStore. func (a *ItemStoreAdapter) Has(k hashsync.Ordered) (bool, error) { h := k.(types.Hash32) @@ -224,3 +275,7 @@ func (ia *iteratorAdapter) Key() (hashsync.Ordered, error) { func (ia *iteratorAdapter) Next() error { return ia.it.Next() } + +func (ia *iteratorAdapter) Clone() hashsync.Iterator { + return &iteratorAdapter{it: ia.it.Clone()} +} diff --git a/sync2/dbsync/dbiter.go b/sync2/dbsync/dbiter.go index e834352d32..0d80ab696d 100644 --- a/sync2/dbsync/dbiter.go +++ b/sync2/dbsync/dbiter.go @@ -67,7 +67,7 @@ type dbRangeIterator struct { loaded bool } -var _ iterator = &dbRangeIterator{} +var _ hashsync.Iterator = &dbRangeIterator{} // makeDBIterator creates a dbRangeIterator and initializes it from the database. // If query returns no rows even after starting from zero ID, errEmptySet error is returned. @@ -76,7 +76,7 @@ func newDBRangeIterator( query string, from KeyBytes, maxChunkSize int, -) iterator { +) hashsync.Iterator { if from == nil { panic("BUG: makeDBIterator: nil from") } @@ -203,7 +203,7 @@ func (it *dbRangeIterator) Next() error { return it.load() } -func (it *dbRangeIterator) clone() iterator { +func (it *dbRangeIterator) Clone() hashsync.Iterator { cloned := *it cloned.from = slices.Clone(it.from) cloned.chunk = make([]KeyBytes, len(it.chunk)) @@ -215,15 +215,15 @@ func (it *dbRangeIterator) clone() iterator { type combinedIterator struct { startingPoint hashsync.Ordered - iters []iterator - wrapped []iterator - ahead iterator + iters []hashsync.Iterator + wrapped []hashsync.Iterator + ahead hashsync.Iterator aheadIdx int } // combineIterators combines multiple iterators into one, returning the smallest current // key among all iterators at each step. -func combineIterators(startingPoint hashsync.Ordered, iters ...iterator) iterator { +func combineIterators(startingPoint hashsync.Ordered, iters ...hashsync.Iterator) hashsync.Iterator { return &combinedIterator{startingPoint: startingPoint, iters: iters} } @@ -257,7 +257,7 @@ func (c *combinedIterator) begin() error { return nil } -func (c *combinedIterator) aheadIterator() (iterator, error) { +func (c *combinedIterator) aheadIterator() (hashsync.Iterator, error) { if err := c.begin(); err != nil { return nil, err } @@ -325,17 +325,17 @@ func (c *combinedIterator) Next() error { return nil } -func (c *combinedIterator) clone() iterator { +func (c *combinedIterator) Clone() hashsync.Iterator { cloned := &combinedIterator{ - iters: make([]iterator, len(c.iters)), - wrapped: make([]iterator, len(c.wrapped)), + iters: make([]hashsync.Iterator, len(c.iters)), + wrapped: make([]hashsync.Iterator, len(c.wrapped)), startingPoint: c.startingPoint, } for i, it := range c.iters { - cloned.iters[i] = it.clone() + cloned.iters[i] = it.Clone() } for i, it := range c.wrapped { - cloned.wrapped[i] = it.clone() + cloned.wrapped[i] = it.Clone() } return cloned } diff --git a/sync2/dbsync/dbiter_test.go b/sync2/dbsync/dbiter_test.go index 3463c79281..bcc3529987 100644 --- a/sync2/dbsync/dbiter_test.go +++ b/sync2/dbsync/dbiter_test.go @@ -296,7 +296,7 @@ func TestDBRangeIterator(t *testing.T) { } // when there are no items, errEmptySet is returned require.NotEmpty(t, tc.items) - clonedIt := it.clone() + clonedIt := it.Clone() var collected []KeyBytes for i := 0; i < len(tc.items); i++ { k := itKey(t, it) @@ -312,7 +312,7 @@ func TestDBRangeIterator(t *testing.T) { expected := slices.Concat(tc.items[tc.fromN:], tc.items[:tc.fromN]) require.Equal(t, expected, collected, "count=%d from=%s maxChunkSize=%d", len(tc.items), hex.EncodeToString(tc.from), maxChunkSize) - clonedIt = it.clone() + clonedIt = it.Clone() for range 2 { for i := 0; i < len(tc.items); i++ { k := itKey(t, it) @@ -354,7 +354,7 @@ func (it *fakeIterator) Next() error { return nil } -func (it *fakeIterator) clone() iterator { +func (it *fakeIterator) Clone() hashsync.Iterator { cloned := &fakeIterator{ allItems: make([]KeyBytes, len(it.allItems)), } @@ -380,7 +380,7 @@ func TestCombineIterators(t *testing.T) { } it := combineIterators(nil, it1, it2) - clonedIt := it.clone() + clonedIt := it.Clone() for range 3 { var collected []KeyBytes for range 4 { diff --git a/sync2/dbsync/fptree.go b/sync2/dbsync/fptree.go index 03bcf8ee2e..f65e90d43c 100644 --- a/sync2/dbsync/fptree.go +++ b/sync2/dbsync/fptree.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "encoding/hex" + "errors" "fmt" "io" "math/bits" @@ -16,6 +17,8 @@ import ( "github.com/spacemeshos/go-spacemesh/sync2/hashsync" ) +var errEasySplitFailed = errors.New("easy split failed") + type trace struct { traceEnabled bool traceStack []string @@ -30,7 +33,7 @@ func (t *trace) enter(format string, args ...any) { return } for n, arg := range args { - if it, ok := arg.(iterator); ok { + if it, ok := arg.(hashsync.Iterator); ok { args[n] = formatIter(it) } } @@ -47,7 +50,11 @@ func (t *trace) leave(results ...any) { panic("BUG: trace stack underflow") } for n, r := range results { - if it, ok := r.(iterator); ok { + if err, ok := r.(error); ok { + results = []any{fmt.Sprintf("", err)} + break + } + if it, ok := r.(hashsync.Iterator); ok { results[n] = formatIter(it) } } @@ -66,7 +73,7 @@ func (t *trace) leave(results ...any) { func (t *trace) log(format string, args ...any) { if t.traceEnabled { for n, arg := range args { - if it, ok := arg.(iterator); ok { + if it, ok := arg.(hashsync.Iterator); ok { args[n] = formatIter(it) } } @@ -242,6 +249,13 @@ func (p prefix) idAfter(b KeyBytes) { } s := uint64(64 - p.len()) v := (p.bits() + 1) << s + if v == 0 { + // wraparound + for n := range b { + b[n] = 0 + } + return + } binary.BigEndian.PutUint64(b, v) for n := 8; n < len(b); n++ { b[n] = 0xff @@ -302,32 +316,50 @@ type fpResult struct { fp fingerprint count uint32 itype int - start, end iterator + start, end hashsync.Iterator } type aggContext struct { - x, y KeyBytes - fp fingerprint - count uint32 - itype int - limit int - total uint32 - start, end iterator - lastPrefix *prefix -} + x, y KeyBytes + fp, fp0 fingerprint + count, count0 uint32 + itype int + limit int + total uint32 + start, end hashsync.Iterator + lastPrefix, lastPrefix0 *prefix + easySplit bool +} + +// QQQQQ: TBD: rm +// // pruneX returns true if any ID derived from the specified prefix is strictly below x. +// // With inverse intervals, it should only be used when processing [x, max) part of the +// // interval. +// func (ac *aggContext) pruneX(p prefix) bool { +// // QQQQQ: TBD: <= must work, check !!!!! +// return (p.bits()+1)<<(64-p.len())-1 < load64(ac.x) +// } // prefixAtOrAfterX verifies that the any key with the prefix p is at or after x. // It can be used for the whole interval in case of a normal interval. -// With inverse intervals, it should only be used for the [x, max) part -// of the interval. +// With inverse intervals, it should only be used when processing the [x, max) part of the +// interval. func (ac *aggContext) prefixAtOrAfterX(p prefix) bool { return p.bits()<<(64-p.len()) >= load64(ac.x) } +// QQQQQ: TBD: rm +// // pruneY returns true if any ID derived from the specified prefix is at or after y. +// // With inverse intervals, it should only be used when processing the [0, y) part of the +// // interval. +// func (ac *aggContext) pruneY(p prefix) bool { +// return p.bits()<<(64-p.len()) >= load64(ac.y) +// } + // prefixBelowY verifies that the any key with the prefix p is below y. // It can be used for the whole interval in case of a normal interval. -// With inverse intervals, it should only be used for the [0, y) part -// of the interval. +// With inverse intervals, it should only be used when processing the [0, y) part of the +// interval. func (ac *aggContext) prefixBelowY(p prefix) bool { // QQQQQ: TBD: <= must work, check !!!!! return (p.bits()+1)<<(64-p.len())-1 < load64(ac.y) @@ -360,13 +392,31 @@ func (ac *aggContext) nodeBelowY(node node, p prefix) bool { return ac.prefixBelowY(p) } +func (ac *aggContext) pruneY(node node) bool { + if node.c != 1 { + return false + } + k := make(KeyBytes, len(ac.y)) + copy(k, node.fp[:]) + return bytes.Compare(k, ac.y) >= 0 +} + func (ac *aggContext) maybeIncludeNode(node node, p prefix) bool { switch { case ac.limit < 0: - case uint32(ac.limit) < node.c: - return false - default: + case uint32(ac.limit) >= node.c: ac.limit -= int(node.c) + case ac.easySplit && node.leaf(): + ac.limit = -1 + ac.fp0 = ac.fp + ac.count0 = ac.count + ac.lastPrefix0 = ac.lastPrefix + copy(ac.fp[:], node.fp[:]) + ac.count = node.c + ac.lastPrefix = &p + return true + default: + return false } ac.fp.update(node.fp[:]) ac.count += node.c @@ -374,16 +424,11 @@ func (ac *aggContext) maybeIncludeNode(node node, p prefix) bool { return true } -type iterator interface { - hashsync.Iterator - clone() iterator -} - type idStore interface { clone() idStore registerHash(h KeyBytes) error - start() iterator - iter(from KeyBytes) iterator + start() hashsync.Iterator + iter(from KeyBytes) hashsync.Iterator } type fpTree struct { @@ -574,10 +619,15 @@ func (ft *fpTree) followPrefix(from nodeIndex, p, followed prefix) (idx nodeInde // It returns a boolean indicating whether the limit or the right edge (y) was reached and // an error, if any. func (ft *fpTree) aggregateEdge(x, y KeyBytes, p prefix, ac *aggContext) (cont bool, err error) { - ft.log("aggregateEdge: x %s y %s p %s limit %d count %d", x, y, p, ac.limit, ac.count) + ft.enter("aggregateEdge: x %s y %s p %s limit %d count %d", x, y, p, ac.limit, ac.count) defer func() { - ft.log("aggregateEdge ==> limit %d count %d\n", ac.limit, ac.count) + ft.leave(ac.limit, ac.count, cont, err) }() + if ac.easySplit { + // easySplit means we should not be querying the database, + // so we'll have to retry using slower strategy + return false, errEasySplitFailed + } if ac.limit == 0 && ac.end != nil { ft.log("aggregateEdge: limit is 0 and end already set") return false, nil @@ -592,7 +642,7 @@ func (ft *fpTree) aggregateEdge(x, y KeyBytes, p prefix, ac *aggContext) (cont b ft.log("aggregateEdge: startFrom %s", startFrom) it := ft.iter(startFrom) if ac.limit == 0 { - ac.end = it.clone() + ac.end = it.Clone() if x != nil { ft.log("aggregateEdge: limit 0: x is not nil, setting start to %s", ac.start) ac.start = ac.end @@ -601,7 +651,7 @@ func (ft *fpTree) aggregateEdge(x, y KeyBytes, p prefix, ac *aggContext) (cont b return false, nil } if x != nil { - ac.start = it.clone() + ac.start = it.Clone() ft.log("aggregateEdge: x is not nil, setting start to %s", ac.start) } @@ -647,82 +697,63 @@ func (ft *fpTree) node(idx nodeIndex) (node, bool) { return ft.np.node(idx), true } -// QQQQQ: rm -// func (ft *fpTree) markEnd(p prefix, ac *aggContext) error { -// if ac.end != nil { -// return nil -// } -// k := make(KeyBytes, ft.keyLen) -// p.minID(k) -// it, err := ft.iter(k) -// if err != nil { -// return err -// } -// ac.end = it -// ft.log("markEnd: p %s k %s => %s", p, k, it.Key().(fmt.Stringer)) -// return nil -// } - func (ft *fpTree) aggregateUpToLimit(idx nodeIndex, p prefix, ac *aggContext) (cont bool, err error) { ft.enter("aggregateUpToLimit: idx %d p %s limit %d cur_fp %s cur_count %d", idx, p, ac.limit, ac.fp, ac.count) defer func() { - ft.leave(ac.fp, ac.count) + ft.leave(ac.fp, ac.count, err) }() - for { - node, ok := ft.node(idx) - switch { - case !ok: - ft.log("stop: no node") - return true, nil - // case ac.limit == 0: - // // for ac.limit == 0, it's important that we still visit the node - // // so that we can get the item immediately following the included items - // ft.log("stop: limit exhausted") - // return false, ft.markEnd(p, ac) - case ac.maybeIncludeNode(node, p): - // node is fully included - ft.log("included fully") - return true, nil - case node.leaf(): - // reached the limit on this node, do not need to continue after - // done with it - cont, err := ft.aggregateEdge(nil, nil, p, ac) - if err != nil { - return false, err - } - if cont { - panic("BUG: expected limit not reached") - } - return false, nil - default: - pLeft := p.left() - left, haveLeft := ft.node(node.left) - if haveLeft { - if ac.maybeIncludeNode(left, pLeft) { - // left node is fully included, after which - // we need to stop somewhere in the right subtree - ft.log("include left in full") - } else { - // we must stop somewhere in the left subtree, - // and the right subtree is irrelevant - ft.log("descend to the left") - idx = node.left - p = pLeft - continue + node, ok := ft.node(idx) + switch { + case !ok: + ft.log("stop: no node") + return true, nil + case ac.maybeIncludeNode(node, p): + // node is fully included + ft.log("included fully, lastPrefix = %s", ac.lastPrefix) + return true, nil + case node.leaf(): + // reached the limit on this node, do not need to continue after + // done with it + cont, err := ft.aggregateEdge(nil, nil, p, ac) + if err != nil { + return false, err + } + if cont { + panic("BUG: expected limit not reached") + } + return false, nil + default: + pLeft := p.left() + left, haveLeft := ft.node(node.left) + if haveLeft { + if ac.maybeIncludeNode(left, pLeft) { + // left node is fully included, after which + // we need to stop somewhere in the right subtree + ft.log("include left in full") + } else { + // we must stop somewhere in the left subtree, + // and the right subtree is irrelevant unless + // easySplit is being done and we must restart + // after the limit is exhausted + ft.log("descend to the left") + if cont, err := ft.aggregateUpToLimit(node.left, pLeft, ac); !cont || err != nil { + return cont, err + } + if !ac.easySplit { + return } } - ft.log("descend to the right") - idx = node.right - p = p.right() } + ft.log("descend to the right") + return ft.aggregateUpToLimit(node.right, p.right(), ac) } } func (ft *fpTree) aggregateLeft(idx nodeIndex, v uint64, p prefix, ac *aggContext) (cont bool, err error) { ft.enter("aggregateLeft: idx %d v %016x p %s limit %d", idx, v, p, ac.limit) defer func() { - ft.leave(ac.fp, ac.count) + ft.leave(ac.fp, ac.count, err) }() node, ok := ft.node(idx) switch { @@ -730,11 +761,12 @@ func (ft *fpTree) aggregateLeft(idx nodeIndex, v uint64, p prefix, ac *aggContex // for ac.limit == 0, it's important that we still visit the node // so that we can get the item immediately following the included items ft.log("stop: no node") - // QQQQQ: no mark end.... - return true, nil //ft.markEnd(p, ac) - // case ac.limit == 0: - // ft.log("stop: limit exhausted") - // return false, ft.markEnd(p, ac) + return true, nil + // QQQQQ: TBD: rm + // case ac.pruneX(p): + // // TODO: rm, this never happens + // ft.log("prune: prefix not at or after x") + // return true, nil case ac.nodeAtOrAfterX(node, p) && ac.maybeIncludeNode(node, p): ft.log("including node in full: %s limit %d", p, ac.limit) return ac.limit != 0, nil @@ -765,16 +797,13 @@ func (ft *fpTree) aggregateLeft(idx nodeIndex, v uint64, p prefix, ac *aggContex func (ft *fpTree) aggregateRight(idx nodeIndex, v uint64, p prefix, ac *aggContext) (cont bool, err error) { ft.enter("aggregateRight: idx %d v %016x p %s limit %d", idx, v, p, ac.limit) defer func() { - ft.leave(ac.fp, ac.count) + ft.leave(ac.fp, ac.count, err) }() node, ok := ft.node(idx) switch { case !ok: ft.log("stop: no node") return true, nil - // case ac.limit == 0: - // ft.log("stop: limit exhausted") - // return false, ft.markEnd(p, ac) case ac.nodeBelowY(node, p) && ac.maybeIncludeNode(node, p): ft.log("including node in full: %s limit %d", p, ac.limit) return ac.limit != 0, nil @@ -782,6 +811,10 @@ func (ft *fpTree) aggregateRight(idx nodeIndex, v uint64, p prefix, ac *aggConte if node.left != noIndex || node.right != noIndex { panic("BUG: node @ maxDepth has children") } + if ac.pruneY(node) { + ft.log("node %d p %s pruned", idx, p) + return false, nil + } return ft.aggregateEdge(nil, ac.y, p, ac) case v&bit63 == 0: ft.log("go left to node %d", node.left) @@ -798,10 +831,14 @@ func (ft *fpTree) aggregateRight(idx nodeIndex, v uint64, p prefix, ac *aggConte } } -func (ft *fpTree) aggregateXX(ac *aggContext) error { +func (ft *fpTree) aggregateXX(ac *aggContext) (err error) { // [x; x) interval which denotes the whole set unless // the limit is specified, in which case we need to start aggregating // with x and wrap around if necessary + ft.enter("aggregateXX: x %s limit %d", ac.x, ac.limit) + defer func() { + ft.leave(ac, err) + }() if ft.root == noIndex { ft.log("empty set (no root)") } else if ac.maybeIncludeNode(ft.np.node(ft.root), 0) { @@ -814,8 +851,12 @@ func (ft *fpTree) aggregateXX(ac *aggContext) error { return nil } -func (ft *fpTree) aggregateSimple(ac *aggContext) error { +func (ft *fpTree) aggregateSimple(ac *aggContext) (err error) { // "proper" interval: [x; lca); (lca; y) + ft.enter("aggregateSimple: x %s y %s limit %d", ac.x, ac.y, ac.limit) + defer func() { + ft.leave(ac, err) + }() p := commonPrefix(ac.x, ac.y) lcaIdx, lcaPrefix, fullPrefixFound := ft.followPrefix(ft.root, p, 0) var lca node @@ -829,9 +870,13 @@ func (ft *fpTree) aggregateSimple(ac *aggContext) error { if lcaPrefix != p { panic("BUG: bad followedPrefix") } - ft.aggregateLeft(lca.left, load64(ac.x)<<(p.len()+1), p.left(), ac) + if _, err := ft.aggregateLeft(lca.left, load64(ac.x)<<(p.len()+1), p.left(), ac); err != nil { + return err + } if ac.limit != 0 { - ft.aggregateRight(lca.right, load64(ac.y)<<(p.len()+1), p.right(), ac) + if _, err := ft.aggregateRight(lca.right, load64(ac.y)<<(p.len()+1), p.right(), ac); err != nil { + return err + } } case lcaIdx == noIndex || !lca.leaf(): ft.log("commonPrefix %s NOT found b/c no items have it", p) @@ -847,9 +892,15 @@ func (ft *fpTree) aggregateSimple(ac *aggContext) error { return nil } -func (ft *fpTree) aggregateInverse(ac *aggContext) error { +func (ft *fpTree) aggregateInverse(ac *aggContext) (err error) { // inverse interval: [min; y); [x; max] - // first, we handle [x; max] part + + // First, we handle [x; max] part + // For this, we process the subtree rooted in the LCA of 0x000000... (all 0s) and x + ft.enter("aggregateInverse: x %s y %s limit %d", ac.x, ac.y, ac.limit) + defer func() { + ft.leave(ac, err) + }() pf0 := preFirst0(ac.x) idx0, followedPrefix, found := ft.followPrefix(ft.root, pf0, 0) var pf0Node node @@ -862,7 +913,13 @@ func (ft *fpTree) aggregateInverse(ac *aggContext) error { if followedPrefix != pf0 { panic("BUG: bad followedPrefix") } - ft.aggregateLeft(idx0, load64(ac.x)< %s", *ac.lastPrefix, fpr.end) } else { fpr.end = ft.iter(y) ft.log("fingerprintInterval: end at y: %s", fpr.end) @@ -980,6 +1044,59 @@ func (ft *fpTree) fingerprintInterval(x, y KeyBytes, limit int) (fpr fpResult, e return fpr, nil } +// easySplit splits an interval in two parts trying to do it in such way that the first +// part has close to limit items while not making any idStore queries so that the database +// is not accessed. If the split can't be done, which includes the situation where one of +// the sides has 0 items, easySplit returns errEasySplitFailed error +func (ft *fpTree) easySplit(x, y KeyBytes, limit int) (fprA, fprB fpResult, err error) { + ft.enter("easySplit: x %s y %s limit %d", x, y, limit) + defer func() { + ft.leave(fprA.fp, fprA.count, fprA.itype, fprA.start, fprA.end, + fprB.fp, fprB.count, fprB.itype, fprB.start, fprB.end, err) + }() + if limit < 0 { + panic("BUG: easySplit with limit < 0") + } + ac := aggContext{x: x, y: y, limit: limit, easySplit: true} + if err := ft.aggregateInterval(&ac); err != nil { + return fpResult{}, fpResult{}, err + } + + if ac.total == 0 { + return fpResult{}, fpResult{}, nil + } + + if ac.count0 == 0 || ac.count == 0 { + // need to get some items on both sides for the easy split to succeed + return fpResult{}, fpResult{}, errEasySplitFailed + } + + // It should not be possible to have ac.lastPrefix0 == nil or ac.lastPrefix == nil + // if both ac.count0 and ac.count are non-zero, b/c of how + // aggContext.maybeIncludeNode works + if ac.lastPrefix0 == nil || ac.lastPrefix == nil { + panic("BUG: easySplit lastPrefix or lastPrefix0 not set") + } + + // ac.start / ac.end are only set in aggregateEdge which fails with + // errEasySplitFailed if easySplit is enabled, so we can ignore them here + fprA = fpResult{ + fp: ac.fp0, + count: ac.count0, + itype: ac.itype, + start: ft.iter(x), + end: ft.endIterFromPrefix(*ac.lastPrefix0), + } + fprB = fpResult{ + fp: ac.fp, + count: ac.count, + itype: ac.itype, + start: fprA.end.Clone(), + end: ft.endIterFromPrefix(*ac.lastPrefix), + } + return fprA, fprB, nil +} + func (ft *fpTree) dumpNode(w io.Writer, idx nodeIndex, indent, dir string) { if idx == noIndex { return @@ -1015,7 +1132,7 @@ func (ft *fpTree) count() int { } type iterFormatter struct { - it iterator + it hashsync.Iterator } func (f iterFormatter) String() string { @@ -1026,7 +1143,7 @@ func (f iterFormatter) String() string { } } -func formatIter(it iterator) fmt.Stringer { +func formatIter(it hashsync.Iterator) fmt.Stringer { return iterFormatter{it: it} } diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index af47ba5228..69d0c29056 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -1,6 +1,7 @@ package dbsync import ( + "encoding/binary" "fmt" "math" "math/rand" @@ -785,12 +786,12 @@ func (noIDStore) registerHash(h KeyBytes) error { return nil } -func (noIDStore) start() iterator { +func (noIDStore) start() hashsync.Iterator { panic("no ID store") } -func (noIDStore) iter(from KeyBytes) iterator { +func (noIDStore) iter(from KeyBytes) hashsync.Iterator { return noIter{} } @@ -804,11 +805,11 @@ func (noIter) Next() error { panic("no ID store") } -func (noIter) clone() iterator { +func (noIter) Clone() hashsync.Iterator { return noIter{} } -var _ iterator = &noIter{} +var _ hashsync.Iterator = &noIter{} // TestFPTreeNoIDStore tests that an fpTree can avoid using an idStore if X has only // 0 bits below max-depth and Y has only 1 bits below max-depth. It also checks that an fpTree @@ -1211,6 +1212,79 @@ func TestFPTreeManyItems(t *testing.T) { // TBD: test start/end iterators } +func verifyEasySplit(t *testing.T, ft *fpTree, x, y KeyBytes) { + t.Logf("--- fingerprint interval %s %s ---", x.String(), y.String()) + fpr, err := ft.fingerprintInterval(x, y, -1) + require.NoError(t, err) + if fpr.count <= 1 { + return + } + a, err := fpr.start.Key() + require.NoError(t, err) + b, err := fpr.end.Key() + require.NoError(t, err) + + m := fpr.count / 2 + t.Logf("--- easy split %s %s %d ---", x.String(), y.String(), m) + fpr1, fpr2, err := ft.easySplit(x[:], y[:], int(m)) + require.NoError(t, err) + require.NotZero(t, fpr1.count) + require.NotZero(t, fpr2.count) + require.Equal(t, fpr.count, fpr1.count+fpr2.count) + require.Equal(t, fpr.itype, fpr1.itype) + require.Equal(t, fpr.itype, fpr2.itype) + fp := fpr1.fp + fp.update(fpr2.fp[:]) + require.Equal(t, fpr.fp, fp) + require.Equal(t, a, itKey(t, fpr1.start)) + require.Equal(t, b, itKey(t, fpr2.end)) + middle := itKey(t, fpr1.end) + require.Equal(t, middle, itKey(t, fpr2.start)) + + fpr11, err := ft.fingerprintInterval(x, middle, -1) + require.NoError(t, err) + require.Equal(t, fpr1.fp, fpr11.fp) + require.Equal(t, fpr1.count, fpr11.count) + require.Equal(t, a, itKey(t, fpr11.start)) + require.Equal(t, middle, itKey(t, fpr11.end)) + + fpr12, err := ft.fingerprintInterval(middle, y, -1) + require.NoError(t, err) + require.Equal(t, fpr2.fp, fpr12.fp) + require.Equal(t, fpr2.count, fpr12.count) + require.Equal(t, middle, itKey(t, fpr12.start)) + require.Equal(t, b, itKey(t, fpr12.end)) + + // TBD: QQQQQ: recurse! +} + +func TestEasySplit(t *testing.T) { + var np nodePool + maxDepth := 24 + ft := newFPTree(&np, newInMemIDStore(32), 32, maxDepth) + for range 10 { + h := types.RandomHash() + t.Logf("adding hash %s", h.String()) + ft.addHash(h[:]) + } + k, err := ft.start().Key() + require.NoError(t, err) + x := k.(KeyBytes) + v := load64(x) & ^(1<<(64-maxDepth) - 1) + binary.BigEndian.PutUint64(x, v) + for i := 8; i < len(x); i++ { + x[i] = 0 + } + + ft.traceEnabled = true + var sb strings.Builder + ft.dump(&sb) + t.Logf("tree:\n%s", sb.String()) + + verifyEasySplit(t, ft, x, x) + // TBD: test split with leafs that have c > 1 +} + const dbFile = "/Users/ivan4th/Library/Application Support/Spacemesh/node-data/7c8cef2b/state.sql" // func dumbAggATXs(t *testing.T, db sql.StateDatabase, x, y types.Hash32) fpResult { diff --git a/sync2/dbsync/inmemidstore.go b/sync2/dbsync/inmemidstore.go index c3619f21bb..5694cfab19 100644 --- a/sync2/dbsync/inmemidstore.go +++ b/sync2/dbsync/inmemidstore.go @@ -34,11 +34,11 @@ func (s *inMemIDStore) registerHash(h KeyBytes) error { return nil } -func (s *inMemIDStore) start() iterator { +func (s *inMemIDStore) start() hashsync.Iterator { return &inMemIDStoreIterator{sl: s.sl, node: s.sl.First()} } -func (s *inMemIDStore) iter(from KeyBytes) iterator { +func (s *inMemIDStore) iter(from KeyBytes) hashsync.Iterator { node := s.sl.FindGTENode(from) if node == nil { node = s.sl.First() @@ -51,7 +51,7 @@ type inMemIDStoreIterator struct { node *skiplist.Node } -var _ iterator = &inMemIDStoreIterator{} +var _ hashsync.Iterator = &inMemIDStoreIterator{} func (it *inMemIDStoreIterator) Key() (hashsync.Ordered, error) { if it.node == nil { @@ -70,7 +70,7 @@ func (it *inMemIDStoreIterator) Next() error { return nil } -func (it *inMemIDStoreIterator) clone() iterator { +func (it *inMemIDStoreIterator) Clone() hashsync.Iterator { cloned := *it return &cloned } diff --git a/sync2/dbsync/inmemidstore_test.go b/sync2/dbsync/inmemidstore_test.go index 623e881757..0a74582d3b 100644 --- a/sync2/dbsync/inmemidstore_test.go +++ b/sync2/dbsync/inmemidstore_test.go @@ -4,14 +4,16 @@ import ( "encoding/hex" "testing" + "github.com/stretchr/testify/require" + "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/common/util" - "github.com/stretchr/testify/require" + "github.com/spacemeshos/go-spacemesh/sync2/hashsync" ) func TestInMemIDStore(t *testing.T) { var ( - it iterator + it hashsync.Iterator err error ) s := newInMemIDStore(32) diff --git a/sync2/dbsync/sqlidstore.go b/sync2/dbsync/sqlidstore.go index 2569953c43..616ba81ecf 100644 --- a/sync2/dbsync/sqlidstore.go +++ b/sync2/dbsync/sqlidstore.go @@ -4,6 +4,7 @@ import ( "bytes" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/hashsync" ) const sqlMaxChunkSize = 1024 @@ -29,12 +30,12 @@ func (s *sqlIDStore) registerHash(h KeyBytes) error { return nil } -func (s *sqlIDStore) start() iterator { +func (s *sqlIDStore) start() hashsync.Iterator { // TODO: should probably use a different query to get the first key return s.iter(make(KeyBytes, s.keyLen)) } -func (s *sqlIDStore) iter(from KeyBytes) iterator { +func (s *sqlIDStore) iter(from KeyBytes) hashsync.Iterator { if len(from) != s.keyLen { panic("BUG: invalid key length") } @@ -66,13 +67,13 @@ func (s *dbBackedStore) registerHash(h KeyBytes) error { return s.inMemIDStore.registerHash(h) } -func (s *dbBackedStore) start() iterator { +func (s *dbBackedStore) start() hashsync.Iterator { dbIt := s.sqlIDStore.start() memIt := s.inMemIDStore.start() return combineIterators(nil, dbIt, memIt) } -func (s *dbBackedStore) iter(from KeyBytes) iterator { +func (s *dbBackedStore) iter(from KeyBytes) hashsync.Iterator { dbIt := s.sqlIDStore.iter(from) memIt := s.inMemIDStore.iter(from) return combineIterators(from, dbIt, memIt) diff --git a/sync2/hashsync/handler_test.go b/sync2/hashsync/handler_test.go index 58b37361d7..a83d1330f6 100644 --- a/sync2/hashsync/handler_test.go +++ b/sync2/hashsync/handler_test.go @@ -132,6 +132,10 @@ func (it *sliceIterator) Next() error { return nil } +func (it *sliceIterator) Clone() Iterator { + return &sliceIterator{s: it.s} +} + type fakeSend struct { x, y Ordered count int @@ -434,6 +438,7 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) Requester { } var client Requester verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []RangeSetReconcilerOption) bool { + opts = append(opts, WithRangeSyncLogger(zaptest.NewLogger(t))) // QQQQQ: TBD: rm withClientServer( storeA, getRequester, opts, func(ctx context.Context, client Requester, srvPeerID p2p.Peer) { diff --git a/sync2/hashsync/interface.go b/sync2/hashsync/interface.go index 32bfc1f46d..3379b75976 100644 --- a/sync2/hashsync/interface.go +++ b/sync2/hashsync/interface.go @@ -20,6 +20,8 @@ type Iterator interface { Key() (Ordered, error) // Next advances the iterator Next() error + // Clone returns a copy of the iterator + Clone() Iterator } type RangeInfo struct { @@ -38,6 +40,9 @@ type ItemStore interface { // If both x and y is nil, the whole set of items is used. // If only x or only y is nil, GetRangeInfo panics GetRangeInfo(preceding Iterator, x, y Ordered, count int) (RangeInfo, error) + // SplitRange splits the range roughly after the specified count of items, + // returning RangeInfo for the first half and the second half of the range. + SplitRange(preceding Iterator, x, y Ordered, count int) (RangeInfo, RangeInfo, error) // Min returns the iterator pointing at the minimum element // in the store. If the store is empty, it returns nil Min() (Iterator, error) diff --git a/sync2/hashsync/mocks_test.go b/sync2/hashsync/mocks_test.go index adb27568f4..93adca1ec9 100644 --- a/sync2/hashsync/mocks_test.go +++ b/sync2/hashsync/mocks_test.go @@ -44,11 +44,12 @@ func (m *MockIterator) EXPECT() *MockIteratorMockRecorder { } // Key mocks base method. -func (m *MockIterator) Key() Ordered { +func (m *MockIterator) Key() (Ordered, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Key") ret0, _ := ret[0].(Ordered) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // Key indicates an expected call of Key. @@ -64,19 +65,19 @@ type MockIteratorKeyCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockIteratorKeyCall) Return(arg0 Ordered) *MockIteratorKeyCall { - c.Call = c.Call.Return(arg0) +func (c *MockIteratorKeyCall) Return(arg0 Ordered, arg1 error) *MockIteratorKeyCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockIteratorKeyCall) Do(f func() Ordered) *MockIteratorKeyCall { +func (c *MockIteratorKeyCall) Do(f func() (Ordered, error)) *MockIteratorKeyCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockIteratorKeyCall) DoAndReturn(f func() Ordered) *MockIteratorKeyCall { +func (c *MockIteratorKeyCall) DoAndReturn(f func() (Ordered, error)) *MockIteratorKeyCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -335,6 +336,46 @@ func (c *MockItemStoreMinCall) DoAndReturn(f func() (Iterator, error)) *MockItem return c } +// SplitRange mocks base method. +func (m *MockItemStore) SplitRange(preceding Iterator, x, y Ordered, count int) (RangeInfo, RangeInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SplitRange", preceding, x, y, count) + ret0, _ := ret[0].(RangeInfo) + ret1, _ := ret[1].(RangeInfo) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// SplitRange indicates an expected call of SplitRange. +func (mr *MockItemStoreMockRecorder) SplitRange(preceding, x, y, count any) *MockItemStoreSplitRangeCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SplitRange", reflect.TypeOf((*MockItemStore)(nil).SplitRange), preceding, x, y, count) + return &MockItemStoreSplitRangeCall{Call: call} +} + +// MockItemStoreSplitRangeCall wrap *gomock.Call +type MockItemStoreSplitRangeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockItemStoreSplitRangeCall) Return(arg0, arg1 RangeInfo, arg2 error) *MockItemStoreSplitRangeCall { + c.Call = c.Call.Return(arg0, arg1, arg2) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockItemStoreSplitRangeCall) Do(f func(Iterator, Ordered, Ordered, int) (RangeInfo, RangeInfo, error)) *MockItemStoreSplitRangeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockItemStoreSplitRangeCall) DoAndReturn(f func(Iterator, Ordered, Ordered, int) (RangeInfo, RangeInfo, error)) *MockItemStoreSplitRangeCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // MockRequester is a mock of Requester interface. type MockRequester struct { ctrl *gomock.Controller diff --git a/sync2/hashsync/rangesync.go b/sync2/hashsync/rangesync.go index 392310f858..53fd36d184 100644 --- a/sync2/hashsync/rangesync.go +++ b/sync2/hashsync/rangesync.go @@ -208,24 +208,12 @@ func NewRangeSetReconciler(is ItemStore, opts ...RangeSetReconcilerOption) *Rang // return fmt.Sprintf("%s", it.Key()) // } -func (rsr *RangeSetReconciler) processSubrange(c Conduit, preceding Iterator, x, y Ordered) (Iterator, error) { - if preceding != nil { - k, err := preceding.Key() - if err != nil { - return nil, err - } - if k.Compare(x) > 0 { - preceding = nil - } - } +func (rsr *RangeSetReconciler) processSubrange(c Conduit, x, y Ordered, info RangeInfo) error { // fmt.Fprintf(os.Stderr, "QQQQQ: preceding=%q\n", // qqqqRmmeK(preceding)) // TODO: don't re-request range info for the first part of range after stop - rsr.log.Debug("processSubrange", IteratorField("preceding", preceding), HexField("x", x), HexField("y", y)) - info, err := rsr.is.GetRangeInfo(preceding, x, y, -1) - if err != nil { - return nil, err - } + rsr.log.Debug("processSubrange", HexField("x", x), HexField("y", y), + zap.Int("count", info.Count), HexField("fingerprint", info.Fingerprint)) // fmt.Fprintf(os.Stderr, "QQQQQ: start=%q end=%q info.Start=%q info.End=%q info.FP=%q x=%q y=%q\n", // qqqqRmmeK(start), qqqqRmmeK(end), qqqqRmmeK(info.Start), qqqqRmmeK(info.End), info.Fingerprint, x, y) switch { @@ -247,7 +235,7 @@ func (rsr *RangeSetReconciler) processSubrange(c Conduit, preceding Iterator, x, // Ask peer to send any items it has in the range rsr.log.Debug("processSubrange: send empty range", HexField("x", x), HexField("y", y)) if err := c.SendEmptyRange(x, y); err != nil { - return nil, err + return err } default: // The range is non-empty and large enough. @@ -255,11 +243,11 @@ func (rsr *RangeSetReconciler) processSubrange(c Conduit, preceding Iterator, x, rsr.log.Debug("processSubrange: send fingerprint", HexField("x", x), HexField("y", y), zap.Int("count", info.Count)) if err := c.SendFingerprint(x, y, info.Fingerprint, info.Count); err != nil { - return nil, err + return err } } // fmt.Fprintf(os.Stderr, "QQQQQ: info.End=%q\n", qqqqRmmeK(info.End)) - return info.End, nil + return nil } func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg SyncMessage) (it Iterator, done bool, err error) { @@ -389,31 +377,30 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg rsr.log.Debug("handleMessage: PRE split range", HexField("x", x), HexField("y", y), zap.Int("countArg", count)) - part, err := rsr.is.GetRangeInfo(preceding, x, y, count) + part0, part1, err := rsr.is.SplitRange(preceding, x, y, count) if err != nil { return nil, false, err } - if part.End == nil { - panic("BUG: can't split range with count > 1") - } rsr.log.Debug("handleMessage: split range", HexField("x", x), HexField("y", y), zap.Int("countArg", count), - zap.Int("count", part.Count), - HexField("fingerprint", part.Fingerprint), - IteratorField("start", part.Start), - IteratorField("middle", part.End)) - middle, err := part.End.Key() + zap.Int("count0", part0.Count), + HexField("fp0", part0.Fingerprint), + IteratorField("start0", part0.Start), + IteratorField("end0", part0.End), + zap.Int("count1", part1.Count), + HexField("fp1", part1.Fingerprint), + IteratorField("start1", part1.End), + IteratorField("end1", part1.End)) + middle, err := part0.End.Key() if err != nil { return nil, false, err } - next, err := rsr.processSubrange(c, info.Start, x, middle) - if err != nil { + if err := rsr.processSubrange(c, x, middle, part0); err != nil { return nil, false, err } // fmt.Fprintf(os.Stderr, "QQQQQ: next=%q\n", qqqqRmmeK(next)) - _, err = rsr.processSubrange(c, next, middle, y) - if err != nil { + if err := rsr.processSubrange(c, middle, y, part1); err != nil { return nil, false, err } // fmt.Fprintf(os.Stderr, "normal: split X %s - middle %s - Y %s:\n %s", diff --git a/sync2/hashsync/rangesync_test.go b/sync2/hashsync/rangesync_test.go index 4fac1a6845..41723a7768 100644 --- a/sync2/hashsync/rangesync_test.go +++ b/sync2/hashsync/rangesync_test.go @@ -2,6 +2,7 @@ package hashsync import ( "context" + "errors" "fmt" "math/rand" "slices" @@ -208,6 +209,13 @@ func (it *dumbStoreIterator) Next() error { return nil } +func (it *dumbStoreIterator) Clone() Iterator { + return &dumbStoreIterator{ + ds: it.ds, + n: it.n, + } +} + type dumbStore struct { keys []sampleID } @@ -307,6 +315,28 @@ func (ds *dumbStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) ( return r, nil } +func (ds *dumbStore) SplitRange(preceding Iterator, x, y Ordered, count int) (RangeInfo, RangeInfo, error) { + if count <= 0 { + panic("BUG: bad split count") + } + part0, err := ds.GetRangeInfo(preceding, x, y, count) + if err != nil { + return RangeInfo{}, RangeInfo{}, err + } + if part0.Count == 0 { + return RangeInfo{}, RangeInfo{}, errors.New("can't split empty range") + } + middle, err := part0.End.Key() + if err != nil { + return RangeInfo{}, RangeInfo{}, err + } + part1, err := ds.GetRangeInfo(part0.End.Clone(), middle, y, -1) + if err != nil { + return RangeInfo{}, RangeInfo{}, err + } + return part0, part1, nil +} + func (ds *dumbStore) Min() (Iterator, error) { if len(ds.keys) == 0 { return nil, nil @@ -372,6 +402,14 @@ func (it verifiedStoreIterator) Next() error { return nil } +func (it verifiedStoreIterator) Clone() Iterator { + return verifiedStoreIterator{ + t: it.t, + knownGood: it.knownGood.Clone(), + it: it.it.Clone(), + } +} + type verifiedStore struct { t *testing.T knownGood ItemStore @@ -406,23 +444,7 @@ func (vs *verifiedStore) Add(ctx context.Context, k Ordered) error { return nil } -func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) (RangeInfo, error) { - var ( - ri1, ri2 RangeInfo - err error - ) - if preceding != nil { - p := preceding.(verifiedStoreIterator) - ri1, err = vs.knownGood.GetRangeInfo(p.knownGood, x, y, count) - require.NoError(vs.t, err) - ri2, err = vs.store.GetRangeInfo(p.it, x, y, count) - require.NoError(vs.t, err) - } else { - ri1, err = vs.knownGood.GetRangeInfo(nil, x, y, count) - require.NoError(vs.t, err) - ri2, err = vs.store.GetRangeInfo(nil, x, y, count) - require.NoError(vs.t, err) - } +func (vs *verifiedStore) verifySameRangeInfo(ri1, ri2 RangeInfo) RangeInfo { require.Equal(vs.t, ri1.Fingerprint, ri2.Fingerprint, "range info fingerprint") require.Equal(vs.t, ri1.Count, ri2.Count, "range info count") ri := RangeInfo{ @@ -463,9 +485,49 @@ func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count in it: ri2.End, } } + return ri +} + +func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) (RangeInfo, error) { + var ( + ri1, ri2 RangeInfo + err error + ) + if preceding != nil { + p := preceding.(verifiedStoreIterator) + ri1, err = vs.knownGood.GetRangeInfo(p.knownGood, x, y, count) + require.NoError(vs.t, err) + ri2, err = vs.store.GetRangeInfo(p.it, x, y, count) + require.NoError(vs.t, err) + } else { + ri1, err = vs.knownGood.GetRangeInfo(nil, x, y, count) + require.NoError(vs.t, err) + ri2, err = vs.store.GetRangeInfo(nil, x, y, count) + require.NoError(vs.t, err) + } // QQQQQ: TODO: if count >= 0 and start+end != nil, do more calls to GetRangeInfo using resulting // end iterator key to make sure the range is correct - return ri, nil + return vs.verifySameRangeInfo(ri1, ri2), nil +} + +func (vs *verifiedStore) SplitRange(preceding Iterator, x, y Ordered, count int) (RangeInfo, RangeInfo, error) { + var ( + ri11, ri12, ri21, ri22 RangeInfo + err error + ) + if preceding != nil { + p := preceding.(verifiedStoreIterator) + ri11, ri12, err = vs.knownGood.SplitRange(p.knownGood, x, y, count) + require.NoError(vs.t, err) + ri21, ri22, err = vs.store.SplitRange(p.it, x, y, count) + require.NoError(vs.t, err) + } else { + ri11, ri12, err = vs.knownGood.SplitRange(nil, x, y, count) + require.NoError(vs.t, err) + ri21, ri22, err = vs.store.SplitRange(nil, x, y, count) + require.NoError(vs.t, err) + } + return vs.verifySameRangeInfo(ri11, ri21), vs.verifySameRangeInfo(ri12, ri22), nil } func (vs *verifiedStore) Min() (Iterator, error) { diff --git a/sync2/hashsync/sync_tree.go b/sync2/hashsync/sync_tree.go index 6f09a838cb..0ab32c6b8e 100644 --- a/sync2/hashsync/sync_tree.go +++ b/sync2/hashsync/sync_tree.go @@ -74,6 +74,7 @@ type SyncTreePointer interface { Value() any Prev() Next() + Clone() SyncTreePointer } type flags uint8 @@ -218,6 +219,13 @@ func (p *syncTreePointer) Next() { } } +func (p *syncTreePointer) Clone() SyncTreePointer { + return &syncTreePointer{ + parentStack: slices.Clone(p.parentStack), + node: p.node, + } +} + func (p *syncTreePointer) Key() Ordered { if p.node == nil { return nil diff --git a/sync2/hashsync/sync_tree_store.go b/sync2/hashsync/sync_tree_store.go index 271c411c51..c6cf658bb9 100644 --- a/sync2/hashsync/sync_tree_store.go +++ b/sync2/hashsync/sync_tree_store.go @@ -1,6 +1,9 @@ package hashsync -import "context" +import ( + "context" + "errors" +) type syncTreeIterator struct { st SyncTree @@ -29,6 +32,13 @@ func (it *syncTreeIterator) Next() error { return nil } +func (it *syncTreeIterator) Clone() Iterator { + return &syncTreeIterator{ + st: it.st, + ptr: it.ptr.Clone(), + } +} + type SyncTreeStore struct { st SyncTree identity any @@ -104,6 +114,29 @@ func (sts *SyncTreeStore) GetRangeInfo(preceding Iterator, x, y Ordered, count i }, nil } +// SplitRange implements ItemStore. +func (sts *SyncTreeStore) SplitRange(preceding Iterator, x, y Ordered, count int) (RangeInfo, RangeInfo, error) { + if count <= 0 { + panic("BUG: bad split count") + } + part0, err := sts.GetRangeInfo(preceding, x, y, count) + if err != nil { + return RangeInfo{}, RangeInfo{}, err + } + if part0.Count == 0 { + return RangeInfo{}, RangeInfo{}, errors.New("can't split empty range") + } + middle, err := part0.End.Key() + if err != nil { + return RangeInfo{}, RangeInfo{}, err + } + part1, err := sts.GetRangeInfo(part0.End.Clone(), middle, y, -1) + if err != nil { + return RangeInfo{}, RangeInfo{}, err + } + return part0, part1, nil +} + // Min implements ItemStore. func (sts *SyncTreeStore) Min() (Iterator, error) { return sts.iter(sts.st.Min()), nil From 2513a207f32e22d47932f948faddd0cdc003c287 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 1 Aug 2024 13:41:35 +0400 Subject: [PATCH 56/76] wip4 --- sync2/dbsync/fptree.go | 87 +++++++++++++++++++++-------- sync2/dbsync/fptree_test.go | 106 +++++++++++++++++++++++++++--------- 2 files changed, 144 insertions(+), 49 deletions(-) diff --git a/sync2/dbsync/fptree.go b/sync2/dbsync/fptree.go index f65e90d43c..cd6143b9e3 100644 --- a/sync2/dbsync/fptree.go +++ b/sync2/dbsync/fptree.go @@ -258,7 +258,7 @@ func (p prefix) idAfter(b KeyBytes) { } binary.BigEndian.PutUint64(b, v) for n := 8; n < len(b); n++ { - b[n] = 0xff + b[n] = 0 // QQQQQ: was 0xff } } @@ -392,8 +392,14 @@ func (ac *aggContext) nodeBelowY(node node, p prefix) bool { return ac.prefixBelowY(p) } -func (ac *aggContext) pruneY(node node) bool { +func (ac *aggContext) pruneY(node node, p prefix) bool { + if p.bits()<<(64-p.len()) >= load64(ac.y) { + // min ID derived from the prefix is at or after y => prune + return true + } if node.c != 1 { + // node has count > 1, so we can't use its fingerpeint + // to determine if it's below y return false } k := make(KeyBytes, len(ac.y)) @@ -402,11 +408,23 @@ func (ac *aggContext) pruneY(node node) bool { } func (ac *aggContext) maybeIncludeNode(node node, p prefix) bool { + fmt.Fprintf(os.Stderr, "QQQQQ: maybeIncludeNode: limit %d node.c %d easySplit %v leaf %v\n", + ac.limit, node.c, ac.easySplit, node.leaf()) switch { case ac.limit < 0: case uint32(ac.limit) >= node.c: ac.limit -= int(node.c) - case ac.easySplit && node.leaf(): + case !ac.easySplit || !node.leaf(): + return false + case ac.count == 0: + // We're doing a split and this node is over the limit, but the first part + // is still empty so we include this node in the first part and + // then switch to the second part + ac.limit = 0 + default: + // We're doing a split and this node is over the limit, so store count and + // fingerpint for the first part and include the current node in the + // second part ac.limit = -1 ac.fp0 = ac.fp ac.count0 = ac.count @@ -415,12 +433,25 @@ func (ac *aggContext) maybeIncludeNode(node node, p prefix) bool { ac.count = node.c ac.lastPrefix = &p return true - default: - return false } ac.fp.update(node.fp[:]) ac.count += node.c ac.lastPrefix = &p + if ac.easySplit && ac.limit == 0 { + // We're doing a split and this node is exactly at the limit, or it was + // above the limit but first part was still empty, so store count and + // fingerprint for the first part which includes the current node and zero + // out cound and figerprint for the second part + ac.limit = -1 + ac.fp0 = ac.fp + ac.count0 = ac.count + ac.lastPrefix0 = ac.lastPrefix + for n := range ac.fp { + ac.fp[n] = 0 + } + ac.count = 0 + ac.lastPrefix = nil + } return true } @@ -698,10 +729,10 @@ func (ft *fpTree) node(idx nodeIndex) (node, bool) { } func (ft *fpTree) aggregateUpToLimit(idx nodeIndex, p prefix, ac *aggContext) (cont bool, err error) { - ft.enter("aggregateUpToLimit: idx %d p %s limit %d cur_fp %s cur_count %d", idx, p, ac.limit, - ac.fp, ac.count) + ft.enter("aggregateUpToLimit: idx %d p %s limit %d cur_fp %s cur_count0 %d cur_count %d", idx, p, ac.limit, + ac.fp, ac.count0, ac.count) defer func() { - ft.leave(ac.fp, ac.count, err) + ft.leave(ac.fp, ac.count0, ac.count, err) }() node, ok := ft.node(idx) switch { @@ -753,7 +784,7 @@ func (ft *fpTree) aggregateUpToLimit(idx nodeIndex, p prefix, ac *aggContext) (c func (ft *fpTree) aggregateLeft(idx nodeIndex, v uint64, p prefix, ac *aggContext) (cont bool, err error) { ft.enter("aggregateLeft: idx %d v %016x p %s limit %d", idx, v, p, ac.limit) defer func() { - ft.leave(ac.fp, ac.count, err) + ft.leave(ac.fp, ac.count0, ac.count, err) }() node, ok := ft.node(idx) switch { @@ -797,7 +828,7 @@ func (ft *fpTree) aggregateLeft(idx nodeIndex, v uint64, p prefix, ac *aggContex func (ft *fpTree) aggregateRight(idx nodeIndex, v uint64, p prefix, ac *aggContext) (cont bool, err error) { ft.enter("aggregateRight: idx %d v %016x p %s limit %d", idx, v, p, ac.limit) defer func() { - ft.leave(ac.fp, ac.count, err) + ft.leave(ac.fp, ac.count0, ac.count, err) }() node, ok := ft.node(idx) switch { @@ -811,7 +842,7 @@ func (ft *fpTree) aggregateRight(idx nodeIndex, v uint64, p prefix, ac *aggConte if node.left != noIndex || node.right != noIndex { panic("BUG: node @ maxDepth has children") } - if ac.pruneY(node) { + if ac.pruneY(node, p) { ft.log("node %d p %s pruned", idx, p) return false, nil } @@ -956,7 +987,7 @@ func (ft *fpTree) aggregateInverse(ac *aggContext) (err error) { // nothing to do case ac.nodeBelowY(pf1Node, followedPrefix) && ac.maybeIncludeNode(pf1Node, followedPrefix): // node is fully included - case ac.pruneY(pf1Node): + case ac.pruneY(pf1Node, followedPrefix): ft.log("node %d p %s pruned", idx1, followedPrefix) return nil default: @@ -1044,31 +1075,37 @@ func (ft *fpTree) fingerprintInterval(x, y KeyBytes, limit int) (fpr fpResult, e return fpr, nil } +type splitResult struct { + part0, part1 fpResult + middle KeyBytes +} + // easySplit splits an interval in two parts trying to do it in such way that the first // part has close to limit items while not making any idStore queries so that the database // is not accessed. If the split can't be done, which includes the situation where one of // the sides has 0 items, easySplit returns errEasySplitFailed error -func (ft *fpTree) easySplit(x, y KeyBytes, limit int) (fprA, fprB fpResult, err error) { +func (ft *fpTree) easySplit(x, y KeyBytes, limit int) (sr splitResult, err error) { ft.enter("easySplit: x %s y %s limit %d", x, y, limit) defer func() { - ft.leave(fprA.fp, fprA.count, fprA.itype, fprA.start, fprA.end, - fprB.fp, fprB.count, fprB.itype, fprB.start, fprB.end, err) + ft.leave(sr.part0.fp, sr.part0.count, sr.part0.itype, sr.part0.start, sr.part0.end, + sr.part1.fp, sr.part1.count, sr.part1.itype, sr.part1.start, sr.part1.end, err) }() if limit < 0 { panic("BUG: easySplit with limit < 0") } ac := aggContext{x: x, y: y, limit: limit, easySplit: true} if err := ft.aggregateInterval(&ac); err != nil { - return fpResult{}, fpResult{}, err + return splitResult{}, err } if ac.total == 0 { - return fpResult{}, fpResult{}, nil + return splitResult{}, nil } if ac.count0 == 0 || ac.count == 0 { // need to get some items on both sides for the easy split to succeed - return fpResult{}, fpResult{}, errEasySplitFailed + ft.log("easySplit failed: one side missing: count0 %d count %d", ac.count0, ac.count) + return splitResult{}, errEasySplitFailed } // It should not be possible to have ac.lastPrefix0 == nil or ac.lastPrefix == nil @@ -1080,21 +1117,27 @@ func (ft *fpTree) easySplit(x, y KeyBytes, limit int) (fprA, fprB fpResult, err // ac.start / ac.end are only set in aggregateEdge which fails with // errEasySplitFailed if easySplit is enabled, so we can ignore them here - fprA = fpResult{ + middle := make(KeyBytes, ft.keyLen) + ac.lastPrefix0.idAfter(middle) + part0 := fpResult{ fp: ac.fp0, count: ac.count0, itype: ac.itype, start: ft.iter(x), end: ft.endIterFromPrefix(*ac.lastPrefix0), } - fprB = fpResult{ + part1 := fpResult{ fp: ac.fp, count: ac.count, itype: ac.itype, - start: fprA.end.Clone(), + start: part0.end.Clone(), end: ft.endIterFromPrefix(*ac.lastPrefix), } - return fprA, fprB, nil + return splitResult{ + part0: part0, + part1: part1, + middle: middle, + }, nil } func (ft *fpTree) dumpNode(w io.Writer, idx nodeIndex, indent, dir string) { diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index 69d0c29056..92fd40c14d 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -1212,7 +1212,8 @@ func TestFPTreeManyItems(t *testing.T) { // TBD: test start/end iterators } -func verifyEasySplit(t *testing.T, ft *fpTree, x, y KeyBytes) { +func verifyEasySplit(t *testing.T, ft *fpTree, x, y KeyBytes, depth, maxDepth int) { + t.Logf("depth %d", depth) t.Logf("--- fingerprint interval %s %s ---", x.String(), y.String()) fpr, err := ft.fingerprintInterval(x, y, -1) require.NoError(t, err) @@ -1226,47 +1227,98 @@ func verifyEasySplit(t *testing.T, ft *fpTree, x, y KeyBytes) { m := fpr.count / 2 t.Logf("--- easy split %s %s %d ---", x.String(), y.String(), m) - fpr1, fpr2, err := ft.easySplit(x[:], y[:], int(m)) + sr, err := ft.easySplit(x[:], y[:], int(m)) require.NoError(t, err) - require.NotZero(t, fpr1.count) - require.NotZero(t, fpr2.count) - require.Equal(t, fpr.count, fpr1.count+fpr2.count) - require.Equal(t, fpr.itype, fpr1.itype) - require.Equal(t, fpr.itype, fpr2.itype) - fp := fpr1.fp - fp.update(fpr2.fp[:]) + require.NotZero(t, sr.part0.count) + require.NotZero(t, sr.part1.count) + require.Equal(t, fpr.count, sr.part0.count+sr.part1.count) + require.Equal(t, fpr.itype, sr.part0.itype) + require.Equal(t, fpr.itype, sr.part1.itype) + fp := sr.part0.fp + fp.update(sr.part1.fp[:]) require.Equal(t, fpr.fp, fp) - require.Equal(t, a, itKey(t, fpr1.start)) - require.Equal(t, b, itKey(t, fpr2.end)) - middle := itKey(t, fpr1.end) - require.Equal(t, middle, itKey(t, fpr2.start)) + require.Equal(t, a, itKey(t, sr.part0.start)) + require.Equal(t, b, itKey(t, sr.part1.end)) + precMiddle := itKey(t, sr.part0.end) + require.Equal(t, precMiddle, itKey(t, sr.part1.start)) - fpr11, err := ft.fingerprintInterval(x, middle, -1) + fpr11, err := ft.fingerprintInterval(x, precMiddle, -1) require.NoError(t, err) - require.Equal(t, fpr1.fp, fpr11.fp) - require.Equal(t, fpr1.count, fpr11.count) + require.Equal(t, sr.part0.fp, fpr11.fp) + require.Equal(t, sr.part0.count, fpr11.count) require.Equal(t, a, itKey(t, fpr11.start)) - require.Equal(t, middle, itKey(t, fpr11.end)) + require.Equal(t, precMiddle, itKey(t, fpr11.end)) - fpr12, err := ft.fingerprintInterval(middle, y, -1) + fpr12, err := ft.fingerprintInterval(precMiddle, y, -1) require.NoError(t, err) - require.Equal(t, fpr2.fp, fpr12.fp) - require.Equal(t, fpr2.count, fpr12.count) - require.Equal(t, middle, itKey(t, fpr12.start)) + require.Equal(t, sr.part1.fp, fpr12.fp) + require.Equal(t, sr.part1.count, fpr12.count) + require.Equal(t, precMiddle, itKey(t, fpr12.start)) require.Equal(t, b, itKey(t, fpr12.end)) - // TBD: QQQQQ: recurse! + fpr11, err = ft.fingerprintInterval(x, sr.middle, -1) + require.NoError(t, err) + require.Equal(t, sr.part0.fp, fpr11.fp) + require.Equal(t, sr.part0.count, fpr11.count) + require.Equal(t, a, itKey(t, fpr11.start)) + require.Equal(t, precMiddle, itKey(t, fpr11.end)) + + fpr12, err = ft.fingerprintInterval(sr.middle, y, -1) + require.NoError(t, err) + require.Equal(t, sr.part1.fp, fpr12.fp) + require.Equal(t, sr.part1.count, fpr12.count) + require.Equal(t, precMiddle, itKey(t, fpr12.start)) + require.Equal(t, b, itKey(t, fpr12.end)) + + if depth < maxDepth { + verifyEasySplit(t, ft, x, sr.middle, depth+1, maxDepth) + verifyEasySplit(t, ft, sr.middle, y, depth+1, maxDepth) + } } func TestEasySplit(t *testing.T) { var np nodePool - maxDepth := 24 + maxDepth := 5 + // count := 25 ft := newFPTree(&np, newInMemIDStore(32), 32, maxDepth) - for range 10 { - h := types.RandomHash() + // QQQQQ: rm + for _, h := range []string{ + "00754cf490eeed75fa614b77d6b2b3cd16298711c126f73e8d265304dd251a50", + "16d5ed7d8c71b7c7d6ba00340355e6a0de63ed48089a5d1e29dac608d96d246d", + "2edbb20246c25fbf3d8b56cd183f3ede530a02b9658babd7a90295ac645e2aa2", + "45450dfbdf6613eb137daa99cffe47a1b7e21454301bdf8f814b26c309bb2e4c", + "4b42fe0661e3356998293436b83d28d253751a39c382bdeb310f13dec9b0e79d", + "5349101f5ad0ed08bae1cfb95dbb0399dfe87017783ae67e48254445391a1e5e", + "5c3a1f51e1f84a93bbbd284f609a61e0b71b641e155b34f447fc1a189ab0bb08", + "68e8c4773cc7f0218503ed3b2da426bdaf288c8722a091da57edb8ff8303fa03", + "6edaa71d8d400dfcbb8dc3fa6041ae602127db1cd84ec17b729492d9f53025ee", + "7f509a73fef1d7e44abf87ceb775b4e4acc312271caca84aa719cf63ca6ae3ee", + "884cba4481c5d81ee62158baa74fc315fcd6a089b50c7abf197f75acfa40d9ef", + "8f9dbdd4d95afc56c0744b1f1cd6e98dc730078438bb6e5511a2bb8f1e7d16d0", + "92ccb9bb426375dbae456fc4f743ffa69543a9741ea9600ded314ce56dd678af", + "942469ecfc783631c644f1af572e18bf45977d3d9b0628a1bcd9cffbf86be3bf", + "97d46b0e99abf5e290b0783e1595fded3393ac7e7badda2f348859ca085afe80", + "a4d9bc80fd93d3008930089148ab55e3d576f9ce14a1f900d0cc12ca1632e478", + "aee28a625a79951613c5e3165e396e3f5a459803cf16c4305447b86f55dd9048", + "af8255c5db1cf9a9ba5fcb76dbaace9e08fd9d8aa11b02c034f43d90322ee3d1", + "bf6dfedbf2152fabb7ebaedab3589c8c21d0f6e996fee2b3f93740908e0e0115", + "da0c349844ca0d393996ada10fb3ce00eb002ed2f3b83e87f246b6863ed3b5d1", + "dea9caf0c26793cb46f70b30e3e772cb210cf21ced570dfe5e4257da37873715", + "e358c3b354a629b78c470ab873609f96dd5f3648d9eee6fe133fe5182514b057", + "e86dc212c57bdb978d0440fcc95b2900ad0f685e8655f958c1ab19c268653712", + "f3f8560d9d8d8342698c728bff154e797dfaaa9acc88cfbca7825db200310c69", + "f593bc74323d291c9c31ce3ad3af4826105f068519480eec7491360b0f01de8e", + } { + h := types.HexToHash32(h) t.Logf("adding hash %s", h.String()) ft.addHash(h[:]) } + // QQQQQ: restore + // for range count { + // h := types.RandomHash() + // t.Logf("adding hash %s", h.String()) + // ft.addHash(h[:]) + // } k, err := ft.start().Key() require.NoError(t, err) x := k.(KeyBytes) @@ -1281,8 +1333,8 @@ func TestEasySplit(t *testing.T) { ft.dump(&sb) t.Logf("tree:\n%s", sb.String()) - verifyEasySplit(t, ft, x, x) - // TBD: test split with leafs that have c > 1 + verifyEasySplit(t, ft, x, x, 0, maxDepth-2) + // TBD: test split both with few items (incomplete tree) and many items } const dbFile = "/Users/ivan4th/Library/Application Support/Spacemesh/node-data/7c8cef2b/state.sql" From 04561073fae61460bf42a9a664bf6bcbe7584cc4 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 1 Aug 2024 16:52:43 +0400 Subject: [PATCH 57/76] wip5 --- sync2/dbsync/dbitemstore.go | 24 +++++++- sync2/dbsync/fptree.go | 2 - sync2/dbsync/fptree_test.go | 112 ++++++++++++++++-------------------- sync2/hashsync/interface.go | 11 +++- 4 files changed, 79 insertions(+), 70 deletions(-) diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go index 83e7c61724..18578bc14e 100644 --- a/sync2/dbsync/dbitemstore.go +++ b/sync2/dbsync/dbitemstore.go @@ -99,11 +99,19 @@ func (d *DBItemStore) SplitRange( preceding hashsync.Iterator, x, y hashsync.Ordered, count int, -) (hashsync.RangeInfo, hashsync.RangeInfo, error) { +) ( + hashsync.RangeInfo, + hashsync.RangeInfo, + error, +) { if err := d.EnsureLoaded(); err != nil { return hashsync.RangeInfo{}, hashsync.RangeInfo{}, err } panic("TBD") + // fpr1, fpr2, err := d.ft.easySplit(x.(KeyBytes), y.(KeyBytes), count) + // if err != nil { + // return hashsync.RangeInfo{}, hashsync.RangeInfo{}, err + // } // fpr1, fpr2, err := d.ft.splitFingerprintInterval(x.(KeyBytes), y.(KeyBytes), count) // if err != nil { // return hashsync.RangeInfo{}, hashsync.RangeInfo{}, err @@ -216,10 +224,20 @@ func (a *ItemStoreAdapter) GetRangeInfo(preceding hashsync.Iterator, x hashsync. }, nil } -func (a *ItemStoreAdapter) SplitRange(preceding hashsync.Iterator, x hashsync.Ordered, y hashsync.Ordered, count int) (hashsync.RangeInfo, hashsync.RangeInfo, error) { +func (a *ItemStoreAdapter) SplitRange( + preceding hashsync.Iterator, + x hashsync.Ordered, + y hashsync.Ordered, + count int, +) ( + hashsync.RangeInfo, + hashsync.RangeInfo, + error, +) { hx := x.(types.Hash32) hy := y.(types.Hash32) - info1, info2, err := a.s.SplitRange(preceding, KeyBytes(hx[:]), KeyBytes(hy[:]), count) + info1, info2, err := a.s.SplitRange( + preceding, KeyBytes(hx[:]), KeyBytes(hy[:]), count) if err != nil { return hashsync.RangeInfo{}, hashsync.RangeInfo{}, err } diff --git a/sync2/dbsync/fptree.go b/sync2/dbsync/fptree.go index cd6143b9e3..866db9a34b 100644 --- a/sync2/dbsync/fptree.go +++ b/sync2/dbsync/fptree.go @@ -408,8 +408,6 @@ func (ac *aggContext) pruneY(node node, p prefix) bool { } func (ac *aggContext) maybeIncludeNode(node node, p prefix) bool { - fmt.Fprintf(os.Stderr, "QQQQQ: maybeIncludeNode: limit %d node.c %d easySplit %v leaf %v\n", - ac.limit, node.c, ac.easySplit, node.leaf()) switch { case ac.limit < 0: case uint32(ac.limit) >= node.c: diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index 92fd40c14d..0a7e4cbc21 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -1212,9 +1212,17 @@ func TestFPTreeManyItems(t *testing.T) { // TBD: test start/end iterators } -func verifyEasySplit(t *testing.T, ft *fpTree, x, y KeyBytes, depth, maxDepth int) { - t.Logf("depth %d", depth) - t.Logf("--- fingerprint interval %s %s ---", x.String(), y.String()) +func verifyEasySplit( + t *testing.T, + ft *fpTree, + x, y KeyBytes, + depth, + maxDepth int, +) ( + succeeded, failed int, +) { + // t.Logf("depth %d", depth) + // t.Logf("--- fingerprint interval %s %s ---", x.String(), y.String()) fpr, err := ft.fingerprintInterval(x, y, -1) require.NoError(t, err) if fpr.count <= 1 { @@ -1226,8 +1234,12 @@ func verifyEasySplit(t *testing.T, ft *fpTree, x, y KeyBytes, depth, maxDepth in require.NoError(t, err) m := fpr.count / 2 - t.Logf("--- easy split %s %s %d ---", x.String(), y.String(), m) + // t.Logf("--- easy split %s %s %d ---", x.String(), y.String(), m) sr, err := ft.easySplit(x[:], y[:], int(m)) + if err != nil { + require.ErrorIs(t, err, errEasySplitFailed) + return 0, 1 + } require.NoError(t, err) require.NotZero(t, sr.part0.count) require.NotZero(t, sr.part1.count) @@ -1270,71 +1282,45 @@ func verifyEasySplit(t *testing.T, ft *fpTree, x, y KeyBytes, depth, maxDepth in require.Equal(t, precMiddle, itKey(t, fpr12.start)) require.Equal(t, b, itKey(t, fpr12.end)) - if depth < maxDepth { - verifyEasySplit(t, ft, x, sr.middle, depth+1, maxDepth) - verifyEasySplit(t, ft, sr.middle, y, depth+1, maxDepth) + if depth >= maxDepth { + return 1, 0 } + s1, f1 := verifyEasySplit(t, ft, x, sr.middle, depth+1, maxDepth) + s2, f2 := verifyEasySplit(t, ft, sr.middle, y, depth+1, maxDepth) + return s1 + s2 + 1, f1 + f2 } func TestEasySplit(t *testing.T) { - var np nodePool - maxDepth := 5 - // count := 25 - ft := newFPTree(&np, newInMemIDStore(32), 32, maxDepth) - // QQQQQ: rm - for _, h := range []string{ - "00754cf490eeed75fa614b77d6b2b3cd16298711c126f73e8d265304dd251a50", - "16d5ed7d8c71b7c7d6ba00340355e6a0de63ed48089a5d1e29dac608d96d246d", - "2edbb20246c25fbf3d8b56cd183f3ede530a02b9658babd7a90295ac645e2aa2", - "45450dfbdf6613eb137daa99cffe47a1b7e21454301bdf8f814b26c309bb2e4c", - "4b42fe0661e3356998293436b83d28d253751a39c382bdeb310f13dec9b0e79d", - "5349101f5ad0ed08bae1cfb95dbb0399dfe87017783ae67e48254445391a1e5e", - "5c3a1f51e1f84a93bbbd284f609a61e0b71b641e155b34f447fc1a189ab0bb08", - "68e8c4773cc7f0218503ed3b2da426bdaf288c8722a091da57edb8ff8303fa03", - "6edaa71d8d400dfcbb8dc3fa6041ae602127db1cd84ec17b729492d9f53025ee", - "7f509a73fef1d7e44abf87ceb775b4e4acc312271caca84aa719cf63ca6ae3ee", - "884cba4481c5d81ee62158baa74fc315fcd6a089b50c7abf197f75acfa40d9ef", - "8f9dbdd4d95afc56c0744b1f1cd6e98dc730078438bb6e5511a2bb8f1e7d16d0", - "92ccb9bb426375dbae456fc4f743ffa69543a9741ea9600ded314ce56dd678af", - "942469ecfc783631c644f1af572e18bf45977d3d9b0628a1bcd9cffbf86be3bf", - "97d46b0e99abf5e290b0783e1595fded3393ac7e7badda2f348859ca085afe80", - "a4d9bc80fd93d3008930089148ab55e3d576f9ce14a1f900d0cc12ca1632e478", - "aee28a625a79951613c5e3165e396e3f5a459803cf16c4305447b86f55dd9048", - "af8255c5db1cf9a9ba5fcb76dbaace9e08fd9d8aa11b02c034f43d90322ee3d1", - "bf6dfedbf2152fabb7ebaedab3589c8c21d0f6e996fee2b3f93740908e0e0115", - "da0c349844ca0d393996ada10fb3ce00eb002ed2f3b83e87f246b6863ed3b5d1", - "dea9caf0c26793cb46f70b30e3e772cb210cf21ced570dfe5e4257da37873715", - "e358c3b354a629b78c470ab873609f96dd5f3648d9eee6fe133fe5182514b057", - "e86dc212c57bdb978d0440fcc95b2900ad0f685e8655f958c1ab19c268653712", - "f3f8560d9d8d8342698c728bff154e797dfaaa9acc88cfbca7825db200310c69", - "f593bc74323d291c9c31ce3ad3af4826105f068519480eec7491360b0f01de8e", - } { - h := types.HexToHash32(h) - t.Logf("adding hash %s", h.String()) - ft.addHash(h[:]) - } - // QQQQQ: restore - // for range count { - // h := types.RandomHash() - // t.Logf("adding hash %s", h.String()) - // ft.addHash(h[:]) - // } - k, err := ft.start().Key() - require.NoError(t, err) - x := k.(KeyBytes) - v := load64(x) & ^(1<<(64-maxDepth) - 1) - binary.BigEndian.PutUint64(x, v) - for i := 8; i < len(x); i++ { - x[i] = 0 - } + maxDepth := 17 + count := 10000 + for range 5 { + var np nodePool + ft := newFPTree(&np, newInMemIDStore(32), 32, maxDepth) + for range count { + h := types.RandomHash() + // t.Logf("adding hash %s", h.String()) + ft.addHash(h[:]) + } + k, err := ft.start().Key() + require.NoError(t, err) + x := k.(KeyBytes) + v := load64(x) & ^(1<<(64-maxDepth) - 1) + binary.BigEndian.PutUint64(x, v) + for i := 8; i < len(x); i++ { + x[i] = 0 + } - ft.traceEnabled = true - var sb strings.Builder - ft.dump(&sb) - t.Logf("tree:\n%s", sb.String()) + // ft.traceEnabled = true + // var sb strings.Builder + // ft.dump(&sb) + // t.Logf("tree:\n%s", sb.String()) - verifyEasySplit(t, ft, x, x, 0, maxDepth-2) - // TBD: test split both with few items (incomplete tree) and many items + succeeded, failed := verifyEasySplit(t, ft, x, x, 0, maxDepth-2) + successRate := float64(succeeded) * 100 / float64(succeeded+failed) + t.Logf("succeeded %d, failed %d, success rate %.2f%%", + succeeded, failed, successRate) + require.GreaterOrEqual(t, successRate, 95.0) + } } const dbFile = "/Users/ivan4th/Library/Application Support/Spacemesh/node-data/7c8cef2b/state.sql" diff --git a/sync2/hashsync/interface.go b/sync2/hashsync/interface.go index 3379b75976..75e70a28d4 100644 --- a/sync2/hashsync/interface.go +++ b/sync2/hashsync/interface.go @@ -24,10 +24,17 @@ type Iterator interface { Clone() Iterator } +// RangeInfo contains information about a range of items in the ItemStore as returned by +// ItemStore.GetRangeInfo. type RangeInfo struct { + // Fingerprint of the interval Fingerprint any - Count int - Start, End Iterator + // Number of items in the interval + Count int + // An iterator pointing to the beginning of the interval or nil if count is zero. + Start Iterator + // An iterator pointing to the end of the interval or nil if count is zero. + End Iterator } // ItemStore represents the data store that can be synced against a remote peer From f93a02c5741543f38786b557a3f7d823753a2e0e Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 13 Aug 2024 12:42:04 +0400 Subject: [PATCH 58/76] wip6 --- sync2/dbsync/dbitemstore.go | 116 +++++++++++++++++++----------- sync2/hashsync/interface.go | 10 ++- sync2/hashsync/mocks_test.go | 55 +++++++++++--- sync2/hashsync/rangesync.go | 68 +++++++++--------- sync2/hashsync/rangesync_test.go | 38 ++++++---- sync2/hashsync/sync_tree_store.go | 15 ++-- 6 files changed, 198 insertions(+), 104 deletions(-) diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go index 18578bc14e..145c58a599 100644 --- a/sync2/dbsync/dbitemstore.go +++ b/sync2/dbsync/dbitemstore.go @@ -2,6 +2,7 @@ package dbsync import ( "context" + "errors" "sync" "github.com/spacemeshos/go-spacemesh/common/types" @@ -100,33 +101,61 @@ func (d *DBItemStore) SplitRange( x, y hashsync.Ordered, count int, ) ( - hashsync.RangeInfo, - hashsync.RangeInfo, + hashsync.SplitInfo, error, ) { + if count <= 0 { + panic("BUG: bad split count") + } + if err := d.EnsureLoaded(); err != nil { - return hashsync.RangeInfo{}, hashsync.RangeInfo{}, err + return hashsync.SplitInfo{}, err + } + + sr, err := d.ft.easySplit(x.(KeyBytes), y.(KeyBytes), count) + if err == nil { + return hashsync.SplitInfo{ + Parts: [2]hashsync.RangeInfo{ + { + Fingerprint: sr.part0.fp, + Count: int(sr.part0.count), + Start: sr.part0.start, + End: sr.part0.end, + }, + { + Fingerprint: sr.part1.fp, + Count: int(sr.part1.count), + Start: sr.part1.start, + End: sr.part1.end, + }, + }, + Middle: sr.middle, + }, nil + } + + if !errors.Is(err, errEasySplitFailed) { + return hashsync.SplitInfo{}, err + } + + part0, err := d.GetRangeInfo(preceding, x, y, count) + if err != nil { + return hashsync.SplitInfo{}, err } - panic("TBD") - // fpr1, fpr2, err := d.ft.easySplit(x.(KeyBytes), y.(KeyBytes), count) - // if err != nil { - // return hashsync.RangeInfo{}, hashsync.RangeInfo{}, err - // } - // fpr1, fpr2, err := d.ft.splitFingerprintInterval(x.(KeyBytes), y.(KeyBytes), count) - // if err != nil { - // return hashsync.RangeInfo{}, hashsync.RangeInfo{}, err - // } - // return hashsync.RangeInfo{ - // Fingerprint: fpr1.fp, - // Count: int(fpr1.count), - // Start: fpr1.start, - // End: fpr1.end, - // }, hashsync.RangeInfo{ - // Fingerprint: fpr2.fp, - // Count: int(fpr2.count), - // Start: fpr2.start, - // End: fpr2.end, - // }, nil + if part0.Count == 0 { + return hashsync.SplitInfo{}, errors.New("can't split empty range") + } + middle, err := part0.End.Key() + if err != nil { + return hashsync.SplitInfo{}, err + } + part1, err := d.GetRangeInfo(part0.End.Clone(), middle, y, -1) + if err != nil { + return hashsync.SplitInfo{}, err + } + return hashsync.SplitInfo{ + Parts: [2]hashsync.RangeInfo{part0, part1}, + Middle: middle, + }, nil } // Min implements hashsync.ItemStore. @@ -230,33 +259,40 @@ func (a *ItemStoreAdapter) SplitRange( y hashsync.Ordered, count int, ) ( - hashsync.RangeInfo, - hashsync.RangeInfo, + hashsync.SplitInfo, error, ) { hx := x.(types.Hash32) hy := y.(types.Hash32) - info1, info2, err := a.s.SplitRange( + si, err := a.s.SplitRange( preceding, KeyBytes(hx[:]), KeyBytes(hy[:]), count) if err != nil { - return hashsync.RangeInfo{}, hashsync.RangeInfo{}, err + return hashsync.SplitInfo{}, err } var fp1, fp2 types.Hash12 - src1 := info1.Fingerprint.(fingerprint) - src2 := info2.Fingerprint.(fingerprint) + src1 := si.Parts[0].Fingerprint.(fingerprint) + src2 := si.Parts[1].Fingerprint.(fingerprint) copy(fp1[:], src1[:]) copy(fp2[:], src2[:]) - return hashsync.RangeInfo{ - Fingerprint: fp1, - Count: info1.Count, - Start: a.wrapIterator(info1.Start), - End: a.wrapIterator(info1.End), - }, hashsync.RangeInfo{ - Fingerprint: fp2, - Count: info2.Count, - Start: a.wrapIterator(info2.Start), - End: a.wrapIterator(info2.End), - }, nil + var middle types.Hash32 + copy(middle[:], si.Middle.(KeyBytes)) + return hashsync.SplitInfo{ + Parts: [2]hashsync.RangeInfo{ + { + Fingerprint: fp1, + Count: si.Parts[0].Count, + Start: a.wrapIterator(si.Parts[0].Start), + End: a.wrapIterator(si.Parts[0].End), + }, + { + Fingerprint: fp2, + Count: si.Parts[1].Count, + Start: a.wrapIterator(si.Parts[1].Start), + End: a.wrapIterator(si.Parts[1].End), + }, + }, + Middle: middle, + }, nil } // Has implements hashsync.ItemStore. diff --git a/sync2/hashsync/interface.go b/sync2/hashsync/interface.go index 75e70a28d4..86cf0a9243 100644 --- a/sync2/hashsync/interface.go +++ b/sync2/hashsync/interface.go @@ -37,6 +37,14 @@ type RangeInfo struct { End Iterator } +// SplitInfo contains information about range split in two. +type SplitInfo struct { + // 2 parts of the range + Parts [2]RangeInfo + // Middle point between the ranges + Middle Ordered +} + // ItemStore represents the data store that can be synced against a remote peer type ItemStore interface { // Add adds a key to the store @@ -49,7 +57,7 @@ type ItemStore interface { GetRangeInfo(preceding Iterator, x, y Ordered, count int) (RangeInfo, error) // SplitRange splits the range roughly after the specified count of items, // returning RangeInfo for the first half and the second half of the range. - SplitRange(preceding Iterator, x, y Ordered, count int) (RangeInfo, RangeInfo, error) + SplitRange(preceding Iterator, x, y Ordered, count int) (SplitInfo, error) // Min returns the iterator pointing at the minimum element // in the store. If the store is empty, it returns nil Min() (Iterator, error) diff --git a/sync2/hashsync/mocks_test.go b/sync2/hashsync/mocks_test.go index 93adca1ec9..30ecc97651 100644 --- a/sync2/hashsync/mocks_test.go +++ b/sync2/hashsync/mocks_test.go @@ -43,6 +43,44 @@ func (m *MockIterator) EXPECT() *MockIteratorMockRecorder { return m.recorder } +// Clone mocks base method. +func (m *MockIterator) Clone() Iterator { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Clone") + ret0, _ := ret[0].(Iterator) + return ret0 +} + +// Clone indicates an expected call of Clone. +func (mr *MockIteratorMockRecorder) Clone() *MockIteratorCloneCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clone", reflect.TypeOf((*MockIterator)(nil).Clone)) + return &MockIteratorCloneCall{Call: call} +} + +// MockIteratorCloneCall wrap *gomock.Call +type MockIteratorCloneCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockIteratorCloneCall) Return(arg0 Iterator) *MockIteratorCloneCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockIteratorCloneCall) Do(f func() Iterator) *MockIteratorCloneCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockIteratorCloneCall) DoAndReturn(f func() Iterator) *MockIteratorCloneCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // Key mocks base method. func (m *MockIterator) Key() (Ordered, error) { m.ctrl.T.Helper() @@ -337,13 +375,12 @@ func (c *MockItemStoreMinCall) DoAndReturn(f func() (Iterator, error)) *MockItem } // SplitRange mocks base method. -func (m *MockItemStore) SplitRange(preceding Iterator, x, y Ordered, count int) (RangeInfo, RangeInfo, error) { +func (m *MockItemStore) SplitRange(preceding Iterator, x, y Ordered, count int) (SplitInfo, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SplitRange", preceding, x, y, count) - ret0, _ := ret[0].(RangeInfo) - ret1, _ := ret[1].(RangeInfo) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret0, _ := ret[0].(SplitInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 } // SplitRange indicates an expected call of SplitRange. @@ -359,19 +396,19 @@ type MockItemStoreSplitRangeCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockItemStoreSplitRangeCall) Return(arg0, arg1 RangeInfo, arg2 error) *MockItemStoreSplitRangeCall { - c.Call = c.Call.Return(arg0, arg1, arg2) +func (c *MockItemStoreSplitRangeCall) Return(arg0 SplitInfo, arg1 error) *MockItemStoreSplitRangeCall { + c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockItemStoreSplitRangeCall) Do(f func(Iterator, Ordered, Ordered, int) (RangeInfo, RangeInfo, error)) *MockItemStoreSplitRangeCall { +func (c *MockItemStoreSplitRangeCall) Do(f func(Iterator, Ordered, Ordered, int) (SplitInfo, error)) *MockItemStoreSplitRangeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreSplitRangeCall) DoAndReturn(f func(Iterator, Ordered, Ordered, int) (RangeInfo, RangeInfo, error)) *MockItemStoreSplitRangeCall { +func (c *MockItemStoreSplitRangeCall) DoAndReturn(f func(Iterator, Ordered, Ordered, int) (SplitInfo, error)) *MockItemStoreSplitRangeCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/sync2/hashsync/rangesync.go b/sync2/hashsync/rangesync.go index 53fd36d184..f6f66c22f6 100644 --- a/sync2/hashsync/rangesync.go +++ b/sync2/hashsync/rangesync.go @@ -250,7 +250,7 @@ func (rsr *RangeSetReconciler) processSubrange(c Conduit, x, y Ordered, info Ran return nil } -func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg SyncMessage) (it Iterator, done bool, err error) { +func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg SyncMessage) (done bool, err error) { rsr.log.Debug("handleMessage", IteratorField("preceding", preceding), zap.String("msg", SyncMessageToString(msg))) x := msg.X() @@ -261,36 +261,36 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg // even send X & Y (SendEmptySet) it, err := rsr.is.Min() if err != nil { - return nil, false, err + return false, err } if it == nil { // We don't have any items at all, too if msg.Type() == MessageTypeProbe { info, err := rsr.is.GetRangeInfo(preceding, nil, nil, -1) if err != nil { - return nil, false, err + return false, err } rsr.log.Debug("handleMessage: send probe response", HexField("fingerpint", info.Fingerprint), zap.Int("count", info.Count), IteratorField("it", it)) if err := c.SendProbeResponse(x, y, info.Fingerprint, info.Count, 0, it); err != nil { - return nil, false, err + return false, err } } - return nil, true, nil + return true, nil } x, err = it.Key() if err != nil { - return nil, false, err + return false, err } y = x } else if x == nil || y == nil { - return nil, false, errors.New("bad X or Y") + return false, errors.New("bad X or Y") } info, err := rsr.is.GetRangeInfo(preceding, x, y, -1) if err != nil { - return nil, false, err + return false, err } rsr.log.Debug("handleMessage: range info", HexField("x", x), HexField("y", y), @@ -314,7 +314,7 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg rsr.log.Debug("handleMessage: send items", zap.Int("count", info.Count), IteratorField("start", info.Start)) if err := c.SendItems(info.Count, rsr.itemChunkSize, info.Start); err != nil { - return nil, false, err + return false, err } } else { rsr.log.Debug("handleMessage: local range is empty") @@ -322,7 +322,7 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg case msg.Type() == MessageTypeProbe: sampleSize := msg.Count() if sampleSize > maxSampleSize { - return nil, false, fmt.Errorf("bad minhash sample size %d (max %d)", + return false, fmt.Errorf("bad minhash sample size %d (max %d)", msg.Count(), maxSampleSize) } else if sampleSize > info.Count { sampleSize = info.Count @@ -336,11 +336,11 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg // msg.Fingerprint(), info.Fingerprint) } if err := c.SendProbeResponse(x, y, info.Fingerprint, info.Count, sampleSize, it); err != nil { - return nil, false, err + return false, err } - return nil, true, nil + return true, nil case msg.Type() != MessageTypeFingerprint: - return nil, false, fmt.Errorf("unexpected message type %s", msg.Type()) + return false, fmt.Errorf("unexpected message type %s", msg.Type()) case fingerprintEqual(info.Fingerprint, msg.Fingerprint()): // The range is synced // case (info.Count+1)/2 <= rsr.maxSendRange: @@ -354,17 +354,17 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg HexField("x", x), HexField("y", y), zap.Int("count", info.Count)) // fmt.Fprintf(os.Stderr, "small incoming range: %s -> SendItems\n", msg) if err := c.SendRangeContents(x, y, info.Count); err != nil { - return nil, false, err + return false, err } if err := c.SendItems(info.Count, rsr.itemChunkSize, info.Start); err != nil { - return nil, false, err + return false, err } } else { rsr.log.Debug("handleMessage: empty incoming range", HexField("x", x), HexField("y", y)) // fmt.Fprintf(os.Stderr, "small incoming range: %s -> empty range msg\n", msg) if err := c.SendEmptyRange(x, y); err != nil { - return nil, false, err + return false, err } } default: @@ -377,37 +377,37 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg rsr.log.Debug("handleMessage: PRE split range", HexField("x", x), HexField("y", y), zap.Int("countArg", count)) - part0, part1, err := rsr.is.SplitRange(preceding, x, y, count) + si, err := rsr.is.SplitRange(preceding, x, y, count) if err != nil { - return nil, false, err + return false, err } rsr.log.Debug("handleMessage: split range", HexField("x", x), HexField("y", y), zap.Int("countArg", count), - zap.Int("count0", part0.Count), - HexField("fp0", part0.Fingerprint), - IteratorField("start0", part0.Start), - IteratorField("end0", part0.End), - zap.Int("count1", part1.Count), - HexField("fp1", part1.Fingerprint), - IteratorField("start1", part1.End), - IteratorField("end1", part1.End)) - middle, err := part0.End.Key() + zap.Int("count0", si.Parts[0].Count), + HexField("fp0", si.Parts[0].Fingerprint), + IteratorField("start0", si.Parts[0].Start), + IteratorField("end0", si.Parts[0].End), + zap.Int("count1", si.Parts[1].Count), + HexField("fp1", si.Parts[1].Fingerprint), + IteratorField("start1", si.Parts[1].End), + IteratorField("end1", si.Parts[1].End)) + middle, err := si.Parts[0].End.Key() if err != nil { - return nil, false, err + return false, err } - if err := rsr.processSubrange(c, x, middle, part0); err != nil { - return nil, false, err + if err := rsr.processSubrange(c, x, middle, si.Parts[0]); err != nil { + return false, err } // fmt.Fprintf(os.Stderr, "QQQQQ: next=%q\n", qqqqRmmeK(next)) - if err := rsr.processSubrange(c, middle, y, part1); err != nil { - return nil, false, err + if err := rsr.processSubrange(c, middle, y, si.Parts[1]); err != nil { + return false, err } // fmt.Fprintf(os.Stderr, "normal: split X %s - middle %s - Y %s:\n %s", // msg.X(), middle, msg.Y(), msg) done = false } - return info.End, done, nil + return done, nil } func (rsr *RangeSetReconciler) Initiate(c Conduit) error { @@ -634,7 +634,7 @@ func (rsr *RangeSetReconciler) Process(ctx context.Context, c Conduit) (done boo // breaks if we capture the iterator from handleMessage and // pass it to the next handleMessage call (it shouldn't) var msgDone bool - _, msgDone, err = rsr.handleMessage(c, nil, msg) + msgDone, err = rsr.handleMessage(c, nil, msg) if !msgDone { done = false } diff --git a/sync2/hashsync/rangesync_test.go b/sync2/hashsync/rangesync_test.go index 41723a7768..db8cabf6f1 100644 --- a/sync2/hashsync/rangesync_test.go +++ b/sync2/hashsync/rangesync_test.go @@ -315,26 +315,29 @@ func (ds *dumbStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) ( return r, nil } -func (ds *dumbStore) SplitRange(preceding Iterator, x, y Ordered, count int) (RangeInfo, RangeInfo, error) { +func (ds *dumbStore) SplitRange(preceding Iterator, x, y Ordered, count int) (SplitInfo, error) { if count <= 0 { panic("BUG: bad split count") } part0, err := ds.GetRangeInfo(preceding, x, y, count) if err != nil { - return RangeInfo{}, RangeInfo{}, err + return SplitInfo{}, err } if part0.Count == 0 { - return RangeInfo{}, RangeInfo{}, errors.New("can't split empty range") + return SplitInfo{}, errors.New("can't split empty range") } middle, err := part0.End.Key() if err != nil { - return RangeInfo{}, RangeInfo{}, err + return SplitInfo{}, err } part1, err := ds.GetRangeInfo(part0.End.Clone(), middle, y, -1) if err != nil { - return RangeInfo{}, RangeInfo{}, err + return SplitInfo{}, err } - return part0, part1, nil + return SplitInfo{ + Parts: [2]RangeInfo{part0, part1}, + Middle: middle, + }, nil } func (ds *dumbStore) Min() (Iterator, error) { @@ -510,24 +513,31 @@ func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count in return vs.verifySameRangeInfo(ri1, ri2), nil } -func (vs *verifiedStore) SplitRange(preceding Iterator, x, y Ordered, count int) (RangeInfo, RangeInfo, error) { +func (vs *verifiedStore) SplitRange(preceding Iterator, x, y Ordered, count int) (SplitInfo, error) { var ( - ri11, ri12, ri21, ri22 RangeInfo - err error + si1, si2 SplitInfo + err error ) if preceding != nil { p := preceding.(verifiedStoreIterator) - ri11, ri12, err = vs.knownGood.SplitRange(p.knownGood, x, y, count) + si1, err = vs.knownGood.SplitRange(p.knownGood, x, y, count) require.NoError(vs.t, err) - ri21, ri22, err = vs.store.SplitRange(p.it, x, y, count) + si2, err = vs.store.SplitRange(p.it, x, y, count) require.NoError(vs.t, err) } else { - ri11, ri12, err = vs.knownGood.SplitRange(nil, x, y, count) + si1, err = vs.knownGood.SplitRange(nil, x, y, count) require.NoError(vs.t, err) - ri21, ri22, err = vs.store.SplitRange(nil, x, y, count) + si2, err = vs.store.SplitRange(nil, x, y, count) require.NoError(vs.t, err) } - return vs.verifySameRangeInfo(ri11, ri21), vs.verifySameRangeInfo(ri12, ri22), nil + require.Equal(vs.t, si1.Middle, si2.Middle, "split middle") + return SplitInfo{ + Parts: [2]RangeInfo{ + vs.verifySameRangeInfo(si1.Parts[0], si2.Parts[0]), + vs.verifySameRangeInfo(si1.Parts[1], si2.Parts[1]), + }, + Middle: si1.Middle, + }, nil } func (vs *verifiedStore) Min() (Iterator, error) { diff --git a/sync2/hashsync/sync_tree_store.go b/sync2/hashsync/sync_tree_store.go index c6cf658bb9..9084c38f23 100644 --- a/sync2/hashsync/sync_tree_store.go +++ b/sync2/hashsync/sync_tree_store.go @@ -115,26 +115,29 @@ func (sts *SyncTreeStore) GetRangeInfo(preceding Iterator, x, y Ordered, count i } // SplitRange implements ItemStore. -func (sts *SyncTreeStore) SplitRange(preceding Iterator, x, y Ordered, count int) (RangeInfo, RangeInfo, error) { +func (sts *SyncTreeStore) SplitRange(preceding Iterator, x, y Ordered, count int) (SplitInfo, error) { if count <= 0 { panic("BUG: bad split count") } part0, err := sts.GetRangeInfo(preceding, x, y, count) if err != nil { - return RangeInfo{}, RangeInfo{}, err + return SplitInfo{}, err } if part0.Count == 0 { - return RangeInfo{}, RangeInfo{}, errors.New("can't split empty range") + return SplitInfo{}, errors.New("can't split empty range") } middle, err := part0.End.Key() if err != nil { - return RangeInfo{}, RangeInfo{}, err + return SplitInfo{}, err } part1, err := sts.GetRangeInfo(part0.End.Clone(), middle, y, -1) if err != nil { - return RangeInfo{}, RangeInfo{}, err + return SplitInfo{}, err } - return part0, part1, nil + return SplitInfo{ + Parts: [2]RangeInfo{part0, part1}, + Middle: middle, + }, nil } // Min implements ItemStore. From 72018dc663037ac5b3b9451ba654656b0c988464 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sun, 18 Aug 2024 03:50:22 +0400 Subject: [PATCH 59/76] wip7 --- sync2/dbsync/dbitemstore.go | 2 + sync2/dbsync/dbiter.go | 91 ++++++++++++++++++++++++++++-------- sync2/dbsync/dbiter_test.go | 35 +++++++++----- sync2/dbsync/p2p_test.go | 36 ++++++++++++-- sync2/dbsync/sqlidstore.go | 10 +++- sync2/hashsync/rangesync.go | 10 ++-- sync2/hashsync/split_sync.go | 1 + 7 files changed, 140 insertions(+), 45 deletions(-) diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go index 145c58a599..1fad5abcb5 100644 --- a/sync2/dbsync/dbitemstore.go +++ b/sync2/dbsync/dbitemstore.go @@ -114,6 +114,7 @@ func (d *DBItemStore) SplitRange( sr, err := d.ft.easySplit(x.(KeyBytes), y.(KeyBytes), count) if err == nil { + // fmt.Fprintf(os.Stderr, "QQQQQ: fast split, middle: %s\n", sr.middle.String()) return hashsync.SplitInfo{ Parts: [2]hashsync.RangeInfo{ { @@ -137,6 +138,7 @@ func (d *DBItemStore) SplitRange( return hashsync.SplitInfo{}, err } + // fmt.Fprintf(os.Stderr, "QQQQQ: slow split x %s y %s\n", x.(fmt.Stringer), y.(fmt.Stringer)) part0, err := d.GetRangeInfo(preceding, x, y, count) if err != nil { return hashsync.SplitInfo{}, err diff --git a/sync2/dbsync/dbiter.go b/sync2/dbsync/dbiter.go index 0d80ab696d..e871295a79 100644 --- a/sync2/dbsync/dbiter.go +++ b/sync2/dbsync/dbiter.go @@ -6,6 +6,7 @@ import ( "errors" "slices" + "github.com/hashicorp/golang-lru/v2/simplelru" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sync2/hashsync" ) @@ -54,6 +55,23 @@ func (k KeyBytes) isZero() bool { var errEmptySet = errors.New("empty range") +type dbIDKey struct { + id string + chunkSize int +} + +type lru = simplelru.LRU[dbIDKey, []KeyBytes] + +const lruCacheSize = 1024 * 1024 + +func newLRU() *lru { + cache, err := simplelru.NewLRU[dbIDKey, []KeyBytes](lruCacheSize, nil) + if err != nil { + panic("BUG: failed to create LRU cache: " + err.Error()) + } + return cache +} + type dbRangeIterator struct { db sql.Database from KeyBytes @@ -65,6 +83,7 @@ type dbRangeIterator struct { keyLen int singleChunk bool loaded bool + cache *lru } var _ hashsync.Iterator = &dbRangeIterator{} @@ -76,6 +95,7 @@ func newDBRangeIterator( query string, from KeyBytes, maxChunkSize int, + lru *lru, ) hashsync.Iterator { if from == nil { panic("BUG: makeDBIterator: nil from") @@ -95,9 +115,28 @@ func newDBRangeIterator( chunk: make([]KeyBytes, maxChunkSize), singleChunk: false, loaded: false, + cache: lru, } } +func (it *dbRangeIterator) loadCached(key dbIDKey) (bool, int) { + chunk, ok := it.cache.Get(key) + if !ok { + // fmt.Fprintf(os.Stderr, "QQQQQ: cache miss\n") + return false, 0 + } + + // fmt.Fprintf(os.Stderr, "QQQQQ: cache hit, chunk size %d\n", len(chunk)) + for n, id := range it.chunk[:len(chunk)] { + if id == nil { + id = make([]byte, it.keyLen) + it.chunk[n] = id + } + copy(id, chunk[n]) + } + return true, len(chunk) +} + func (it *dbRangeIterator) load() error { it.pos = 0 if it.singleChunk { @@ -114,27 +153,39 @@ func (it *dbRangeIterator) load() error { } else { it.chunk = it.chunk[:it.chunkSize] } - var ierr error - _, err := it.db.Exec( - it.query, func(stmt *sql.Statement) { - stmt.BindBytes(1, it.from) - stmt.BindInt64(2, int64(it.chunkSize)) - }, - func(stmt *sql.Statement) bool { - if n >= len(it.chunk) { - ierr = errors.New("too many rows") - return false - } - // we reuse existing slices when possible for retrieving new IDs - id := it.chunk[n] - if id == nil { - id = make([]byte, it.keyLen) - it.chunk[n] = id + // fmt.Fprintf(os.Stderr, "QQQQQ: from: %s chunkSize: %d\n", hex.EncodeToString(it.from), it.chunkSize) + key := dbIDKey{string(it.from), it.chunkSize} + var ierr, err error + found, n := it.loadCached(key) + if !found { + _, err = it.db.Exec( + it.query, func(stmt *sql.Statement) { + stmt.BindBytes(1, it.from) + stmt.BindInt64(2, int64(it.chunkSize)) + }, + func(stmt *sql.Statement) bool { + if n >= len(it.chunk) { + ierr = errors.New("too many rows") + return false + } + // we reuse existing slices when possible for retrieving new IDs + id := it.chunk[n] + if id == nil { + id = make([]byte, it.keyLen) + it.chunk[n] = id + } + stmt.ColumnBytes(0, id) + n++ + return true + }) + if err == nil && ierr == nil { + cached := make([]KeyBytes, n) + for n, id := range it.chunk[:n] { + cached[n] = slices.Clone(id) } - stmt.ColumnBytes(0, id) - n++ - return true - }) + it.cache.Add(key, cached) + } + } fromZero := it.from.isZero() it.chunkSize = min(it.chunkSize*2, it.maxChunkSize) switch { diff --git a/sync2/dbsync/dbiter_test.go b/sync2/dbsync/dbiter_test.go index bcc3529987..16107adf71 100644 --- a/sync2/dbsync/dbiter_test.go +++ b/sync2/dbsync/dbiter_test.go @@ -1,6 +1,7 @@ package dbsync import ( + "context" "encoding/hex" "errors" "fmt" @@ -41,24 +42,35 @@ func TestIncID(t *testing.T) { } func createDB(t *testing.T, keyLen int) sql.Database { - db := sql.InMemory(sql.WithIgnoreSchemaDrift()) + // QQQQQ: FIXME + tmpDir := t.TempDir() + t.Logf("QQQQQ: temp dir: %s", tmpDir) + db, err := sql.Open(fmt.Sprintf("file:%s/test.db", tmpDir), sql.WithIgnoreSchemaDrift()) + require.NoError(t, err) + // db := sql.InMemory(sql.WithIgnoreSchemaDrift()) t.Cleanup(func() { require.NoError(t, db.Close()) }) - _, err := db.Exec(fmt.Sprintf("create table foo(id char(%d) not null primary key)", keyLen), nil, nil) + _, err = db.Exec(fmt.Sprintf("create table foo(id char(%d) not null primary key)", keyLen), nil, nil) require.NoError(t, err) return db } func insertDBItems(t *testing.T, db sql.Database, content []KeyBytes) { - for _, id := range content { - _, err := db.Exec( - "insert into foo(id) values(?)", - func(stmt *sql.Statement) { - stmt.BindBytes(1, id) - }, nil) - require.NoError(t, err) - } + err := db.WithTx(context.Background(), func(tx sql.Transaction) error { + for _, id := range content { + _, err := tx.Exec( + "insert into foo(id) values(?)", + func(stmt *sql.Statement) { + stmt.BindBytes(1, id) + }, nil) + if err != nil { + return err + } + } + return nil + }) + require.NoError(t, err) } func deleteDBItems(t *testing.T, db sql.Database) { @@ -287,8 +299,9 @@ func TestDBRangeIterator(t *testing.T) { } { deleteDBItems(t, db) insertDBItems(t, db, tc.items) + cache := newLRU() for maxChunkSize := 1; maxChunkSize < 12; maxChunkSize++ { - it := newDBRangeIterator(db, testQuery, tc.from, maxChunkSize) + it := newDBRangeIterator(db, testQuery, tc.from, maxChunkSize, cache) if tc.expErr != nil { _, err := it.Key() require.ErrorIs(t, err, tc.expErr) diff --git a/sync2/dbsync/p2p_test.go b/sync2/dbsync/p2p_test.go index 5d32a03554..869e8c1c36 100644 --- a/sync2/dbsync/p2p_test.go +++ b/sync2/dbsync/p2p_test.go @@ -2,6 +2,7 @@ package dbsync import ( "context" + "encoding/binary" "errors" "io" "slices" @@ -21,14 +22,20 @@ import ( ) func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { + const maxDepth = 24 log := zaptest.NewLogger(t) + t.Logf("QQQQQ: 0") dbA := populateDB(t, 32, itemsA) + t.Logf("QQQQQ: 1") dbB := populateDB(t, 32, itemsB) mesh, err := mocknet.FullMeshConnected(2) require.NoError(t, err) proto := "itest" - storeA := NewItemStoreAdapter(NewDBItemStore(dbA, "select id from foo", testQuery, 32, 24)) - storeB := NewItemStoreAdapter(NewDBItemStore(dbB, "select id from foo", testQuery, 32, 24)) + t.Logf("QQQQQ: 2") + storeA := NewItemStoreAdapter(NewDBItemStore(dbA, "select id from foo", testQuery, 32, maxDepth)) + t.Logf("QQQQQ: 3") + storeB := NewItemStoreAdapter(NewDBItemStore(dbB, "select id from foo", testQuery, 32, maxDepth)) + t.Logf("QQQQQ: 4") // QQQQQ: rmme // storeB.s.ft.traceEnabled = true @@ -90,7 +97,26 @@ func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { // uncomment to enable verbose logging which may slow down tests // hashsync.WithRangeSyncLogger(log.Named("sideB")), }) - require.NoError(t, pss.SyncStore(ctx, srvPeerID, storeB, nil, nil)) + + var x *types.Hash32 + it, err := storeB.Min() + require.NoError(t, err) + if it != nil { + x = &types.Hash32{} + k, err := it.Key() + require.NoError(t, err) + h := k.(types.Hash32) + v := load64(h[:]) & ^uint64(1<<(64-maxDepth)-1) + binary.BigEndian.PutUint64(x[:], v) + for i := 8; i < len(x); i++ { + x[i] = 0 + } + t.Logf("x: %s", x.String()) + } + + tStart := time.Now() + require.NoError(t, pss.SyncStore(ctx, srvPeerID, storeB, x, x)) + t.Logf("synced in %v", time.Since(tStart)) // // QQQQQ: rmme // sb = strings.Builder{} @@ -103,7 +129,7 @@ func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { if len(combinedItems) == 0 { return } - it, err := storeA.Min() + it, err = storeA.Min() require.NoError(t, err) var actItemsA []KeyBytes if len(combinedItems) == 0 { @@ -213,7 +239,7 @@ func TestP2P(t *testing.T) { }) t.Run("random test", func(t *testing.T) { // TODO: increase these values and profile - const nShared = 8000 + const nShared = 800000 const nUniqueA = 400 const nUniqueB = 800 // const nShared = 2 diff --git a/sync2/dbsync/sqlidstore.go b/sync2/dbsync/sqlidstore.go index 616ba81ecf..b6790e57fd 100644 --- a/sync2/dbsync/sqlidstore.go +++ b/sync2/dbsync/sqlidstore.go @@ -13,12 +13,18 @@ type sqlIDStore struct { db sql.Database query string keyLen int + cache *lru } var _ idStore = &sqlIDStore{} func newSQLIDStore(db sql.Database, query string, keyLen int) *sqlIDStore { - return &sqlIDStore{db: db, query: query, keyLen: keyLen} + return &sqlIDStore{ + db: db, + query: query, + keyLen: keyLen, + cache: newLRU(), + } } func (s *sqlIDStore) clone() idStore { @@ -39,7 +45,7 @@ func (s *sqlIDStore) iter(from KeyBytes) hashsync.Iterator { if len(from) != s.keyLen { panic("BUG: invalid key length") } - return newDBRangeIterator(s.db, s.query, from, sqlMaxChunkSize) + return newDBRangeIterator(s.db, s.query, from, sqlMaxChunkSize, s.cache) } type dbBackedStore struct { diff --git a/sync2/hashsync/rangesync.go b/sync2/hashsync/rangesync.go index f6f66c22f6..9def50e0a6 100644 --- a/sync2/hashsync/rangesync.go +++ b/sync2/hashsync/rangesync.go @@ -392,15 +392,11 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg HexField("fp1", si.Parts[1].Fingerprint), IteratorField("start1", si.Parts[1].End), IteratorField("end1", si.Parts[1].End)) - middle, err := si.Parts[0].End.Key() - if err != nil { - return false, err - } - if err := rsr.processSubrange(c, x, middle, si.Parts[0]); err != nil { + if err := rsr.processSubrange(c, x, si.Middle, si.Parts[0]); err != nil { return false, err } // fmt.Fprintf(os.Stderr, "QQQQQ: next=%q\n", qqqqRmmeK(next)) - if err := rsr.processSubrange(c, middle, y, si.Parts[1]); err != nil { + if err := rsr.processSubrange(c, si.Middle, y, si.Parts[1]); err != nil { return false, err } // fmt.Fprintf(os.Stderr, "normal: split X %s - middle %s - Y %s:\n %s", @@ -426,7 +422,7 @@ func (rsr *RangeSetReconciler) Initiate(c Conduit) error { } func (rsr *RangeSetReconciler) InitiateBounded(c Conduit, x, y Ordered) error { - rsr.log.Debug("inititate", HexField("x", x), HexField("y", y)) + rsr.log.Debug("initiate", HexField("x", x), HexField("y", y)) if x == nil { rsr.log.Debug("initiate: send empty set") if err := c.SendEmptySet(); err != nil { diff --git a/sync2/hashsync/split_sync.go b/sync2/hashsync/split_sync.go index b26187526d..d164b72b12 100644 --- a/sync2/hashsync/split_sync.go +++ b/sync2/hashsync/split_sync.go @@ -205,6 +205,7 @@ func getDelimiters(numPeers int) (h []types.Hash32) { if numPeers < 2 { return nil } + // QQQQQ: TBD: support maxDepth inc := (uint64(0x80) << 56) / uint64(numPeers) h = make([]types.Hash32, numPeers-1) for i, v := 0, uint64(0); i < numPeers-1; i++ { From 251cd3e8e126c96b205246ee98e295caa7627639 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Mon, 19 Aug 2024 23:32:53 +0400 Subject: [PATCH 60/76] wip8 -- fast enough --- sync2/dbsync/dbitemstore.go | 65 +++++++++------ sync2/dbsync/dbitemstore_test.go | 32 +++---- sync2/dbsync/dbiter.go | 6 +- sync2/dbsync/dbiter_test.go | 16 ++-- sync2/dbsync/fptree.go | 32 +++---- sync2/dbsync/fptree_test.go | 45 +++++----- sync2/dbsync/inmemidstore.go | 6 +- sync2/dbsync/inmemidstore_test.go | 11 ++- sync2/dbsync/p2p_test.go | 129 ++++++++++++++++++----------- sync2/dbsync/sqlidstore.go | 41 ++++++--- sync2/dbsync/sqlidstore_test.go | 83 +++++++++++-------- sync2/hashsync/handler.go | 10 ++- sync2/hashsync/handler_test.go | 8 +- sync2/hashsync/interface.go | 10 +-- sync2/hashsync/log.go | 38 ++++++--- sync2/hashsync/mocks_test.go | 60 +++++++------- sync2/hashsync/multipeer.go | 2 +- sync2/hashsync/multipeer_test.go | 14 ++-- sync2/hashsync/rangesync.go | 41 +++++---- sync2/hashsync/rangesync_test.go | 74 +++++++++++------ sync2/hashsync/setsyncbase.go | 8 +- sync2/hashsync/setsyncbase_test.go | 20 ++--- sync2/hashsync/sync_tree_store.go | 24 ++++-- 23 files changed, 461 insertions(+), 314 deletions(-) diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go index 1fad5abcb5..dc86035af8 100644 --- a/sync2/dbsync/dbitemstore.go +++ b/sync2/dbsync/dbitemstore.go @@ -40,8 +40,9 @@ func NewDBItemStore( } } -func (d *DBItemStore) load() error { - _, err := d.db.Exec(d.loadQuery, nil, +func (d *DBItemStore) load(ctx context.Context) error { + db := ContextSQLExec(ctx, d.db) + _, err := db.Exec(d.loadQuery, nil, func(stmt *sql.Statement) bool { id := make(KeyBytes, d.keyLen) // TODO: don't allocate new ID stmt.ColumnBytes(0, id[:]) @@ -51,11 +52,11 @@ func (d *DBItemStore) load() error { return err } -func (d *DBItemStore) EnsureLoaded() error { +func (d *DBItemStore) EnsureLoaded(ctx context.Context) error { d.loadMtx.Lock() defer d.loadMtx.Unlock() if !d.loaded { - if err := d.load(); err != nil { + if err := d.load(ctx); err != nil { return err } d.loaded = true @@ -65,10 +66,10 @@ func (d *DBItemStore) EnsureLoaded() error { // Add implements hashsync.ItemStore. func (d *DBItemStore) Add(ctx context.Context, k hashsync.Ordered) error { - if err := d.EnsureLoaded(); err != nil { + if err := d.EnsureLoaded(ctx); err != nil { return err } - has, err := d.Has(k) // TODO: this check shouldn't be needed + has, err := d.Has(ctx, k) // TODO: this check shouldn't be needed if has || err != nil { return err } @@ -77,14 +78,15 @@ func (d *DBItemStore) Add(ctx context.Context, k hashsync.Ordered) error { // GetRangeInfo implements hashsync.ItemStore. func (d *DBItemStore) GetRangeInfo( + ctx context.Context, preceding hashsync.Iterator, x, y hashsync.Ordered, count int, ) (hashsync.RangeInfo, error) { - if err := d.EnsureLoaded(); err != nil { + if err := d.EnsureLoaded(ctx); err != nil { return hashsync.RangeInfo{}, err } - fpr, err := d.ft.fingerprintInterval(x.(KeyBytes), y.(KeyBytes), count) + fpr, err := d.ft.fingerprintInterval(ctx, x.(KeyBytes), y.(KeyBytes), count) if err != nil { return hashsync.RangeInfo{}, err } @@ -97,6 +99,7 @@ func (d *DBItemStore) GetRangeInfo( } func (d *DBItemStore) SplitRange( + ctx context.Context, preceding hashsync.Iterator, x, y hashsync.Ordered, count int, @@ -108,11 +111,11 @@ func (d *DBItemStore) SplitRange( panic("BUG: bad split count") } - if err := d.EnsureLoaded(); err != nil { + if err := d.EnsureLoaded(ctx); err != nil { return hashsync.SplitInfo{}, err } - sr, err := d.ft.easySplit(x.(KeyBytes), y.(KeyBytes), count) + sr, err := d.ft.easySplit(ctx, x.(KeyBytes), y.(KeyBytes), count) if err == nil { // fmt.Fprintf(os.Stderr, "QQQQQ: fast split, middle: %s\n", sr.middle.String()) return hashsync.SplitInfo{ @@ -139,7 +142,7 @@ func (d *DBItemStore) SplitRange( } // fmt.Fprintf(os.Stderr, "QQQQQ: slow split x %s y %s\n", x.(fmt.Stringer), y.(fmt.Stringer)) - part0, err := d.GetRangeInfo(preceding, x, y, count) + part0, err := d.GetRangeInfo(ctx, preceding, x, y, count) if err != nil { return hashsync.SplitInfo{}, err } @@ -150,7 +153,7 @@ func (d *DBItemStore) SplitRange( if err != nil { return hashsync.SplitInfo{}, err } - part1, err := d.GetRangeInfo(part0.End.Clone(), middle, y, -1) + part1, err := d.GetRangeInfo(ctx, part0.End.Clone(), middle, y, -1) if err != nil { return hashsync.SplitInfo{}, err } @@ -161,14 +164,14 @@ func (d *DBItemStore) SplitRange( } // Min implements hashsync.ItemStore. -func (d *DBItemStore) Min() (hashsync.Iterator, error) { - if err := d.EnsureLoaded(); err != nil { +func (d *DBItemStore) Min(ctx context.Context) (hashsync.Iterator, error) { + if err := d.EnsureLoaded(ctx); err != nil { return nil, err } if d.ft.count() == 0 { return nil, nil } - it := d.ft.start() + it := d.ft.start(ctx) if _, err := it.Key(); err != nil { return nil, err } @@ -177,7 +180,10 @@ func (d *DBItemStore) Min() (hashsync.Iterator, error) { // Copy implements hashsync.ItemStore. func (d *DBItemStore) Copy() hashsync.ItemStore { - d.EnsureLoaded() + if !d.loaded { + // FIXME + panic("BUG: can't copy DBItemStore before it's loaded") + } return &DBItemStore{ db: d.db, ft: d.ft.clone(), @@ -190,8 +196,8 @@ func (d *DBItemStore) Copy() hashsync.ItemStore { } // Has implements hashsync.ItemStore. -func (d *DBItemStore) Has(k hashsync.Ordered) (bool, error) { - if err := d.EnsureLoaded(); err != nil { +func (d *DBItemStore) Has(ctx context.Context, k hashsync.Ordered) (bool, error) { + if err := d.EnsureLoaded(ctx); err != nil { return false, err } if d.ft.count() == 0 { @@ -199,7 +205,7 @@ func (d *DBItemStore) Has(k hashsync.Ordered) (bool, error) { } // TODO: should often be able to avoid querying the database if we check the key // against the fptree - it := d.ft.iter(k.(KeyBytes)) + it := d.ft.iter(ctx, k.(KeyBytes)) itK, err := it.Key() if err != nil { return false, err @@ -237,10 +243,16 @@ func (a *ItemStoreAdapter) Copy() hashsync.ItemStore { } // GetRangeInfo implements hashsync.ItemStore. -func (a *ItemStoreAdapter) GetRangeInfo(preceding hashsync.Iterator, x hashsync.Ordered, y hashsync.Ordered, count int) (hashsync.RangeInfo, error) { +func (a *ItemStoreAdapter) GetRangeInfo( + ctx context.Context, + preceding hashsync.Iterator, + x hashsync.Ordered, + y hashsync.Ordered, + count int, +) (hashsync.RangeInfo, error) { hx := x.(types.Hash32) hy := y.(types.Hash32) - info, err := a.s.GetRangeInfo(preceding, KeyBytes(hx[:]), KeyBytes(hy[:]), count) + info, err := a.s.GetRangeInfo(ctx, preceding, KeyBytes(hx[:]), KeyBytes(hy[:]), count) if err != nil { return hashsync.RangeInfo{}, err } @@ -256,6 +268,7 @@ func (a *ItemStoreAdapter) GetRangeInfo(preceding hashsync.Iterator, x hashsync. } func (a *ItemStoreAdapter) SplitRange( + ctx context.Context, preceding hashsync.Iterator, x hashsync.Ordered, y hashsync.Ordered, @@ -267,7 +280,7 @@ func (a *ItemStoreAdapter) SplitRange( hx := x.(types.Hash32) hy := y.(types.Hash32) si, err := a.s.SplitRange( - preceding, KeyBytes(hx[:]), KeyBytes(hy[:]), count) + ctx, preceding, KeyBytes(hx[:]), KeyBytes(hy[:]), count) if err != nil { return hashsync.SplitInfo{}, err } @@ -298,14 +311,14 @@ func (a *ItemStoreAdapter) SplitRange( } // Has implements hashsync.ItemStore. -func (a *ItemStoreAdapter) Has(k hashsync.Ordered) (bool, error) { +func (a *ItemStoreAdapter) Has(ctx context.Context, k hashsync.Ordered) (bool, error) { h := k.(types.Hash32) - return a.s.Has(KeyBytes(h[:])) + return a.s.Has(ctx, KeyBytes(h[:])) } // Min implements hashsync.ItemStore. -func (a *ItemStoreAdapter) Min() (hashsync.Iterator, error) { - it, err := a.s.Min() +func (a *ItemStoreAdapter) Min(ctx context.Context) (hashsync.Iterator, error) { + it, err := a.s.Min(ctx) if err != nil { return nil, err } diff --git a/sync2/dbsync/dbitemstore_test.go b/sync2/dbsync/dbitemstore_test.go index 8af95aa786..ecb6bae5f0 100644 --- a/sync2/dbsync/dbitemstore_test.go +++ b/sync2/dbsync/dbitemstore_test.go @@ -12,11 +12,12 @@ import ( func TestDBItemStoreEmpty(t *testing.T) { db := populateDB(t, 32, nil) s := NewDBItemStore(db, "select id from foo", testQuery, 32, 24) - it, err := s.Min() + ctx := context.Background() + it, err := s.Min(ctx) require.NoError(t, err) require.Nil(t, it) - info, err := s.GetRangeInfo(nil, + info, err := s.GetRangeInfo(ctx, nil, KeyBytes(util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")), KeyBytes(util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")), -1) @@ -26,7 +27,7 @@ func TestDBItemStoreEmpty(t *testing.T) { require.Nil(t, info.Start) require.Nil(t, info.End) - info, err = s.GetRangeInfo(nil, + info, err = s.GetRangeInfo(ctx, nil, KeyBytes(util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")), KeyBytes(util.FromHex("9999000000000000000000000000000000000000000000000000000000000000")), -1) @@ -45,13 +46,14 @@ func TestDBItemStore(t *testing.T) { util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000"), } + ctx := context.Background() db := populateDB(t, 32, ids) s := NewDBItemStore(db, "select id from foo", testQuery, 32, 24) - it, err := s.Min() + it, err := s.Min(ctx) require.NoError(t, err) require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", itKey(t, it).String()) - has, err := s.Has(KeyBytes(util.FromHex("9876000000000000000000000000000000000000000000000000000000000000"))) + has, err := s.Has(ctx, KeyBytes(util.FromHex("9876000000000000000000000000000000000000000000000000000000000000"))) require.NoError(t, err) require.False(t, has) @@ -111,16 +113,16 @@ func TestDBItemStore(t *testing.T) { name := fmt.Sprintf("%d-%d_%d", tc.xIdx, tc.yIdx, tc.limit) t.Run(name, func(t *testing.T) { t.Logf("x %s y %s limit %d", ids[tc.xIdx], ids[tc.yIdx], tc.limit) - info, err := s.GetRangeInfo(nil, ids[tc.xIdx], ids[tc.yIdx], tc.limit) + info, err := s.GetRangeInfo(ctx, nil, ids[tc.xIdx], ids[tc.yIdx], tc.limit) require.NoError(t, err) require.Equal(t, tc.count, info.Count) require.Equal(t, tc.fp, info.Fingerprint.(fmt.Stringer).String()) require.Equal(t, ids[tc.startIdx], itKey(t, info.Start)) require.Equal(t, ids[tc.endIdx], itKey(t, info.End)) - has, err := s.Has(ids[tc.startIdx]) + has, err := s.Has(ctx, ids[tc.startIdx]) require.NoError(t, err) require.True(t, has) - has, err = s.Has(ids[tc.endIdx]) + has, err = s.Has(ctx, ids[tc.endIdx]) require.NoError(t, err) require.True(t, has) }) @@ -136,7 +138,8 @@ func TestDBItemStoreAdd(t *testing.T) { } db := populateDB(t, 32, ids) s := NewDBItemStore(db, "select id from foo", testQuery, 32, 24) - it, err := s.Min() + ctx := context.Background() + it, err := s.Min(ctx) require.NoError(t, err) require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", itKey(t, it).String()) @@ -150,7 +153,7 @@ func TestDBItemStoreAdd(t *testing.T) { // s.ft.dump(&sb) // t.Logf("tree:\n%s", sb.String()) - info, err := s.GetRangeInfo(nil, ids[2], ids[0], -1) + info, err := s.GetRangeInfo(ctx, nil, ids[2], ids[0], -1) require.NoError(t, err) require.Equal(t, 3, info.Count) require.Equal(t, "761032cfe98ba54ddddddddd", info.Fingerprint.(fmt.Stringer).String()) @@ -167,14 +170,15 @@ func TestDBItemStoreCopy(t *testing.T) { } db := populateDB(t, 32, ids) s := NewDBItemStore(db, "select id from foo", testQuery, 32, 24) - it, err := s.Min() + ctx := context.Background() + it, err := s.Min(ctx) require.NoError(t, err) require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", itKey(t, it).String()) copy := s.Copy() - info, err := copy.GetRangeInfo(nil, ids[2], ids[0], -1) + info, err := copy.GetRangeInfo(ctx, nil, ids[2], ids[0], -1) require.NoError(t, err) require.Equal(t, 2, info.Count) require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.(fmt.Stringer).String()) @@ -184,14 +188,14 @@ func TestDBItemStoreCopy(t *testing.T) { newID := KeyBytes(util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000")) require.NoError(t, copy.Add(context.Background(), newID)) - info, err = s.GetRangeInfo(nil, ids[2], ids[0], -1) + info, err = s.GetRangeInfo(ctx, nil, ids[2], ids[0], -1) require.NoError(t, err) require.Equal(t, 2, info.Count) require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.(fmt.Stringer).String()) require.Equal(t, ids[2], itKey(t, info.Start)) require.Equal(t, ids[0], itKey(t, info.End)) - info, err = copy.GetRangeInfo(nil, ids[2], ids[0], -1) + info, err = copy.GetRangeInfo(ctx, nil, ids[2], ids[0], -1) require.NoError(t, err) require.Equal(t, 3, info.Count) require.Equal(t, "761032cfe98ba54ddddddddd", info.Fingerprint.(fmt.Stringer).String()) diff --git a/sync2/dbsync/dbiter.go b/sync2/dbsync/dbiter.go index e871295a79..2099fe3820 100644 --- a/sync2/dbsync/dbiter.go +++ b/sync2/dbsync/dbiter.go @@ -73,7 +73,7 @@ func newLRU() *lru { } type dbRangeIterator struct { - db sql.Database + db sql.Executor from KeyBytes query string chunkSize int @@ -91,7 +91,7 @@ var _ hashsync.Iterator = &dbRangeIterator{} // makeDBIterator creates a dbRangeIterator and initializes it from the database. // If query returns no rows even after starting from zero ID, errEmptySet error is returned. func newDBRangeIterator( - db sql.Database, + db sql.Executor, query string, from KeyBytes, maxChunkSize int, @@ -103,8 +103,6 @@ func newDBRangeIterator( if maxChunkSize <= 0 { panic("BUG: makeDBIterator: chunkSize must be > 0") } - // panic("TBD: QQQQQ: do not preload the iterator! Key should panic upon no entries. With from > max item, iterator should work, wrapping around (TEST)!") - // panic("TBD: QQQQQ: Key() should return an error!") return &dbRangeIterator{ db: db, from: from.Clone(), diff --git a/sync2/dbsync/dbiter_test.go b/sync2/dbsync/dbiter_test.go index 16107adf71..1f5c06bd33 100644 --- a/sync2/dbsync/dbiter_test.go +++ b/sync2/dbsync/dbiter_test.go @@ -43,15 +43,19 @@ func TestIncID(t *testing.T) { func createDB(t *testing.T, keyLen int) sql.Database { // QQQQQ: FIXME - tmpDir := t.TempDir() - t.Logf("QQQQQ: temp dir: %s", tmpDir) - db, err := sql.Open(fmt.Sprintf("file:%s/test.db", tmpDir), sql.WithIgnoreSchemaDrift()) - require.NoError(t, err) - // db := sql.InMemory(sql.WithIgnoreSchemaDrift()) + // tmpDir := t.TempDir() + // t.Logf("QQQQQ: temp dir: %s", tmpDir) + // db, err := sql.Open( + // fmt.Sprintf("file:%s/test.db", tmpDir), + // sql.WithIgnoreSchemaDrift(), + // sql.WithConnections(16), + // ) + // require.NoError(t, err) + db := sql.InMemory(sql.WithIgnoreSchemaDrift(), sql.WithConnections(16)) t.Cleanup(func() { require.NoError(t, db.Close()) }) - _, err = db.Exec(fmt.Sprintf("create table foo(id char(%d) not null primary key)", keyLen), nil, nil) + _, err := db.Exec(fmt.Sprintf("create table foo(id char(%d) not null primary key)", keyLen), nil, nil) require.NoError(t, err) return db } diff --git a/sync2/dbsync/fptree.go b/sync2/dbsync/fptree.go index 866db9a34b..2b98f6e8c6 100644 --- a/sync2/dbsync/fptree.go +++ b/sync2/dbsync/fptree.go @@ -2,6 +2,7 @@ package dbsync import ( "bytes" + "context" "encoding/binary" "encoding/hex" "errors" @@ -320,6 +321,7 @@ type fpResult struct { } type aggContext struct { + ctx context.Context x, y KeyBytes fp, fp0 fingerprint count, count0 uint32 @@ -456,8 +458,8 @@ func (ac *aggContext) maybeIncludeNode(node node, p prefix) bool { type idStore interface { clone() idStore registerHash(h KeyBytes) error - start() hashsync.Iterator - iter(from KeyBytes) hashsync.Iterator + start(ctx context.Context) hashsync.Iterator + iter(ctx context.Context, from KeyBytes) hashsync.Iterator } type fpTree struct { @@ -669,7 +671,7 @@ func (ft *fpTree) aggregateEdge(x, y KeyBytes, p prefix, ac *aggContext) (cont b startFrom = x } ft.log("aggregateEdge: startFrom %s", startFrom) - it := ft.iter(startFrom) + it := ft.iter(ac.ctx, startFrom) if ac.limit == 0 { ac.end = it.Clone() if x != nil { @@ -1022,19 +1024,19 @@ func (ft *fpTree) aggregateInterval(ac *aggContext) (err error) { } } -func (ft *fpTree) endIterFromPrefix(p prefix) hashsync.Iterator { +func (ft *fpTree) endIterFromPrefix(ac *aggContext, p prefix) hashsync.Iterator { k := make(KeyBytes, ft.keyLen) p.idAfter(k) ft.log("endIterFromPrefix: p: %s idAfter: %s", p, k) - return ft.iter(k) + return ft.iter(ac.ctx, k) } -func (ft *fpTree) fingerprintInterval(x, y KeyBytes, limit int) (fpr fpResult, err error) { +func (ft *fpTree) fingerprintInterval(ctx context.Context, x, y KeyBytes, limit int) (fpr fpResult, err error) { ft.enter("fingerprintInterval: x %s y %s limit %d", x, y, limit) defer func() { ft.leave(fpr.fp, fpr.count, fpr.itype, fpr.start, fpr.end, err) }() - ac := aggContext{x: x, y: y, limit: limit} + ac := aggContext{ctx: ctx, x: x, y: y, limit: limit} if err := ft.aggregateInterval(&ac); err != nil { return fpResult{}, err } @@ -1052,7 +1054,7 @@ func (ft *fpTree) fingerprintInterval(x, y KeyBytes, limit int) (fpr fpResult, e ft.log("fingerprintInterval: start %s", ac.start) fpr.start = ac.start } else { - fpr.start = ft.iter(x) + fpr.start = ft.iter(ac.ctx, x) ft.log("fingerprintInterval: start from x: %s", fpr.start) } @@ -1063,10 +1065,10 @@ func (ft *fpTree) fingerprintInterval(x, y KeyBytes, limit int) (fpr fpResult, e fpr.end = fpr.start ft.log("fingerprintInterval: end at start %s", fpr.end) } else if ac.lastPrefix != nil { - fpr.end = ft.endIterFromPrefix(*ac.lastPrefix) + fpr.end = ft.endIterFromPrefix(&ac, *ac.lastPrefix) ft.log("fingerprintInterval: end at lastPrefix %s -> %s", *ac.lastPrefix, fpr.end) } else { - fpr.end = ft.iter(y) + fpr.end = ft.iter(ac.ctx, y) ft.log("fingerprintInterval: end at y: %s", fpr.end) } @@ -1082,7 +1084,7 @@ type splitResult struct { // part has close to limit items while not making any idStore queries so that the database // is not accessed. If the split can't be done, which includes the situation where one of // the sides has 0 items, easySplit returns errEasySplitFailed error -func (ft *fpTree) easySplit(x, y KeyBytes, limit int) (sr splitResult, err error) { +func (ft *fpTree) easySplit(ctx context.Context, x, y KeyBytes, limit int) (sr splitResult, err error) { ft.enter("easySplit: x %s y %s limit %d", x, y, limit) defer func() { ft.leave(sr.part0.fp, sr.part0.count, sr.part0.itype, sr.part0.start, sr.part0.end, @@ -1091,7 +1093,7 @@ func (ft *fpTree) easySplit(x, y KeyBytes, limit int) (sr splitResult, err error if limit < 0 { panic("BUG: easySplit with limit < 0") } - ac := aggContext{x: x, y: y, limit: limit, easySplit: true} + ac := aggContext{ctx: ctx, x: x, y: y, limit: limit, easySplit: true} if err := ft.aggregateInterval(&ac); err != nil { return splitResult{}, err } @@ -1121,15 +1123,15 @@ func (ft *fpTree) easySplit(x, y KeyBytes, limit int) (sr splitResult, err error fp: ac.fp0, count: ac.count0, itype: ac.itype, - start: ft.iter(x), - end: ft.endIterFromPrefix(*ac.lastPrefix0), + start: ft.iter(ac.ctx, x), + end: ft.endIterFromPrefix(&ac, *ac.lastPrefix0), } part1 := fpResult{ fp: ac.fp, count: ac.count, itype: ac.itype, start: part0.end.Clone(), - end: ft.endIterFromPrefix(*ac.lastPrefix), + end: ft.endIterFromPrefix(&ac, *ac.lastPrefix), } return splitResult{ part0: part0, diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index 0a7e4cbc21..36c02e6265 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -1,6 +1,7 @@ package dbsync import ( + "context" "encoding/binary" "fmt" "math" @@ -728,7 +729,10 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { name = fmt.Sprintf("%d-%d_%d", rtc.xIdx, rtc.yIdx, rtc.limit) } t.Run(name, func(t *testing.T) { - fpr, err := ft.fingerprintInterval(x[:], y[:], rtc.limit) + fpr, err := ft.fingerprintInterval( + context.Background(), + x[:], y[:], rtc.limit, + ) require.NoError(t, err) assert.Equal(t, rtc.fp, fpr.fp.String(), "fp") assert.Equal(t, rtc.count, fpr.count, "count") @@ -786,12 +790,12 @@ func (noIDStore) registerHash(h KeyBytes) error { return nil } -func (noIDStore) start() hashsync.Iterator { +func (noIDStore) start(ctx context.Context) hashsync.Iterator { panic("no ID store") } -func (noIDStore) iter(from KeyBytes) hashsync.Iterator { +func (noIDStore) iter(ctx context.Context, from KeyBytes) hashsync.Iterator { return noIter{} } @@ -856,7 +860,7 @@ func TestFPTreeNoIDStore(t *testing.T) { count: 4, }, } { - fpr, err := ft.fingerprintInterval(tc.x, tc.y, tc.limit) + fpr, err := ft.fingerprintInterval(context.Background(), tc.x, tc.y, tc.limit) require.NoError(t, err) require.Equal(t, tc.fp, fpr.fp.String(), "fp") require.Equal(t, tc.count, fpr.count, "count") @@ -874,7 +878,8 @@ func TestFPTreeClone(t *testing.T) { ft1.addHash(hashes[0][:]) ft1.addHash(hashes[1][:]) - fpr, err := ft1.fingerprintInterval(hashes[0][:], hashes[0][:], -1) + ctx := context.Background() + fpr, err := ft1.fingerprintInterval(ctx, hashes[0][:], hashes[0][:], -1) require.NoError(t, err) require.Equal(t, "222222222222222222222222", fpr.fp.String(), "fp") require.Equal(t, uint32(2), fpr.count, "count") @@ -895,7 +900,7 @@ func TestFPTreeClone(t *testing.T) { t.Logf("ft2 after-clone:\n%s", sb.String()) // original tree unchanged --- rmme!!!! - fpr, err = ft1.fingerprintInterval(hashes[0][:], hashes[0][:], -1) + fpr, err = ft1.fingerprintInterval(ctx, hashes[0][:], hashes[0][:], -1) require.NoError(t, err) require.Equal(t, "222222222222222222222222", fpr.fp.String(), "fp") require.Equal(t, uint32(2), fpr.count, "count") @@ -903,14 +908,14 @@ func TestFPTreeClone(t *testing.T) { ft2.addHash(hashes[2][:]) - fpr, err = ft2.fingerprintInterval(hashes[0][:], hashes[0][:], -1) + fpr, err = ft2.fingerprintInterval(ctx, hashes[0][:], hashes[0][:], -1) require.NoError(t, err) require.Equal(t, "666666666666666666666666", fpr.fp.String(), "fp") require.Equal(t, uint32(3), fpr.count, "count") require.Equal(t, 0, fpr.itype, "itype") // original tree unchanged - fpr, err = ft1.fingerprintInterval(hashes[0][:], hashes[0][:], -1) + fpr, err = ft1.fingerprintInterval(ctx, hashes[0][:], hashes[0][:], -1) require.NoError(t, err) require.Equal(t, "222222222222222222222222", fpr.fp.String(), "fp") require.Equal(t, uint32(2), fpr.count, "count") @@ -1091,7 +1096,7 @@ func dumbFP(hs hashList, x, y types.Hash32, limit int) fpResultWithBounds { func verifyInterval(t *testing.T, hs hashList, ft *fpTree, x, y types.Hash32, limit int) fpResult { expFPR := dumbFP(hs, x, y, limit) - fpr, err := ft.fingerprintInterval(x[:], y[:], limit) + fpr, err := ft.fingerprintInterval(context.Background(), x[:], y[:], limit) require.NoError(t, err) require.Equal(t, expFPR, toFPResultWithBounds(t, fpr), "x=%s y=%s limit=%d", x.String(), y.String(), limit) @@ -1149,7 +1154,7 @@ func testFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool, numItems, checkTree(t, ft, maxDepth) - fpr, err := ft.fingerprintInterval(hs[0][:], hs[0][:], -1) + fpr, err := ft.fingerprintInterval(context.Background(), hs[0][:], hs[0][:], -1) require.NoError(t, err) require.Equal(t, fp, fpr.fp, "fp") require.Equal(t, uint32(numItems), fpr.count, "count") @@ -1223,7 +1228,7 @@ func verifyEasySplit( ) { // t.Logf("depth %d", depth) // t.Logf("--- fingerprint interval %s %s ---", x.String(), y.String()) - fpr, err := ft.fingerprintInterval(x, y, -1) + fpr, err := ft.fingerprintInterval(context.Background(), x, y, -1) require.NoError(t, err) if fpr.count <= 1 { return @@ -1235,7 +1240,7 @@ func verifyEasySplit( m := fpr.count / 2 // t.Logf("--- easy split %s %s %d ---", x.String(), y.String(), m) - sr, err := ft.easySplit(x[:], y[:], int(m)) + sr, err := ft.easySplit(context.Background(), x[:], y[:], int(m)) if err != nil { require.ErrorIs(t, err, errEasySplitFailed) return 0, 1 @@ -1254,28 +1259,28 @@ func verifyEasySplit( precMiddle := itKey(t, sr.part0.end) require.Equal(t, precMiddle, itKey(t, sr.part1.start)) - fpr11, err := ft.fingerprintInterval(x, precMiddle, -1) + fpr11, err := ft.fingerprintInterval(context.Background(), x, precMiddle, -1) require.NoError(t, err) require.Equal(t, sr.part0.fp, fpr11.fp) require.Equal(t, sr.part0.count, fpr11.count) require.Equal(t, a, itKey(t, fpr11.start)) require.Equal(t, precMiddle, itKey(t, fpr11.end)) - fpr12, err := ft.fingerprintInterval(precMiddle, y, -1) + fpr12, err := ft.fingerprintInterval(context.Background(), precMiddle, y, -1) require.NoError(t, err) require.Equal(t, sr.part1.fp, fpr12.fp) require.Equal(t, sr.part1.count, fpr12.count) require.Equal(t, precMiddle, itKey(t, fpr12.start)) require.Equal(t, b, itKey(t, fpr12.end)) - fpr11, err = ft.fingerprintInterval(x, sr.middle, -1) + fpr11, err = ft.fingerprintInterval(context.Background(), x, sr.middle, -1) require.NoError(t, err) require.Equal(t, sr.part0.fp, fpr11.fp) require.Equal(t, sr.part0.count, fpr11.count) require.Equal(t, a, itKey(t, fpr11.start)) require.Equal(t, precMiddle, itKey(t, fpr11.end)) - fpr12, err = ft.fingerprintInterval(sr.middle, y, -1) + fpr12, err = ft.fingerprintInterval(context.Background(), sr.middle, y, -1) require.NoError(t, err) require.Equal(t, sr.part1.fp, fpr12.fp) require.Equal(t, sr.part1.count, fpr12.count) @@ -1301,7 +1306,7 @@ func TestEasySplit(t *testing.T) { // t.Logf("adding hash %s", h.String()) ft.addHash(h[:]) } - k, err := ft.start().Key() + k, err := ft.start(context.Background()).Key() require.NoError(t, err) x := k.(KeyBytes) v := load64(x) & ^(1<<(64-maxDepth) - 1) @@ -1460,7 +1465,7 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { for n := 0; n < numIter; n++ { x := types.RandomHash() y := types.RandomHash() - ft.fingerprintInterval(x[:], y[:], -1) + ft.fingerprintInterval(context.Background(), x[:], y[:], -1) } elapsed := time.Now().Sub(ts) @@ -1486,7 +1491,7 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { // t.Logf("QQQQQ: x=%s y=%s", x.String(), y.String()) expFPResult := dumbFP(*hs, x, y, -1) //expFPResult := dumbAggATXs(t, db, x, y) - fpr, err := ft.fingerprintInterval(x[:], y[:], -1) + fpr, err := ft.fingerprintInterval(context.Background(), x[:], y[:], -1) require.NoError(t, err) require.Equal(t, expFPResult, toFPResultWithBounds(t, fpr), "x=%s y=%s", x.String(), y.String()) @@ -1497,7 +1502,7 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { } // t.Logf("QQQQQ: x=%s y=%s limit=%d", x.String(), y.String(), limit) expFPResult = dumbFP(*hs, x, y, limit) - fpr, err = ft.fingerprintInterval(x[:], y[:], limit) + fpr, err = ft.fingerprintInterval(context.Background(), x[:], y[:], limit) require.NoError(t, err) require.Equal(t, expFPResult, toFPResultWithBounds(t, fpr), "x=%s y=%s limit=%d", x.String(), y.String(), limit) diff --git a/sync2/dbsync/inmemidstore.go b/sync2/dbsync/inmemidstore.go index 5694cfab19..e6f420145a 100644 --- a/sync2/dbsync/inmemidstore.go +++ b/sync2/dbsync/inmemidstore.go @@ -1,6 +1,8 @@ package dbsync import ( + "context" + "github.com/spacemeshos/go-spacemesh/sync2/hashsync" "github.com/spacemeshos/go-spacemesh/sync2/internal/skiplist" ) @@ -34,11 +36,11 @@ func (s *inMemIDStore) registerHash(h KeyBytes) error { return nil } -func (s *inMemIDStore) start() hashsync.Iterator { +func (s *inMemIDStore) start(ctx context.Context) hashsync.Iterator { return &inMemIDStoreIterator{sl: s.sl, node: s.sl.First()} } -func (s *inMemIDStore) iter(from KeyBytes) hashsync.Iterator { +func (s *inMemIDStore) iter(ctx context.Context, from KeyBytes) hashsync.Iterator { node := s.sl.FindGTENode(from) if node == nil { node = s.sl.First() diff --git a/sync2/dbsync/inmemidstore_test.go b/sync2/dbsync/inmemidstore_test.go index 0a74582d3b..337dccf2c7 100644 --- a/sync2/dbsync/inmemidstore_test.go +++ b/sync2/dbsync/inmemidstore_test.go @@ -1,6 +1,7 @@ package dbsync import ( + "context" "encoding/hex" "testing" @@ -17,11 +18,12 @@ func TestInMemIDStore(t *testing.T) { err error ) s := newInMemIDStore(32) + ctx := context.Background() - _, err = s.start().Key() + _, err = s.start(ctx).Key() require.ErrorIs(t, err, errEmptySet) - _, err = s.iter(util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")).Key() + _, err = s.iter(ctx, util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")).Key() require.ErrorIs(t, err, errEmptySet) for _, h := range []string{ @@ -38,9 +40,10 @@ func TestInMemIDStore(t *testing.T) { for i := range 6 { if i%2 == 0 { - it = s.start() + it = s.start(ctx) } else { it = s.iter( + ctx, util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")) } var items []string @@ -67,6 +70,7 @@ func TestInMemIDStore(t *testing.T) { s1.registerHash(h[:]) items = nil it = s1.iter( + ctx, util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")) for range 8 { items = append(items, hex.EncodeToString(itKey(t, it))) @@ -87,6 +91,7 @@ func TestInMemIDStore(t *testing.T) { hex.EncodeToString(itKey(t, it))) it = s1.iter( + ctx, util.FromHex("fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0")) require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", diff --git a/sync2/dbsync/p2p_test.go b/sync2/dbsync/p2p_test.go index 869e8c1c36..26cc9315dc 100644 --- a/sync2/dbsync/p2p_test.go +++ b/sync2/dbsync/p2p_test.go @@ -18,6 +18,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/common/util" "github.com/spacemeshos/go-spacemesh/p2p/server" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sync2/hashsync" ) @@ -32,9 +33,41 @@ func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { require.NoError(t, err) proto := "itest" t.Logf("QQQQQ: 2") + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) storeA := NewItemStoreAdapter(NewDBItemStore(dbA, "select id from foo", testQuery, 32, maxDepth)) + t.Logf("QQQQQ: 2.1") + require.NoError(t, dbA.WithTx(ctx, func(tx sql.Transaction) error { + return storeA.s.EnsureLoaded(WithSQLExec(ctx, tx)) + })) t.Logf("QQQQQ: 3") storeB := NewItemStoreAdapter(NewDBItemStore(dbB, "select id from foo", testQuery, 32, maxDepth)) + t.Logf("QQQQQ: 3.1") + var x *types.Hash32 + require.NoError(t, dbB.WithTx(ctx, func(tx sql.Transaction) error { + ctx := WithSQLExec(ctx, tx) + if err := storeB.s.EnsureLoaded(ctx); err != nil { + return err + } + it, err := storeB.Min(ctx) + if err != nil { + return err + } + if it != nil { + x = &types.Hash32{} + k, err := it.Key() + if err != nil { + return err + } + h := k.(types.Hash32) + v := load64(h[:]) & ^uint64(1<<(64-maxDepth)-1) + binary.BigEndian.PutUint64(x[:], v) + for i := 8; i < len(x); i++ { + x[i] = 0 + } + t.Logf("x: %s", x.String()) + } + return nil + })) t.Logf("QQQQQ: 4") // QQQQQ: rmme @@ -56,12 +89,13 @@ func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { // uncomment to enable verbose logging which may slow down tests // hashsync.WithRangeSyncLogger(log.Named("sideA")), }) - return pss.Serve(ctx, req, stream, storeA) + return dbA.WithTx(ctx, func(tx sql.Transaction) error { + return pss.Serve(WithSQLExec(ctx, tx), req, stream, storeA) + }) }, server.WithTimeout(10*time.Second), server.WithLog(log)) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) var eg errgroup.Group client := server.New(mesh.Hosts()[1], proto, @@ -98,24 +132,10 @@ func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { // hashsync.WithRangeSyncLogger(log.Named("sideB")), }) - var x *types.Hash32 - it, err := storeB.Min() - require.NoError(t, err) - if it != nil { - x = &types.Hash32{} - k, err := it.Key() - require.NoError(t, err) - h := k.(types.Hash32) - v := load64(h[:]) & ^uint64(1<<(64-maxDepth)-1) - binary.BigEndian.PutUint64(x[:], v) - for i := 8; i < len(x); i++ { - x[i] = 0 - } - t.Logf("x: %s", x.String()) - } - tStart := time.Now() - require.NoError(t, pss.SyncStore(ctx, srvPeerID, storeB, x, x)) + require.NoError(t, dbB.WithTx(ctx, func(tx sql.Transaction) error { + return pss.SyncStore(WithSQLExec(ctx, tx), srvPeerID, storeB, x, x) + })) t.Logf("synced in %v", time.Since(tStart)) // // QQQQQ: rmme @@ -129,47 +149,53 @@ func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { if len(combinedItems) == 0 { return } - it, err = storeA.Min() - require.NoError(t, err) var actItemsA []KeyBytes - if len(combinedItems) == 0 { - assert.Nil(t, it) - } else { - for range combinedItems { + require.NoError(t, dbA.WithTx(ctx, func(tx sql.Transaction) error { + it, err := storeA.Min(WithSQLExec(ctx, tx)) + require.NoError(t, err) + if len(combinedItems) == 0 { + assert.Nil(t, it) + } else { + for range combinedItems { + k, err := it.Key() + require.NoError(t, err) + h := k.(types.Hash32) + // t.Logf("synced itemA: %s", h.String()) + actItemsA = append(actItemsA, h[:]) + require.NoError(t, it.Next()) + } k, err := it.Key() require.NoError(t, err) h := k.(types.Hash32) - // t.Logf("synced itemA: %s", h.String()) - actItemsA = append(actItemsA, h[:]) - require.NoError(t, it.Next()) + assert.Equal(t, actItemsA[0], KeyBytes(h[:])) } - k, err := it.Key() - require.NoError(t, err) - h := k.(types.Hash32) - assert.Equal(t, actItemsA[0], KeyBytes(h[:])) - } + return nil + })) - it, err = storeB.Min() - require.NoError(t, err) var actItemsB []KeyBytes - if len(combinedItems) == 0 { - assert.Nil(t, it) - } else { - for range combinedItems { + require.NoError(t, dbB.WithTx(ctx, func(tx sql.Transaction) error { + it, err := storeB.Min(WithSQLExec(ctx, tx)) + require.NoError(t, err) + if len(combinedItems) == 0 { + assert.Nil(t, it) + } else { + for range combinedItems { + k, err := it.Key() + require.NoError(t, err) + h := k.(types.Hash32) + // t.Logf("synced itemB: %s", h.String()) + actItemsB = append(actItemsB, h[:]) + require.NoError(t, it.Next()) + } k, err := it.Key() require.NoError(t, err) h := k.(types.Hash32) - // t.Logf("synced itemB: %s", h.String()) - actItemsB = append(actItemsB, h[:]) - require.NoError(t, it.Next()) + assert.Equal(t, actItemsB[0], KeyBytes(h[:])) } - k, err := it.Key() - require.NoError(t, err) - h := k.(types.Hash32) - assert.Equal(t, actItemsB[0], KeyBytes(h[:])) - assert.Equal(t, combinedItems, actItemsA) - assert.Equal(t, actItemsA, actItemsB) - } + return nil + })) + assert.Equal(t, combinedItems, actItemsA) + assert.Equal(t, actItemsA, actItemsB) } func TestP2P(t *testing.T) { @@ -239,7 +265,10 @@ func TestP2P(t *testing.T) { }) t.Run("random test", func(t *testing.T) { // TODO: increase these values and profile - const nShared = 800000 + // const nShared = 8000000 + // const nUniqueA = 40000 + // const nUniqueB = 80000 + const nShared = 80000 const nUniqueA = 400 const nUniqueB = 800 // const nShared = 2 diff --git a/sync2/dbsync/sqlidstore.go b/sync2/dbsync/sqlidstore.go index b6790e57fd..1c90aa8743 100644 --- a/sync2/dbsync/sqlidstore.go +++ b/sync2/dbsync/sqlidstore.go @@ -2,6 +2,7 @@ package dbsync import ( "bytes" + "context" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sync2/hashsync" @@ -9,6 +10,26 @@ import ( const sqlMaxChunkSize = 1024 +type dbExecKey struct{} + +func WithSQLExec(ctx context.Context, db sql.Executor) context.Context { + return context.WithValue(ctx, dbExecKey{}, db) +} + +func ContextSQLExec(ctx context.Context, db sql.Database) sql.Executor { + v := ctx.Value(dbExecKey{}) + if v == nil { + return db + } + return v.(sql.Executor) +} + +func WithSQLTransaction(ctx context.Context, db sql.Database, toCall func(context.Context) error) error { + return db.WithTx(ctx, func(tx sql.Transaction) error { + return toCall(WithSQLExec(ctx, tx)) + }) +} + type sqlIDStore struct { db sql.Database query string @@ -36,16 +57,16 @@ func (s *sqlIDStore) registerHash(h KeyBytes) error { return nil } -func (s *sqlIDStore) start() hashsync.Iterator { +func (s *sqlIDStore) start(ctx context.Context) hashsync.Iterator { // TODO: should probably use a different query to get the first key - return s.iter(make(KeyBytes, s.keyLen)) + return s.iter(ctx, make(KeyBytes, s.keyLen)) } -func (s *sqlIDStore) iter(from KeyBytes) hashsync.Iterator { +func (s *sqlIDStore) iter(ctx context.Context, from KeyBytes) hashsync.Iterator { if len(from) != s.keyLen { panic("BUG: invalid key length") } - return newDBRangeIterator(s.db, s.query, from, sqlMaxChunkSize, s.cache) + return newDBRangeIterator(ContextSQLExec(ctx, s.db), s.query, from, sqlMaxChunkSize, s.cache) } type dbBackedStore struct { @@ -73,15 +94,15 @@ func (s *dbBackedStore) registerHash(h KeyBytes) error { return s.inMemIDStore.registerHash(h) } -func (s *dbBackedStore) start() hashsync.Iterator { - dbIt := s.sqlIDStore.start() - memIt := s.inMemIDStore.start() +func (s *dbBackedStore) start(ctx context.Context) hashsync.Iterator { + dbIt := s.sqlIDStore.start(ctx) + memIt := s.inMemIDStore.start(ctx) return combineIterators(nil, dbIt, memIt) } -func (s *dbBackedStore) iter(from KeyBytes) hashsync.Iterator { - dbIt := s.sqlIDStore.iter(from) - memIt := s.inMemIDStore.iter(from) +func (s *dbBackedStore) iter(ctx context.Context, from KeyBytes) hashsync.Iterator { + dbIt := s.sqlIDStore.iter(ctx, from) + memIt := s.inMemIDStore.iter(ctx, from) return combineIterators(from, dbIt, memIt) } diff --git a/sync2/dbsync/sqlidstore_test.go b/sync2/dbsync/sqlidstore_test.go index 02e9169b67..3007747e89 100644 --- a/sync2/dbsync/sqlidstore_test.go +++ b/sync2/dbsync/sqlidstore_test.go @@ -1,6 +1,7 @@ package dbsync import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -14,41 +15,55 @@ func TestDBBackedStore(t *testing.T) { {0, 0, 0, 7, 0, 0, 0, 0}, } db := populateDB(t, 8, initialIDs) - store := newDBBackedStore(db, fakeIDQuery, 8) - var actualIDs []KeyBytes - it := store.iter(KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}) - for range 5 { - actualIDs = append(actualIDs, itKey(t, it)) - require.NoError(t, it.Next()) - } - require.Equal(t, []KeyBytes{ - {0, 0, 0, 1, 0, 0, 0, 0}, - {0, 0, 0, 3, 0, 0, 0, 0}, - {0, 0, 0, 5, 0, 0, 0, 0}, - {0, 0, 0, 7, 0, 0, 0, 0}, - {0, 0, 0, 1, 0, 0, 0, 0}, // wrapped around - }, actualIDs) + verify := func(t *testing.T, ctx context.Context) { + store := newDBBackedStore(db, fakeIDQuery, 8) + it := store.iter(ctx, KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}) + var actualIDs []KeyBytes + for range 5 { + actualIDs = append(actualIDs, itKey(t, it)) + require.NoError(t, it.Next()) + } + require.Equal(t, []KeyBytes{ + {0, 0, 0, 1, 0, 0, 0, 0}, + {0, 0, 0, 3, 0, 0, 0, 0}, + {0, 0, 0, 5, 0, 0, 0, 0}, + {0, 0, 0, 7, 0, 0, 0, 0}, + {0, 0, 0, 1, 0, 0, 0, 0}, // wrapped around + }, actualIDs) - it = store.start() - for n := range 5 { - require.Equal(t, actualIDs[n], itKey(t, it)) - require.NoError(t, it.Next()) - } + it = store.start(ctx) + for n := range 5 { + require.Equal(t, actualIDs[n], itKey(t, it)) + require.NoError(t, it.Next()) + } - require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 2, 0, 0, 0, 0})) - require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 9, 0, 0, 0, 0})) - actualIDs = nil - it = store.iter(KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}) - for range 6 { - actualIDs = append(actualIDs, itKey(t, it)) - require.NoError(t, it.Next()) + require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 2, 0, 0, 0, 0})) + require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 9, 0, 0, 0, 0})) + actualIDs = nil + it = store.iter(ctx, KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}) + for range 6 { + actualIDs = append(actualIDs, itKey(t, it)) + require.NoError(t, it.Next()) + } + require.Equal(t, []KeyBytes{ + {0, 0, 0, 1, 0, 0, 0, 0}, + {0, 0, 0, 2, 0, 0, 0, 0}, + {0, 0, 0, 3, 0, 0, 0, 0}, + {0, 0, 0, 5, 0, 0, 0, 0}, + {0, 0, 0, 7, 0, 0, 0, 0}, + {0, 0, 0, 9, 0, 0, 0, 0}, + }, actualIDs) } - require.Equal(t, []KeyBytes{ - {0, 0, 0, 1, 0, 0, 0, 0}, - {0, 0, 0, 2, 0, 0, 0, 0}, - {0, 0, 0, 3, 0, 0, 0, 0}, - {0, 0, 0, 5, 0, 0, 0, 0}, - {0, 0, 0, 7, 0, 0, 0, 0}, - {0, 0, 0, 9, 0, 0, 0, 0}, - }, actualIDs) + + t.Run("no transaction", func(t *testing.T) { + verify(t, context.Background()) + }) + + t.Run("with transaction", func(t *testing.T) { + err := WithSQLTransaction(context.Background(), db, func(ctx context.Context) error { + verify(t, ctx) + return nil + }) + require.NoError(t, err) + }) } diff --git a/sync2/hashsync/handler.go b/sync2/hashsync/handler.go index 627e17e8ef..143b2374ca 100644 --- a/sync2/hashsync/handler.go +++ b/sync2/hashsync/handler.go @@ -287,12 +287,12 @@ func (pss *PairwiseStoreSyncer) Probe( rsr := NewRangeSetReconciler(is, pss.opts...) if x == nil { initReq, err = c.withInitialRequest(func(c Conduit) error { - info, err = rsr.InitiateProbe(c) + info, err = rsr.InitiateProbe(ctx, c) return err }) } else { initReq, err = c.withInitialRequest(func(c Conduit) error { - info, err = rsr.InitiateBoundedProbe(c, *x, *y) + info, err = rsr.InitiateBoundedProbe(ctx, c, *x, *y) return err }) } @@ -325,10 +325,12 @@ func (pss *PairwiseStoreSyncer) SyncStore( err error ) if x == nil { - initReq, err = c.withInitialRequest(rsr.Initiate) + initReq, err = c.withInitialRequest(func(c Conduit) error { + return rsr.Initiate(ctx, c) + }) } else { initReq, err = c.withInitialRequest(func(c Conduit) error { - return rsr.InitiateBounded(c, *x, *y) + return rsr.InitiateBounded(ctx, c, *x, *y) }) } if err != nil { diff --git a/sync2/hashsync/handler_test.go b/sync2/hashsync/handler_test.go index a83d1330f6..8da457e945 100644 --- a/sync2/hashsync/handler_test.go +++ b/sync2/hashsync/handler_test.go @@ -481,11 +481,11 @@ func testWireProbe(t *testing.T, getRequester getRequesterFunc) Requester { storeA, getRequester, opts, func(ctx context.Context, client Requester, srvPeerID p2p.Peer) { pss := NewPairwiseStoreSyncer(client, opts) - minA, err := storeA.Min() + minA, err := storeA.Min(ctx) require.NoError(t, err) kA, err := minA.Key() require.NoError(t, err) - infoA, err := storeA.GetRangeInfo(nil, kA, kA, -1) + infoA, err := storeA.GetRangeInfo(ctx, nil, kA, kA, -1) require.NoError(t, err) prA, err := pss.Probe(ctx, srvPeerID, storeB, nil, nil) require.NoError(t, err) @@ -493,11 +493,11 @@ func testWireProbe(t *testing.T, getRequester getRequesterFunc) Requester { require.Equal(t, infoA.Count, prA.Count) require.InDelta(t, 0.98, prA.Sim, 0.05, "sim") - minA, err = storeA.Min() + minA, err = storeA.Min(ctx) require.NoError(t, err) kA, err = minA.Key() require.NoError(t, err) - partInfoA, err := storeA.GetRangeInfo(nil, kA, kA, infoA.Count/2) + partInfoA, err := storeA.GetRangeInfo(ctx, nil, kA, kA, infoA.Count/2) require.NoError(t, err) xK, err := partInfoA.Start.Key() require.NoError(t, err) diff --git a/sync2/hashsync/interface.go b/sync2/hashsync/interface.go index 86cf0a9243..d7445177f9 100644 --- a/sync2/hashsync/interface.go +++ b/sync2/hashsync/interface.go @@ -54,17 +54,17 @@ type ItemStore interface { // is returned for the corresponding subrange of the requested range. // If both x and y is nil, the whole set of items is used. // If only x or only y is nil, GetRangeInfo panics - GetRangeInfo(preceding Iterator, x, y Ordered, count int) (RangeInfo, error) + GetRangeInfo(ctx context.Context, preceding Iterator, x, y Ordered, count int) (RangeInfo, error) // SplitRange splits the range roughly after the specified count of items, // returning RangeInfo for the first half and the second half of the range. - SplitRange(preceding Iterator, x, y Ordered, count int) (SplitInfo, error) + SplitRange(ctx context.Context, preceding Iterator, x, y Ordered, count int) (SplitInfo, error) // Min returns the iterator pointing at the minimum element // in the store. If the store is empty, it returns nil - Min() (Iterator, error) + Min(ctx context.Context) (Iterator, error) // Copy makes a shallow copy of the ItemStore Copy() ItemStore // Has returns true if the specified key is present in ItemStore - Has(k Ordered) (bool, error) + Has(ctx context.Context, k Ordered) (bool, error) } type Requester interface { @@ -73,7 +73,7 @@ type Requester interface { } type SyncBase interface { - Count() (int, error) + Count(ctx context.Context) (int, error) Derive(p p2p.Peer) Syncer Probe(ctx context.Context, p p2p.Peer) (ProbeResult, error) Wait() error diff --git a/sync2/hashsync/log.go b/sync2/hashsync/log.go index 8d3f085b06..980c9f1434 100644 --- a/sync2/hashsync/log.go +++ b/sync2/hashsync/log.go @@ -10,15 +10,23 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" ) +type itFormatter struct { + it Iterator +} + +func (f itFormatter) String() string { + k, err := f.it.Key() + if err != nil { + return fmt.Sprintf("", err) + } + return hexStr(k) +} + func IteratorField(name string, it Iterator) zap.Field { if it == nil { return zap.String(name, "") } - k, err := it.Key() - if err != nil { - return zap.String(name, fmt.Sprintf("", err)) - } - return HexField(name, k) + return zap.Stringer(name, itFormatter{it: it}) } // based on code from testify @@ -40,25 +48,29 @@ func isNil(object any) bool { return false } -func HexField(name string, k any) zap.Field { +func hexStr(k any) string { switch h := k.(type) { case types.Hash32: - return zap.String(name, h.ShortString()) + return h.ShortString() case types.Hash12: - return zap.String(name, hex.EncodeToString(h[:5])) + return hex.EncodeToString(h[:5]) case []byte: if len(h) > 5 { h = h[:5] } - return zap.String(name, hex.EncodeToString(h[:5])) + return hex.EncodeToString(h[:5]) case string: - return zap.String(name, h) + return h case fmt.Stringer: - return zap.String(name, h.String()) + return h.String() default: if isNil(k) { - return zap.String(name, "") + return "" } - panic("unexpected type: " + reflect.TypeOf(k).String()) + return fmt.Sprintf("", k) } } + +func HexField(name string, k any) zap.Field { + return zap.String(name, hexStr(k)) +} diff --git a/sync2/hashsync/mocks_test.go b/sync2/hashsync/mocks_test.go index 30ecc97651..b8a011e764 100644 --- a/sync2/hashsync/mocks_test.go +++ b/sync2/hashsync/mocks_test.go @@ -258,18 +258,18 @@ func (c *MockItemStoreCopyCall) DoAndReturn(f func() ItemStore) *MockItemStoreCo } // GetRangeInfo mocks base method. -func (m *MockItemStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) (RangeInfo, error) { +func (m *MockItemStore) GetRangeInfo(ctx context.Context, preceding Iterator, x, y Ordered, count int) (RangeInfo, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRangeInfo", preceding, x, y, count) + ret := m.ctrl.Call(m, "GetRangeInfo", ctx, preceding, x, y, count) ret0, _ := ret[0].(RangeInfo) ret1, _ := ret[1].(error) return ret0, ret1 } // GetRangeInfo indicates an expected call of GetRangeInfo. -func (mr *MockItemStoreMockRecorder) GetRangeInfo(preceding, x, y, count any) *MockItemStoreGetRangeInfoCall { +func (mr *MockItemStoreMockRecorder) GetRangeInfo(ctx, preceding, x, y, count any) *MockItemStoreGetRangeInfoCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRangeInfo", reflect.TypeOf((*MockItemStore)(nil).GetRangeInfo), preceding, x, y, count) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRangeInfo", reflect.TypeOf((*MockItemStore)(nil).GetRangeInfo), ctx, preceding, x, y, count) return &MockItemStoreGetRangeInfoCall{Call: call} } @@ -285,30 +285,30 @@ func (c *MockItemStoreGetRangeInfoCall) Return(arg0 RangeInfo, arg1 error) *Mock } // Do rewrite *gomock.Call.Do -func (c *MockItemStoreGetRangeInfoCall) Do(f func(Iterator, Ordered, Ordered, int) (RangeInfo, error)) *MockItemStoreGetRangeInfoCall { +func (c *MockItemStoreGetRangeInfoCall) Do(f func(context.Context, Iterator, Ordered, Ordered, int) (RangeInfo, error)) *MockItemStoreGetRangeInfoCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreGetRangeInfoCall) DoAndReturn(f func(Iterator, Ordered, Ordered, int) (RangeInfo, error)) *MockItemStoreGetRangeInfoCall { +func (c *MockItemStoreGetRangeInfoCall) DoAndReturn(f func(context.Context, Iterator, Ordered, Ordered, int) (RangeInfo, error)) *MockItemStoreGetRangeInfoCall { c.Call = c.Call.DoAndReturn(f) return c } // Has mocks base method. -func (m *MockItemStore) Has(k Ordered) (bool, error) { +func (m *MockItemStore) Has(ctx context.Context, k Ordered) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Has", k) + ret := m.ctrl.Call(m, "Has", ctx, k) ret0, _ := ret[0].(bool) ret1, _ := ret[1].(error) return ret0, ret1 } // Has indicates an expected call of Has. -func (mr *MockItemStoreMockRecorder) Has(k any) *MockItemStoreHasCall { +func (mr *MockItemStoreMockRecorder) Has(ctx, k any) *MockItemStoreHasCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Has", reflect.TypeOf((*MockItemStore)(nil).Has), k) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Has", reflect.TypeOf((*MockItemStore)(nil).Has), ctx, k) return &MockItemStoreHasCall{Call: call} } @@ -324,30 +324,30 @@ func (c *MockItemStoreHasCall) Return(arg0 bool, arg1 error) *MockItemStoreHasCa } // Do rewrite *gomock.Call.Do -func (c *MockItemStoreHasCall) Do(f func(Ordered) (bool, error)) *MockItemStoreHasCall { +func (c *MockItemStoreHasCall) Do(f func(context.Context, Ordered) (bool, error)) *MockItemStoreHasCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreHasCall) DoAndReturn(f func(Ordered) (bool, error)) *MockItemStoreHasCall { +func (c *MockItemStoreHasCall) DoAndReturn(f func(context.Context, Ordered) (bool, error)) *MockItemStoreHasCall { c.Call = c.Call.DoAndReturn(f) return c } // Min mocks base method. -func (m *MockItemStore) Min() (Iterator, error) { +func (m *MockItemStore) Min(ctx context.Context) (Iterator, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Min") + ret := m.ctrl.Call(m, "Min", ctx) ret0, _ := ret[0].(Iterator) ret1, _ := ret[1].(error) return ret0, ret1 } // Min indicates an expected call of Min. -func (mr *MockItemStoreMockRecorder) Min() *MockItemStoreMinCall { +func (mr *MockItemStoreMockRecorder) Min(ctx any) *MockItemStoreMinCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Min", reflect.TypeOf((*MockItemStore)(nil).Min)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Min", reflect.TypeOf((*MockItemStore)(nil).Min), ctx) return &MockItemStoreMinCall{Call: call} } @@ -363,30 +363,30 @@ func (c *MockItemStoreMinCall) Return(arg0 Iterator, arg1 error) *MockItemStoreM } // Do rewrite *gomock.Call.Do -func (c *MockItemStoreMinCall) Do(f func() (Iterator, error)) *MockItemStoreMinCall { +func (c *MockItemStoreMinCall) Do(f func(context.Context) (Iterator, error)) *MockItemStoreMinCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreMinCall) DoAndReturn(f func() (Iterator, error)) *MockItemStoreMinCall { +func (c *MockItemStoreMinCall) DoAndReturn(f func(context.Context) (Iterator, error)) *MockItemStoreMinCall { c.Call = c.Call.DoAndReturn(f) return c } // SplitRange mocks base method. -func (m *MockItemStore) SplitRange(preceding Iterator, x, y Ordered, count int) (SplitInfo, error) { +func (m *MockItemStore) SplitRange(ctx context.Context, preceding Iterator, x, y Ordered, count int) (SplitInfo, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SplitRange", preceding, x, y, count) + ret := m.ctrl.Call(m, "SplitRange", ctx, preceding, x, y, count) ret0, _ := ret[0].(SplitInfo) ret1, _ := ret[1].(error) return ret0, ret1 } // SplitRange indicates an expected call of SplitRange. -func (mr *MockItemStoreMockRecorder) SplitRange(preceding, x, y, count any) *MockItemStoreSplitRangeCall { +func (mr *MockItemStoreMockRecorder) SplitRange(ctx, preceding, x, y, count any) *MockItemStoreSplitRangeCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SplitRange", reflect.TypeOf((*MockItemStore)(nil).SplitRange), preceding, x, y, count) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SplitRange", reflect.TypeOf((*MockItemStore)(nil).SplitRange), ctx, preceding, x, y, count) return &MockItemStoreSplitRangeCall{Call: call} } @@ -402,13 +402,13 @@ func (c *MockItemStoreSplitRangeCall) Return(arg0 SplitInfo, arg1 error) *MockIt } // Do rewrite *gomock.Call.Do -func (c *MockItemStoreSplitRangeCall) Do(f func(Iterator, Ordered, Ordered, int) (SplitInfo, error)) *MockItemStoreSplitRangeCall { +func (c *MockItemStoreSplitRangeCall) Do(f func(context.Context, Iterator, Ordered, Ordered, int) (SplitInfo, error)) *MockItemStoreSplitRangeCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreSplitRangeCall) DoAndReturn(f func(Iterator, Ordered, Ordered, int) (SplitInfo, error)) *MockItemStoreSplitRangeCall { +func (c *MockItemStoreSplitRangeCall) DoAndReturn(f func(context.Context, Iterator, Ordered, Ordered, int) (SplitInfo, error)) *MockItemStoreSplitRangeCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -541,18 +541,18 @@ func (m *MockSyncBase) EXPECT() *MockSyncBaseMockRecorder { } // Count mocks base method. -func (m *MockSyncBase) Count() (int, error) { +func (m *MockSyncBase) Count(ctx context.Context) (int, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Count") + ret := m.ctrl.Call(m, "Count", ctx) ret0, _ := ret[0].(int) ret1, _ := ret[1].(error) return ret0, ret1 } // Count indicates an expected call of Count. -func (mr *MockSyncBaseMockRecorder) Count() *MockSyncBaseCountCall { +func (mr *MockSyncBaseMockRecorder) Count(ctx any) *MockSyncBaseCountCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Count", reflect.TypeOf((*MockSyncBase)(nil).Count)) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Count", reflect.TypeOf((*MockSyncBase)(nil).Count), ctx) return &MockSyncBaseCountCall{Call: call} } @@ -568,13 +568,13 @@ func (c *MockSyncBaseCountCall) Return(arg0 int, arg1 error) *MockSyncBaseCountC } // Do rewrite *gomock.Call.Do -func (c *MockSyncBaseCountCall) Do(f func() (int, error)) *MockSyncBaseCountCall { +func (c *MockSyncBaseCountCall) Do(f func(context.Context) (int, error)) *MockSyncBaseCountCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockSyncBaseCountCall) DoAndReturn(f func() (int, error)) *MockSyncBaseCountCall { +func (c *MockSyncBaseCountCall) DoAndReturn(f func(context.Context) (int, error)) *MockSyncBaseCountCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/sync2/hashsync/multipeer.go b/sync2/hashsync/multipeer.go index 3978d1a9ef..8bc31daf3e 100644 --- a/sync2/hashsync/multipeer.go +++ b/sync2/hashsync/multipeer.go @@ -179,7 +179,7 @@ func (mpr *MultiPeerReconciler) probePeers(ctx context.Context, syncPeers []p2p. zap.Int("count", pr.Count)) } - c, err := mpr.syncBase.Count() + c, err := mpr.syncBase.Count(ctx) if err != nil { return s, err } diff --git a/sync2/hashsync/multipeer_test.go b/sync2/hashsync/multipeer_test.go index 271a204476..58fd2ea1fe 100644 --- a/sync2/hashsync/multipeer_test.go +++ b/sync2/hashsync/multipeer_test.go @@ -131,7 +131,7 @@ func TestMultiPeerSync(t *testing.T) { mt.addPeers(10) // Advance by peer wait time. After that, 6 peers will be selected // randomly and probed - mt.syncBase.EXPECT().Count().Return(50, nil).AnyTimes() + mt.syncBase.EXPECT().Count(gomock.Any()).Return(50, nil).AnyTimes() for i := 0; i < numSyncs; i++ { mt.expectProbe(6, ProbeResult{ FP: "foo", @@ -159,7 +159,7 @@ func TestMultiPeerSync(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(10) - mt.syncBase.EXPECT().Count().Return(100, nil).AnyTimes() + mt.syncBase.EXPECT().Count(gomock.Any()).Return(100, nil).AnyTimes() for i := 0; i < numSyncs; i++ { mt.expectProbe(6, ProbeResult{ FP: "foo", @@ -179,7 +179,7 @@ func TestMultiPeerSync(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(1) - mt.syncBase.EXPECT().Count().Return(50, nil).AnyTimes() + mt.syncBase.EXPECT().Count(gomock.Any()).Return(50, nil).AnyTimes() for i := 0; i < numSyncs; i++ { mt.expectProbe(1, ProbeResult{ FP: "foo", @@ -199,7 +199,7 @@ func TestMultiPeerSync(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(10) - mt.syncBase.EXPECT().Count().Return(100, nil).AnyTimes() + mt.syncBase.EXPECT().Count(gomock.Any()).Return(100, nil).AnyTimes() mt.syncBase.EXPECT().Probe(gomock.Any(), gomock.Any()). Return(ProbeResult{}, errors.New("probe failed")) mt.expectProbe(5, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) @@ -214,7 +214,7 @@ func TestMultiPeerSync(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(10) - mt.syncBase.EXPECT().Count().Return(100, nil).AnyTimes() + mt.syncBase.EXPECT().Count(gomock.Any()).Return(100, nil).AnyTimes() for i := 0; i < numSyncs; i++ { mt.expectProbe(6, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) mt.expectFullSync(6, 3) @@ -230,7 +230,7 @@ func TestMultiPeerSync(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(10) - mt.syncBase.EXPECT().Count().Return(100, nil).AnyTimes() + mt.syncBase.EXPECT().Count(gomock.Any()).Return(100, nil).AnyTimes() for i := 0; i < numSyncs; i++ { mt.expectProbe(6, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) mt.expectFullSync(6, 0) @@ -246,7 +246,7 @@ func TestMultiPeerSync(t *testing.T) { mt := newMultiPeerSyncTester(t) ctx := mt.start() mt.addPeers(10) - mt.syncBase.EXPECT().Count().Return(100, nil).AnyTimes() + mt.syncBase.EXPECT().Count(gomock.Any()).Return(100, nil).AnyTimes() mt.expectProbe(6, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) mt.syncRunner.EXPECT().fullSync(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, peers []p2p.Peer) error { diff --git a/sync2/hashsync/rangesync.go b/sync2/hashsync/rangesync.go index 9def50e0a6..7abb3a8824 100644 --- a/sync2/hashsync/rangesync.go +++ b/sync2/hashsync/rangesync.go @@ -250,7 +250,7 @@ func (rsr *RangeSetReconciler) processSubrange(c Conduit, x, y Ordered, info Ran return nil } -func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg SyncMessage) (done bool, err error) { +func (rsr *RangeSetReconciler) handleMessage(ctx context.Context, c Conduit, preceding Iterator, msg SyncMessage) (done bool, err error) { rsr.log.Debug("handleMessage", IteratorField("preceding", preceding), zap.String("msg", SyncMessageToString(msg))) x := msg.X() @@ -259,14 +259,14 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg if msg.Type() == MessageTypeEmptySet || (msg.Type() == MessageTypeProbe && x == nil && y == nil) { // The peer has no items at all so didn't // even send X & Y (SendEmptySet) - it, err := rsr.is.Min() + it, err := rsr.is.Min(ctx) if err != nil { return false, err } if it == nil { // We don't have any items at all, too if msg.Type() == MessageTypeProbe { - info, err := rsr.is.GetRangeInfo(preceding, nil, nil, -1) + info, err := rsr.is.GetRangeInfo(ctx, preceding, nil, nil, -1) if err != nil { return false, err } @@ -288,7 +288,7 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg } else if x == nil || y == nil { return false, errors.New("bad X or Y") } - info, err := rsr.is.GetRangeInfo(preceding, x, y, -1) + info, err := rsr.is.GetRangeInfo(ctx, preceding, x, y, -1) if err != nil { return false, err } @@ -377,7 +377,7 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg rsr.log.Debug("handleMessage: PRE split range", HexField("x", x), HexField("y", y), zap.Int("countArg", count)) - si, err := rsr.is.SplitRange(preceding, x, y, count) + si, err := rsr.is.SplitRange(ctx, preceding, x, y, count) if err != nil { return false, err } @@ -406,8 +406,8 @@ func (rsr *RangeSetReconciler) handleMessage(c Conduit, preceding Iterator, msg return done, nil } -func (rsr *RangeSetReconciler) Initiate(c Conduit) error { - it, err := rsr.is.Min() +func (rsr *RangeSetReconciler) Initiate(ctx context.Context, c Conduit) error { + it, err := rsr.is.Min(ctx) if err != nil { return err } @@ -418,10 +418,10 @@ func (rsr *RangeSetReconciler) Initiate(c Conduit) error { return err } } - return rsr.InitiateBounded(c, x, x) + return rsr.InitiateBounded(ctx, c, x, x) } -func (rsr *RangeSetReconciler) InitiateBounded(c Conduit, x, y Ordered) error { +func (rsr *RangeSetReconciler) InitiateBounded(ctx context.Context, c Conduit, x, y Ordered) error { rsr.log.Debug("initiate", HexField("x", x), HexField("y", y)) if x == nil { rsr.log.Debug("initiate: send empty set") @@ -429,7 +429,7 @@ func (rsr *RangeSetReconciler) InitiateBounded(c Conduit, x, y Ordered) error { return err } } else { - info, err := rsr.is.GetRangeInfo(nil, x, y, -1) + info, err := rsr.is.GetRangeInfo(ctx, nil, x, y, -1) if err != nil { return err } @@ -478,12 +478,16 @@ func (rsr *RangeSetReconciler) getMessages(c Conduit) (msgs []SyncMessage, done } } -func (rsr *RangeSetReconciler) InitiateProbe(c Conduit) (RangeInfo, error) { - return rsr.InitiateBoundedProbe(c, nil, nil) +func (rsr *RangeSetReconciler) InitiateProbe(ctx context.Context, c Conduit) (RangeInfo, error) { + return rsr.InitiateBoundedProbe(ctx, c, nil, nil) } -func (rsr *RangeSetReconciler) InitiateBoundedProbe(c Conduit, x, y Ordered) (RangeInfo, error) { - info, err := rsr.is.GetRangeInfo(nil, x, y, -1) +func (rsr *RangeSetReconciler) InitiateBoundedProbe( + ctx context.Context, + c Conduit, + x, y Ordered, +) (RangeInfo, error) { + info, err := rsr.is.GetRangeInfo(ctx, nil, x, y, -1) if err != nil { return RangeInfo{}, err } @@ -630,7 +634,7 @@ func (rsr *RangeSetReconciler) Process(ctx context.Context, c Conduit) (done boo // breaks if we capture the iterator from handleMessage and // pass it to the next handleMessage call (it shouldn't) var msgDone bool - msgDone, err = rsr.handleMessage(c, nil, msg) + msgDone, err = rsr.handleMessage(ctx, c, nil, msg) if !msgDone { done = false } @@ -660,8 +664,9 @@ func fingerprintEqual(a, b any) bool { // CollectStoreItems returns the list of items in the given store func CollectStoreItems[K Ordered](is ItemStore) ([]K, error) { + ctx := context.Background() var r []K - it, err := is.Min() + it, err := is.Min(ctx) if err != nil { return nil, err } @@ -672,11 +677,11 @@ func CollectStoreItems[K Ordered](is ItemStore) ([]K, error) { if err != nil { return nil, err } - info, err := is.GetRangeInfo(nil, k, k, -1) + info, err := is.GetRangeInfo(ctx, nil, k, k, -1) if err != nil { return nil, err } - it, err = is.Min() + it, err = is.Min(ctx) if err != nil { return nil, err } diff --git a/sync2/hashsync/rangesync_test.go b/sync2/hashsync/rangesync_test.go index db8cabf6f1..24bfbb0059 100644 --- a/sync2/hashsync/rangesync_test.go +++ b/sync2/hashsync/rangesync_test.go @@ -265,9 +265,14 @@ func (ds *dumbStore) iterFor(s sampleID) Iterator { return ds.iter(n) } -func (ds *dumbStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) (RangeInfo, error) { +func (ds *dumbStore) GetRangeInfo( + ctx context.Context, + preceding Iterator, + x, y Ordered, + count int, +) (RangeInfo, error) { if x == nil && y == nil { - it, err := ds.Min() + it, err := ds.Min(ctx) if err != nil { return RangeInfo{}, err } @@ -315,11 +320,16 @@ func (ds *dumbStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) ( return r, nil } -func (ds *dumbStore) SplitRange(preceding Iterator, x, y Ordered, count int) (SplitInfo, error) { +func (ds *dumbStore) SplitRange( + ctx context.Context, + preceding Iterator, + x, y Ordered, + count int, +) (SplitInfo, error) { if count <= 0 { panic("BUG: bad split count") } - part0, err := ds.GetRangeInfo(preceding, x, y, count) + part0, err := ds.GetRangeInfo(ctx, preceding, x, y, count) if err != nil { return SplitInfo{}, err } @@ -330,7 +340,7 @@ func (ds *dumbStore) SplitRange(preceding Iterator, x, y Ordered, count int) (Sp if err != nil { return SplitInfo{}, err } - part1, err := ds.GetRangeInfo(part0.End.Clone(), middle, y, -1) + part1, err := ds.GetRangeInfo(ctx, part0.End.Clone(), middle, y, -1) if err != nil { return SplitInfo{}, err } @@ -340,7 +350,7 @@ func (ds *dumbStore) SplitRange(preceding Iterator, x, y Ordered, count int) (Sp }, nil } -func (ds *dumbStore) Min() (Iterator, error) { +func (ds *dumbStore) Min(ctx context.Context) (Iterator, error) { if len(ds.keys) == 0 { return nil, nil } @@ -354,7 +364,7 @@ func (ds *dumbStore) Copy() ItemStore { return &dumbStore{keys: slices.Clone(ds.keys)} } -func (ds *dumbStore) Has(k Ordered) (bool, error) { +func (ds *dumbStore) Has(ctx context.Context, k Ordered) (bool, error) { for _, cur := range ds.keys { if k.Compare(cur) == 0 { return true, nil @@ -491,21 +501,26 @@ func (vs *verifiedStore) verifySameRangeInfo(ri1, ri2 RangeInfo) RangeInfo { return ri } -func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) (RangeInfo, error) { +func (vs *verifiedStore) GetRangeInfo( + ctx context.Context, + preceding Iterator, + x, y Ordered, + count int, +) (RangeInfo, error) { var ( ri1, ri2 RangeInfo err error ) if preceding != nil { p := preceding.(verifiedStoreIterator) - ri1, err = vs.knownGood.GetRangeInfo(p.knownGood, x, y, count) + ri1, err = vs.knownGood.GetRangeInfo(ctx, p.knownGood, x, y, count) require.NoError(vs.t, err) - ri2, err = vs.store.GetRangeInfo(p.it, x, y, count) + ri2, err = vs.store.GetRangeInfo(ctx, p.it, x, y, count) require.NoError(vs.t, err) } else { - ri1, err = vs.knownGood.GetRangeInfo(nil, x, y, count) + ri1, err = vs.knownGood.GetRangeInfo(ctx, nil, x, y, count) require.NoError(vs.t, err) - ri2, err = vs.store.GetRangeInfo(nil, x, y, count) + ri2, err = vs.store.GetRangeInfo(ctx, nil, x, y, count) require.NoError(vs.t, err) } // QQQQQ: TODO: if count >= 0 and start+end != nil, do more calls to GetRangeInfo using resulting @@ -513,21 +528,26 @@ func (vs *verifiedStore) GetRangeInfo(preceding Iterator, x, y Ordered, count in return vs.verifySameRangeInfo(ri1, ri2), nil } -func (vs *verifiedStore) SplitRange(preceding Iterator, x, y Ordered, count int) (SplitInfo, error) { +func (vs *verifiedStore) SplitRange( + ctx context.Context, + preceding Iterator, + x, y Ordered, + count int, +) (SplitInfo, error) { var ( si1, si2 SplitInfo err error ) if preceding != nil { p := preceding.(verifiedStoreIterator) - si1, err = vs.knownGood.SplitRange(p.knownGood, x, y, count) + si1, err = vs.knownGood.SplitRange(ctx, p.knownGood, x, y, count) require.NoError(vs.t, err) - si2, err = vs.store.SplitRange(p.it, x, y, count) + si2, err = vs.store.SplitRange(ctx, p.it, x, y, count) require.NoError(vs.t, err) } else { - si1, err = vs.knownGood.SplitRange(nil, x, y, count) + si1, err = vs.knownGood.SplitRange(ctx, nil, x, y, count) require.NoError(vs.t, err) - si2, err = vs.store.SplitRange(nil, x, y, count) + si2, err = vs.store.SplitRange(ctx, nil, x, y, count) require.NoError(vs.t, err) } require.Equal(vs.t, si1.Middle, si2.Middle, "split middle") @@ -540,10 +560,10 @@ func (vs *verifiedStore) SplitRange(preceding Iterator, x, y Ordered, count int) }, nil } -func (vs *verifiedStore) Min() (Iterator, error) { - m1, err := vs.knownGood.Min() +func (vs *verifiedStore) Min(ctx context.Context) (Iterator, error) { + m1, err := vs.knownGood.Min(ctx) require.NoError(vs.t, err) - m2, err := vs.store.Min() + m2, err := vs.store.Min(ctx) require.NoError(vs.t, err) if m1 == nil { require.Nil(vs.t, m2, "Min") @@ -572,10 +592,10 @@ func (vs *verifiedStore) Copy() ItemStore { } } -func (vs *verifiedStore) Has(k Ordered) (bool, error) { - h1, err := vs.knownGood.Has(k) +func (vs *verifiedStore) Has(ctx context.Context, k Ordered) (bool, error) { + h1, err := vs.knownGood.Has(ctx, k) require.NoError(vs.t, err) - h2, err := vs.store.Has(k) + h2, err := vs.store.Has(ctx, k) require.NoError(vs.t, err) require.Equal(vs.t, h1, h2) return h2, nil @@ -655,13 +675,13 @@ func dumpRangeMessages(t *testing.T, msgs []rangeMessage, fmt string, args ...an func runSync(t *testing.T, syncA, syncB *RangeSetReconciler, maxRounds int) (nRounds, nMsg, nItems int) { fc := &fakeConduit{t: t} - require.NoError(t, syncA.Initiate(fc)) + require.NoError(t, syncA.Initiate(context.Background(), fc)) return doRunSync(fc, syncA, syncB, maxRounds) } func runBoundedSync(t *testing.T, syncA, syncB *RangeSetReconciler, x, y Ordered, maxRounds int) (nRounds, nMsg, nItems int) { fc := &fakeConduit{t: t} - require.NoError(t, syncA.InitiateBounded(fc, x, y)) + require.NoError(t, syncA.InitiateBounded(context.Background(), fc, x, y)) return doRunSync(fc, syncA, syncB, maxRounds) } @@ -706,14 +726,14 @@ func doRunSync(fc *fakeConduit, syncA, syncB *RangeSetReconciler, maxRounds int) func runProbe(t *testing.T, from, to *RangeSetReconciler) ProbeResult { fc := &fakeConduit{t: t} - info, err := from.InitiateProbe(fc) + info, err := from.InitiateProbe(context.Background(), fc) require.NoError(t, err) return doRunProbe(fc, from, to, info) } func runBoundedProbe(t *testing.T, from, to *RangeSetReconciler, x, y Ordered) ProbeResult { fc := &fakeConduit{t: t} - info, err := from.InitiateBoundedProbe(fc, x, y) + info, err := from.InitiateBoundedProbe(context.Background(), fc, x, y) require.NoError(t, err) return doRunProbe(fc, from, to, info) } diff --git a/sync2/hashsync/setsyncbase.go b/sync2/hashsync/setsyncbase.go index 15b8d7225f..4c0dcbd57c 100644 --- a/sync2/hashsync/setsyncbase.go +++ b/sync2/hashsync/setsyncbase.go @@ -35,11 +35,11 @@ func NewSetSyncBase(ps PairwiseSyncer, is ItemStore, handler SyncKeyHandler) *Se } // Count implements syncBase. -func (ssb *SetSyncBase) Count() (int, error) { +func (ssb *SetSyncBase) Count(ctx context.Context) (int, error) { // TODO: don't lock on db-bound operations ssb.Lock() defer ssb.Unlock() - it, err := ssb.is.Min() + it, err := ssb.is.Min(ctx) if it == nil || err != nil { return 0, err } @@ -47,7 +47,7 @@ func (ssb *SetSyncBase) Count() (int, error) { if err != nil { return 0, err } - info, err := ssb.is.GetRangeInfo(nil, x, x, -1) + info, err := ssb.is.GetRangeInfo(ctx, nil, x, x, -1) if err != nil { return 0, err } @@ -79,7 +79,7 @@ func (ssb *SetSyncBase) acceptKey(ctx context.Context, k Ordered, p p2p.Peer) er ssb.Lock() defer ssb.Unlock() key := k.(fmt.Stringer).String() - has, err := ssb.is.Has(k) + has, err := ssb.is.Has(ctx, k) if err != nil { return err } diff --git a/sync2/hashsync/setsyncbase_test.go b/sync2/hashsync/setsyncbase_test.go index fe92763f8b..0d9e963d62 100644 --- a/sync2/hashsync/setsyncbase_test.go +++ b/sync2/hashsync/setsyncbase_test.go @@ -140,7 +140,7 @@ func TestSetSyncBase(t *testing.T) { st.ps.EXPECT().SyncStore(ctx, p2p.Peer("p1"), ss, &x, &y) require.NoError(t, ss.Sync(ctx, &x, &y)) - st.is.EXPECT().Has(addedKey) + st.is.EXPECT().Has(gomock.Any(), addedKey) st.is.EXPECT().Add(ctx, addedKey) st.expectSyncStore(ctx, p2p.Peer("p1"), ss, addedKey) require.NoError(t, ss.Sync(ctx, nil, nil)) @@ -164,7 +164,7 @@ func TestSetSyncBase(t *testing.T) { // added just once st.is.EXPECT().Add(ctx, addedKey) for i := 0; i < 3; i++ { - st.is.EXPECT().Has(addedKey) + st.is.EXPECT().Has(gomock.Any(), addedKey) st.expectSyncStore(ctx, p2p.Peer("p1"), ss, addedKey) require.NoError(t, ss.Sync(ctx, nil, nil)) } @@ -186,8 +186,8 @@ func TestSetSyncBase(t *testing.T) { ss := st.ssb.Derive(p2p.Peer("p1")) require.Equal(t, p2p.Peer("p1"), ss.Peer()) - st.is.EXPECT().Has(k1) - st.is.EXPECT().Has(k2) + st.is.EXPECT().Has(gomock.Any(), k1) + st.is.EXPECT().Has(gomock.Any(), k2) st.is.EXPECT().Add(ctx, k1) st.is.EXPECT().Add(ctx, k2) st.expectSyncStore(ctx, p2p.Peer("p1"), ss, k1, k2) @@ -211,8 +211,8 @@ func TestSetSyncBase(t *testing.T) { ss := st.ssb.Derive(p2p.Peer("p1")) require.Equal(t, p2p.Peer("p1"), ss.Peer()) - st.is.EXPECT().Has(k1) - st.is.EXPECT().Has(k2) + st.is.EXPECT().Has(gomock.Any(), k1) + st.is.EXPECT().Has(gomock.Any(), k2) // k1 is not propagated to syncBase due to the handler failure st.is.EXPECT().Add(ctx, k2) st.expectSyncStore(ctx, p2p.Peer("p1"), ss, k1, k2) @@ -240,10 +240,10 @@ func TestSetSyncBase(t *testing.T) { ss.(ItemStore).Add(context.Background(), hs[2]) ss.(ItemStore).Add(context.Background(), hs[3]) // syncer's cloned ItemStore has new key immediately - has, err := ss.(ItemStore).Has(hs[2]) + has, err := ss.(ItemStore).Has(context.Background(), hs[2]) require.NoError(t, err) require.True(t, has) - has, err = ss.(ItemStore).Has(hs[3]) + has, err = ss.(ItemStore).Has(context.Background(), hs[3]) require.True(t, has) handlerErr := errors.New("fail") st.getWaitCh(hs[2]) <- handlerErr @@ -252,9 +252,9 @@ func TestSetSyncBase(t *testing.T) { require.ErrorIs(t, err, handlerErr) require.ElementsMatch(t, hs[2:], handledKeys) // only successfully handled key propagate the syncBase - has, err = is.Has(hs[2]) + has, err = is.Has(context.Background(), hs[2]) require.False(t, has) - has, err = is.Has(hs[3]) + has, err = is.Has(context.Background(), hs[3]) require.True(t, has) }) } diff --git a/sync2/hashsync/sync_tree_store.go b/sync2/hashsync/sync_tree_store.go index 9084c38f23..b7904691dc 100644 --- a/sync2/hashsync/sync_tree_store.go +++ b/sync2/hashsync/sync_tree_store.go @@ -70,9 +70,14 @@ func (sts *SyncTreeStore) iter(ptr SyncTreePointer) Iterator { } // GetRangeInfo implements ItemStore. -func (sts *SyncTreeStore) GetRangeInfo(preceding Iterator, x, y Ordered, count int) (RangeInfo, error) { +func (sts *SyncTreeStore) GetRangeInfo( + ctx context.Context, + preceding Iterator, + x, y Ordered, + count int, +) (RangeInfo, error) { if x == nil && y == nil { - it, err := sts.Min() + it, err := sts.Min(ctx) if err != nil { return RangeInfo{}, err } @@ -115,11 +120,16 @@ func (sts *SyncTreeStore) GetRangeInfo(preceding Iterator, x, y Ordered, count i } // SplitRange implements ItemStore. -func (sts *SyncTreeStore) SplitRange(preceding Iterator, x, y Ordered, count int) (SplitInfo, error) { +func (sts *SyncTreeStore) SplitRange( + ctx context.Context, + preceding Iterator, + x, y Ordered, + count int, +) (SplitInfo, error) { if count <= 0 { panic("BUG: bad split count") } - part0, err := sts.GetRangeInfo(preceding, x, y, count) + part0, err := sts.GetRangeInfo(ctx, preceding, x, y, count) if err != nil { return SplitInfo{}, err } @@ -130,7 +140,7 @@ func (sts *SyncTreeStore) SplitRange(preceding Iterator, x, y Ordered, count int if err != nil { return SplitInfo{}, err } - part1, err := sts.GetRangeInfo(part0.End.Clone(), middle, y, -1) + part1, err := sts.GetRangeInfo(ctx, part0.End.Clone(), middle, y, -1) if err != nil { return SplitInfo{}, err } @@ -141,7 +151,7 @@ func (sts *SyncTreeStore) SplitRange(preceding Iterator, x, y Ordered, count int } // Min implements ItemStore. -func (sts *SyncTreeStore) Min() (Iterator, error) { +func (sts *SyncTreeStore) Min(ctx context.Context) (Iterator, error) { return sts.iter(sts.st.Min()), nil } @@ -154,7 +164,7 @@ func (sts *SyncTreeStore) Copy() ItemStore { } // Has implements ItemStore. -func (sts *SyncTreeStore) Has(k Ordered) (bool, error) { +func (sts *SyncTreeStore) Has(ctx context.Context, k Ordered) (bool, error) { _, found := sts.st.Lookup(k) return found, nil } From 0526bcf107230d5cf2ccdb5eb6c196f6d9ab8353 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 24 Aug 2024 21:08:46 +0400 Subject: [PATCH 61/76] p2p: server fixup --- p2p/server/interface.go | 6 +++++- p2p/server/mocks/mocks.go | 2 +- p2p/server/server.go | 15 +++++++++++---- p2p/server/server_test.go | 8 ++++---- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/p2p/server/interface.go b/p2p/server/interface.go index d1c536c81f..06b81acd29 100644 --- a/p2p/server/interface.go +++ b/p2p/server/interface.go @@ -13,7 +13,7 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/peerinfo" ) -//go:generate mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./interface.go -exclude_interfaces Host +//go:generate mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./interface.go -exclude_interfaces Host,PeerInfoHost // Host is a subset of libp2p Host interface that needs to be implemented to be usable with server. type Host interface { @@ -21,6 +21,10 @@ type Host interface { NewStream(context.Context, peer.ID, ...protocol.ID) (network.Stream, error) Network() network.Network ConnManager() connmgr.ConnManager +} + +type PeerInfoHost interface { + Host PeerInfo() peerinfo.PeerInfo } diff --git a/p2p/server/mocks/mocks.go b/p2p/server/mocks/mocks.go index 821eb1e751..cd5fb485b7 100644 --- a/p2p/server/mocks/mocks.go +++ b/p2p/server/mocks/mocks.go @@ -3,7 +3,7 @@ // // Generated by this command: // -// mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./interface.go -exclude_interfaces Host +// mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./interface.go -exclude_interfaces Host,PeerInfoHost // // Package mocks is a generated GoMock package. diff --git a/p2p/server/server.go b/p2p/server/server.go index e79b52e05d..46a2539fce 100644 --- a/p2p/server/server.go +++ b/p2p/server/server.go @@ -255,6 +255,13 @@ type request struct { received time.Time } +func (s *Server) peerInfo() peerinfo.PeerInfo { + if h, ok := s.h.(PeerInfoHost); ok { + return h.PeerInfo() + } + return nil +} + func (s *Server) Run(ctx context.Context) error { var eg errgroup.Group for { @@ -291,8 +298,8 @@ func (s *Server) Run(ctx context.Context) error { } ok := s.queueHandler(ctx, req.stream) duration := time.Since(req.received) - if s.h.PeerInfo() != nil { - info := s.h.PeerInfo().EnsurePeerInfo(conn.RemotePeer()) + if s.peerInfo() != nil { + info := s.peerInfo().EnsurePeerInfo(conn.RemotePeer()) info.ServerStats.RequestDone(duration, ok) } if s.metrics != nil { @@ -467,8 +474,8 @@ func (s *Server) streamRequest( if err != nil { return nil, nil, err } - if s.h.PeerInfo() != nil { - info = s.h.PeerInfo().EnsurePeerInfo(stream.Conn().RemotePeer()) + if s.peerInfo() != nil { + info = s.peerInfo().EnsurePeerInfo(stream.Conn().RemotePeer()) } dadj := newDeadlineAdjuster(stream, s.timeout, s.hardTimeout) defer func() { diff --git a/p2p/server/server_test.go b/p2p/server/server_test.go index 13d97dbb4c..c6eab164b9 100644 --- a/p2p/server/server_test.go +++ b/p2p/server/server_test.go @@ -121,11 +121,11 @@ func TestServer(t *testing.T) { require.NotEmpty(t, srvConns) require.Equal(t, n+1, srv1.NumAcceptedRequests()) - clientInfo := client.h.PeerInfo().EnsurePeerInfo(srvID) + clientInfo := client.peerInfo().EnsurePeerInfo(srvID) require.Equal(t, 1, clientInfo.ClientStats.SuccessCount()) require.Zero(t, clientInfo.ClientStats.FailureCount()) - serverInfo := srv1.h.PeerInfo().EnsurePeerInfo(mesh.Hosts()[0].ID()) + serverInfo := srv1.peerInfo().EnsurePeerInfo(mesh.Hosts()[0].ID()) require.Eventually(t, func() bool { return serverInfo.ServerStats.SuccessCount() == 1 }, 10*time.Second, 10*time.Millisecond) @@ -152,11 +152,11 @@ func TestServer(t *testing.T) { require.ErrorContains(t, err, testErr.Error()) require.Equal(t, n+1, srv1.NumAcceptedRequests()) - clientInfo := client.h.PeerInfo().EnsurePeerInfo(srvID) + clientInfo := client.peerInfo().EnsurePeerInfo(srvID) require.Zero(t, clientInfo.ClientStats.SuccessCount()) require.Equal(t, 1, clientInfo.ClientStats.FailureCount()) - serverInfo := srv2.h.PeerInfo().EnsurePeerInfo(mesh.Hosts()[0].ID()) + serverInfo := srv2.peerInfo().EnsurePeerInfo(mesh.Hosts()[0].ID()) require.Eventually(t, func() bool { return serverInfo.ServerStats.FailureCount() == 1 }, 10*time.Second, 10*time.Millisecond) From 31e28d03b91a99f77f5dd24fd0b6a3bf74273209 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 24 Aug 2024 21:08:59 +0400 Subject: [PATCH 62/76] sync2: build fixup --- common/types/hashes.go | 1 + sync2/dbsync/dbiter_test.go | 2 +- sync2/dbsync/fptree_test.go | 2 +- sync2/p2p_test.go | 4 ++-- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/common/types/hashes.go b/common/types/hashes.go index dd0243ad21..4281ee9ca5 100644 --- a/common/types/hashes.go +++ b/common/types/hashes.go @@ -10,6 +10,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/util" "github.com/spacemeshos/go-spacemesh/hash" + "github.com/spacemeshos/go-spacemesh/log" ) const ( diff --git a/sync2/dbsync/dbiter_test.go b/sync2/dbsync/dbiter_test.go index 1f5c06bd33..04aa6ea1fb 100644 --- a/sync2/dbsync/dbiter_test.go +++ b/sync2/dbsync/dbiter_test.go @@ -51,7 +51,7 @@ func createDB(t *testing.T, keyLen int) sql.Database { // sql.WithConnections(16), // ) // require.NoError(t, err) - db := sql.InMemory(sql.WithIgnoreSchemaDrift(), sql.WithConnections(16)) + db := sql.InMemory(sql.WithNoCheckSchemaDrift(), sql.WithConnections(16)) t.Cleanup(func() { require.NoError(t, db.Close()) }) diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index 36c02e6265..23cf5e8e01 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -1400,7 +1400,7 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { // counts := make(map[uint64]uint64) // prefLens := make(map[int]int) // QQQQQ: TBD: reenable schema drift check - db, err := statesql.Open("file:"+dbFile, sql.WithIgnoreSchemaDrift()) + db, err := statesql.Open("file:"+dbFile, sql.WithNoCheckSchemaDrift()) require.NoError(t, err) defer db.Close() // _, err = db.Exec("PRAGMA cache_size = -2000000", nil, nil) diff --git a/sync2/p2p_test.go b/sync2/p2p_test.go index c660440657..d8fb13090f 100644 --- a/sync2/p2p_test.go +++ b/sync2/p2p_test.go @@ -70,14 +70,14 @@ func TestP2P(t *testing.T) { for _, hsync := range hs { // use a snapshot to avoid races is := hsync.ItemStore().Copy() - it, err := is.Min() + it, err := is.Min(context.Background()) require.NoError(t, err) if it == nil { return false } k, err := it.Key() require.NoError(t, err) - info, err := is.GetRangeInfo(nil, k, k, -1) + info, err := is.GetRangeInfo(context.Background(), nil, k, k, -1) require.NoError(t, err) if info.Count < numHashes { return false From 1b94a794669adc2ef6f2388d1bb979e219c45784 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 24 Aug 2024 21:10:17 +0400 Subject: [PATCH 63/76] sync2: optimize wire data size --- go.mod | 1 + go.sum | 4 ++ sync2/dbsync/p2p_test.go | 8 +++- sync2/hashsync/handler.go | 33 ++++++++------ sync2/hashsync/handler_test.go | 16 ++++--- sync2/hashsync/wire_helpers.go | 70 ++++++++++++++++++++++++++++++ sync2/hashsync/wire_types.go | 62 +++++++++----------------- sync2/hashsync/wire_types_scale.go | 44 +++++++++---------- 8 files changed, 151 insertions(+), 87 deletions(-) create mode 100644 sync2/hashsync/wire_helpers.go diff --git a/go.mod b/go.mod index b6c969df41..0718e88027 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/prometheus/client_model v0.6.1 github.com/prometheus/common v0.55.0 github.com/quic-go/quic-go v0.46.0 + github.com/rqlite/sql v0.0.0-20240312185922-ffac88a740bd github.com/rs/cors v1.11.0 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/seehuhn/mt19937 v1.0.0 diff --git a/go.sum b/go.sum index ab21bb0071..1aeceebb86 100644 --- a/go.sum +++ b/go.sum @@ -152,6 +152,8 @@ github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+ github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= +github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= +github.com/go-test/deep v1.0.8/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0= github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -551,6 +553,8 @@ github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzG github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rqlite/sql v0.0.0-20240312185922-ffac88a740bd h1:wW6BtayFoKaaDeIvXRE3SZVPOscSKlYD+X3bB749+zk= +github.com/rqlite/sql v0.0.0-20240312185922-ffac88a740bd/go.mod h1:ib9zVtNgRKiGuoMyUqqL5aNpk+r+++YlyiVIkclVqPg= github.com/rs/cors v1.11.0 h1:0B9GE/r9Bc2UxRMMtymBkHTenPkHDv0CW4Y98GBY+po= github.com/rs/cors v1.11.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= diff --git a/sync2/dbsync/p2p_test.go b/sync2/dbsync/p2p_test.go index 26cc9315dc..7c9f36b85b 100644 --- a/sync2/dbsync/p2p_test.go +++ b/sync2/dbsync/p2p_test.go @@ -23,6 +23,8 @@ import ( ) func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { + nr := hashsync.RmmeNumRead() + nw := hashsync.RmmeNumWritten() const maxDepth = 24 log := zaptest.NewLogger(t) t.Logf("QQQQQ: 0") @@ -137,6 +139,7 @@ func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { return pss.SyncStore(WithSQLExec(ctx, tx), srvPeerID, storeB, x, x) })) t.Logf("synced in %v", time.Since(tStart)) + t.Logf("bytes read: %d, bytes written: %d", hashsync.RmmeNumRead()-nr, hashsync.RmmeNumWritten()-nw) // // QQQQQ: rmme // sb = strings.Builder{} @@ -266,8 +269,11 @@ func TestP2P(t *testing.T) { t.Run("random test", func(t *testing.T) { // TODO: increase these values and profile // const nShared = 8000000 - // const nUniqueA = 40000 + // const nUniqueA = 100 // const nUniqueB = 80000 + // const nShared = 8000000 + // const nUniqueA = 10 + // const nUniqueB = 8000 const nShared = 80000 const nUniqueA = 400 const nUniqueB = 800 diff --git a/sync2/hashsync/handler.go b/sync2/hashsync/handler.go index 143b2374ca..1f00bf44f2 100644 --- a/sync2/hashsync/handler.go +++ b/sync2/hashsync/handler.go @@ -24,6 +24,14 @@ var ( numWritten atomic.Int64 ) +func RmmeNumRead() int64 { + return numRead.Load() +} + +func RmmeNumWritten() int64 { + return numWritten.Load() +} + type rmmeCountingStream struct { io.ReadWriter } @@ -131,8 +139,8 @@ func (c *wireConduit) send(m sendable) error { func (c *wireConduit) SendFingerprint(x, y Ordered, fingerprint any, count int) error { return c.send(&FingerprintMessage{ - RangeX: x.(types.Hash32), - RangeY: y.(types.Hash32), + RangeX: OrderedToCompactHash32(x), + RangeY: OrderedToCompactHash32(y), RangeFingerprint: fingerprint.(types.Hash12), NumItems: uint32(count), }) @@ -143,13 +151,16 @@ func (c *wireConduit) SendEmptySet() error { } func (c *wireConduit) SendEmptyRange(x, y Ordered) error { - return c.send(&EmptyRangeMessage{RangeX: x.(types.Hash32), RangeY: y.(types.Hash32)}) + return c.send(&EmptyRangeMessage{ + RangeX: OrderedToCompactHash32(x), + RangeY: OrderedToCompactHash32(y), + }) } func (c *wireConduit) SendRangeContents(x, y Ordered, count int) error { return c.send(&RangeContentsMessage{ - RangeX: x.(types.Hash32), - RangeY: y.(types.Hash32), + RangeX: OrderedToCompactHash32(x), + RangeY: OrderedToCompactHash32(y), NumItems: uint32(count), }) } @@ -195,10 +206,8 @@ func (c *wireConduit) SendProbe(x, y Ordered, fingerprint any, sampleSize int) e } else if x == nil || y == nil { panic("BUG: SendProbe: bad range: just one of the bounds is nil") } - xh := x.(types.Hash32) - yh := y.(types.Hash32) - m.RangeX = &xh - m.RangeY = &yh + m.RangeX = OrderedToCompactHash32(x) + m.RangeY = OrderedToCompactHash32(y) return c.send(m) } @@ -226,10 +235,8 @@ func (c *wireConduit) SendProbeResponse(x, y Ordered, fingerprint any, count, sa } else if x == nil || y == nil { panic("BUG: SendProbe: bad range: just one of the bounds is nil") } - xh := x.(types.Hash32) - yh := y.(types.Hash32) - m.RangeX = &xh - m.RangeY = &yh + m.RangeX = OrderedToCompactHash32(x) + m.RangeY = OrderedToCompactHash32(y) return c.send(m) } diff --git a/sync2/hashsync/handler_test.go b/sync2/hashsync/handler_test.go index 8da457e945..c2252ea1fa 100644 --- a/sync2/hashsync/handler_test.go +++ b/sync2/hashsync/handler_test.go @@ -244,8 +244,8 @@ func TestWireConduit(t *testing.T) { name: "server got 1st request", expectMsgs: []SyncMessage{ &FingerprintMessage{ - RangeX: hs[0], - RangeY: hs[1], + RangeX: Hash32ToCompact(hs[0]), + RangeY: Hash32ToCompact(hs[1]), RangeFingerprint: fp, NumItems: 4, }, @@ -314,13 +314,13 @@ func TestWireConduit(t *testing.T) { name: "client got 1st response", expectMsgs: []SyncMessage{ &RangeContentsMessage{ - RangeX: hs[0], - RangeY: hs[3], + RangeX: Hash32ToCompact(hs[0]), + RangeY: Hash32ToCompact(hs[3]), NumItems: 2, }, &RangeContentsMessage{ - RangeX: hs[3], - RangeY: hs[6], + RangeX: Hash32ToCompact(hs[3]), + RangeY: Hash32ToCompact(hs[6]), NumItems: 2, }, &ItemBatchMessage{ @@ -442,6 +442,8 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) Requester { withClientServer( storeA, getRequester, opts, func(ctx context.Context, client Requester, srvPeerID p2p.Peer) { + nr := RmmeNumRead() + nw := RmmeNumWritten() pss := NewPairwiseStoreSyncer(client, opts) err := pss.SyncStore(ctx, srvPeerID, storeB, nil, nil) require.NoError(t, err) @@ -450,7 +452,7 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) Requester { t.Logf("numSpecific: %d, bytesSent %d, bytesReceived %d", numSpecific, fr.bytesSent, fr.bytesReceived) } - t.Logf("bytes read: %d, bytes written: %d", numRead.Load(), numWritten.Load()) + t.Logf("bytes read: %d, bytes written: %d", RmmeNumRead()-nr, RmmeNumWritten()-nw) }) return true }) diff --git a/sync2/hashsync/wire_helpers.go b/sync2/hashsync/wire_helpers.go new file mode 100644 index 0000000000..4229fc6046 --- /dev/null +++ b/sync2/hashsync/wire_helpers.go @@ -0,0 +1,70 @@ +package hashsync + +import ( + "github.com/spacemeshos/go-scale" + "github.com/spacemeshos/go-spacemesh/common/types" +) + +type CompactHash32 struct { + H *types.Hash32 +} + +// DecodeScale implements scale.Decodable. +func (c *CompactHash32) DecodeScale(dec *scale.Decoder) (int, error) { + var h types.Hash32 + b, total, err := scale.DecodeByte(dec) + switch { + case err != nil: + return total, err + case b == 255: + c.H = nil + return total, nil + case b != 0: + n, err := scale.DecodeByteArray(dec, h[:b]) + total += n + if err != nil { + return total, err + } + } + c.H = &h + return total, nil +} + +// EncodeScale implements scale.Encodable. +func (c *CompactHash32) EncodeScale(enc *scale.Encoder) (int, error) { + if c.H == nil { + return scale.EncodeByte(enc, 255) + } + + b := byte(31) + for b = 32; b > 0; b-- { + if c.H[b-1] != 0 { + break + } + } + + total, err := scale.EncodeByte(enc, b) + if b == 0 || err != nil { + return total, err + } + + n, err := scale.EncodeByteArray(enc, c.H[:b]) + total += n + return total, err +} + +func (c *CompactHash32) ToOrdered() Ordered { + if c.H == nil { + return nil + } + return *c.H +} + +func Hash32ToCompact(h types.Hash32) CompactHash32 { + return CompactHash32{H: &h} +} + +func OrderedToCompactHash32(h Ordered) CompactHash32 { + hash := h.(types.Hash32) + return CompactHash32{H: &hash} +} diff --git a/sync2/hashsync/wire_types.go b/sync2/hashsync/wire_types.go index bc94a6fa7d..54605228be 100644 --- a/sync2/hashsync/wire_types.go +++ b/sync2/hashsync/wire_types.go @@ -44,14 +44,14 @@ func (*EmptySetMessage) Type() MessageType { return MessageTypeEmptySet } // EmptyRangeMessage notifies the peer that it needs to send all of its items in // the specified range type EmptyRangeMessage struct { - RangeX, RangeY types.Hash32 + RangeX, RangeY CompactHash32 } var _ SyncMessage = &EmptyRangeMessage{} func (m *EmptyRangeMessage) Type() MessageType { return MessageTypeEmptyRange } -func (m *EmptyRangeMessage) X() Ordered { return m.RangeX } -func (m *EmptyRangeMessage) Y() Ordered { return m.RangeY } +func (m *EmptyRangeMessage) X() Ordered { return m.RangeX.ToOrdered() } +func (m *EmptyRangeMessage) Y() Ordered { return m.RangeY.ToOrdered() } func (m *EmptyRangeMessage) Fingerprint() any { return nil } func (m *EmptyRangeMessage) Count() int { return 0 } func (m *EmptyRangeMessage) Keys() []Ordered { return nil } @@ -59,7 +59,7 @@ func (m *EmptyRangeMessage) Keys() []Ordered { return nil } // FingerprintMessage contains range fingerprint for comparison against the // peer's fingerprint of the range with the same bounds [RangeX, RangeY) type FingerprintMessage struct { - RangeX, RangeY types.Hash32 + RangeX, RangeY CompactHash32 RangeFingerprint types.Hash12 NumItems uint32 } @@ -67,8 +67,8 @@ type FingerprintMessage struct { var _ SyncMessage = &FingerprintMessage{} func (m *FingerprintMessage) Type() MessageType { return MessageTypeFingerprint } -func (m *FingerprintMessage) X() Ordered { return m.RangeX } -func (m *FingerprintMessage) Y() Ordered { return m.RangeY } +func (m *FingerprintMessage) X() Ordered { return m.RangeX.ToOrdered() } +func (m *FingerprintMessage) Y() Ordered { return m.RangeY.ToOrdered() } func (m *FingerprintMessage) Fingerprint() any { return m.RangeFingerprint } func (m *FingerprintMessage) Count() int { return int(m.NumItems) } func (m *FingerprintMessage) Keys() []Ordered { return nil } @@ -77,15 +77,15 @@ func (m *FingerprintMessage) Keys() []Ordered { return nil } // The peer needs to send back any items it has in the same range bounded // by [RangeX, RangeY) type RangeContentsMessage struct { - RangeX, RangeY types.Hash32 + RangeX, RangeY CompactHash32 NumItems uint32 } var _ SyncMessage = &RangeContentsMessage{} func (m *RangeContentsMessage) Type() MessageType { return MessageTypeRangeContents } -func (m *RangeContentsMessage) X() Ordered { return m.RangeX } -func (m *RangeContentsMessage) Y() Ordered { return m.RangeY } +func (m *RangeContentsMessage) X() Ordered { return m.RangeX.ToOrdered() } +func (m *RangeContentsMessage) Y() Ordered { return m.RangeY.ToOrdered() } func (m *RangeContentsMessage) Fingerprint() any { return nil } func (m *RangeContentsMessage) Count() int { return int(m.NumItems) } func (m *RangeContentsMessage) Keys() []Ordered { return nil } @@ -111,7 +111,7 @@ func (m *ItemBatchMessage) Keys() []Ordered { // ProbeMessage requests bounded range fingerprint and count from the peer, // along with a minhash sample if fingerprints differ type ProbeMessage struct { - RangeX, RangeY *types.Hash32 + RangeX, RangeY CompactHash32 RangeFingerprint types.Hash12 SampleSize uint32 } @@ -119,22 +119,11 @@ type ProbeMessage struct { var _ SyncMessage = &ProbeMessage{} func (m *ProbeMessage) Type() MessageType { return MessageTypeProbe } -func (m *ProbeMessage) X() Ordered { - if m.RangeX == nil { - return nil - } - return *m.RangeX -} - -func (m *ProbeMessage) Y() Ordered { - if m.RangeY == nil { - return nil - } - return *m.RangeY -} -func (m *ProbeMessage) Fingerprint() any { return m.RangeFingerprint } -func (m *ProbeMessage) Count() int { return int(m.SampleSize) } -func (m *ProbeMessage) Keys() []Ordered { return nil } +func (m *ProbeMessage) X() Ordered { return m.RangeX.ToOrdered() } +func (m *ProbeMessage) Y() Ordered { return m.RangeY.ToOrdered() } +func (m *ProbeMessage) Fingerprint() any { return m.RangeFingerprint } +func (m *ProbeMessage) Count() int { return int(m.SampleSize) } +func (m *ProbeMessage) Keys() []Ordered { return nil } // MinhashSampleItem represents an item of minhash sample subset type MinhashSampleItem uint32 @@ -171,7 +160,7 @@ func MinhashSampleItemFromHash32(h types.Hash32) MinhashSampleItem { // ProbeResponseMessage is a response to ProbeMessage type ProbeResponseMessage struct { - RangeX, RangeY *types.Hash32 + RangeX, RangeY CompactHash32 RangeFingerprint types.Hash12 NumItems uint32 // NOTE: max must be in sync with maxSampleSize in hashsync/rangesync.go @@ -181,21 +170,10 @@ type ProbeResponseMessage struct { var _ SyncMessage = &ProbeResponseMessage{} func (m *ProbeResponseMessage) Type() MessageType { return MessageTypeProbeResponse } -func (m *ProbeResponseMessage) X() Ordered { - if m.RangeX == nil { - return nil - } - return *m.RangeX -} - -func (m *ProbeResponseMessage) Y() Ordered { - if m.RangeY == nil { - return nil - } - return *m.RangeY -} -func (m *ProbeResponseMessage) Fingerprint() any { return m.RangeFingerprint } -func (m *ProbeResponseMessage) Count() int { return int(m.NumItems) } +func (m *ProbeResponseMessage) X() Ordered { return m.RangeX.ToOrdered() } +func (m *ProbeResponseMessage) Y() Ordered { return m.RangeY.ToOrdered() } +func (m *ProbeResponseMessage) Fingerprint() any { return m.RangeFingerprint } +func (m *ProbeResponseMessage) Count() int { return int(m.NumItems) } func (m *ProbeResponseMessage) Keys() []Ordered { r := make([]Ordered, len(m.Sample)) diff --git a/sync2/hashsync/wire_types_scale.go b/sync2/hashsync/wire_types_scale.go index 4a0343e2fa..90d5032a70 100644 --- a/sync2/hashsync/wire_types_scale.go +++ b/sync2/hashsync/wire_types_scale.go @@ -84,14 +84,14 @@ func (t *EmptySetMessage) DecodeScale(dec *scale.Decoder) (total int, err error) func (t *EmptyRangeMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { { - n, err := scale.EncodeByteArray(enc, t.RangeX[:]) + n, err := t.RangeX.EncodeScale(enc) if err != nil { return total, err } total += n } { - n, err := scale.EncodeByteArray(enc, t.RangeY[:]) + n, err := t.RangeY.EncodeScale(enc) if err != nil { return total, err } @@ -102,14 +102,14 @@ func (t *EmptyRangeMessage) EncodeScale(enc *scale.Encoder) (total int, err erro func (t *EmptyRangeMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { { - n, err := scale.DecodeByteArray(dec, t.RangeX[:]) + n, err := t.RangeX.DecodeScale(dec) if err != nil { return total, err } total += n } { - n, err := scale.DecodeByteArray(dec, t.RangeY[:]) + n, err := t.RangeY.DecodeScale(dec) if err != nil { return total, err } @@ -120,14 +120,14 @@ func (t *EmptyRangeMessage) DecodeScale(dec *scale.Decoder) (total int, err erro func (t *FingerprintMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { { - n, err := scale.EncodeByteArray(enc, t.RangeX[:]) + n, err := t.RangeX.EncodeScale(enc) if err != nil { return total, err } total += n } { - n, err := scale.EncodeByteArray(enc, t.RangeY[:]) + n, err := t.RangeY.EncodeScale(enc) if err != nil { return total, err } @@ -152,14 +152,14 @@ func (t *FingerprintMessage) EncodeScale(enc *scale.Encoder) (total int, err err func (t *FingerprintMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { { - n, err := scale.DecodeByteArray(dec, t.RangeX[:]) + n, err := t.RangeX.DecodeScale(dec) if err != nil { return total, err } total += n } { - n, err := scale.DecodeByteArray(dec, t.RangeY[:]) + n, err := t.RangeY.DecodeScale(dec) if err != nil { return total, err } @@ -185,14 +185,14 @@ func (t *FingerprintMessage) DecodeScale(dec *scale.Decoder) (total int, err err func (t *RangeContentsMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { { - n, err := scale.EncodeByteArray(enc, t.RangeX[:]) + n, err := t.RangeX.EncodeScale(enc) if err != nil { return total, err } total += n } { - n, err := scale.EncodeByteArray(enc, t.RangeY[:]) + n, err := t.RangeY.EncodeScale(enc) if err != nil { return total, err } @@ -210,14 +210,14 @@ func (t *RangeContentsMessage) EncodeScale(enc *scale.Encoder) (total int, err e func (t *RangeContentsMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { { - n, err := scale.DecodeByteArray(dec, t.RangeX[:]) + n, err := t.RangeX.DecodeScale(dec) if err != nil { return total, err } total += n } { - n, err := scale.DecodeByteArray(dec, t.RangeY[:]) + n, err := t.RangeY.DecodeScale(dec) if err != nil { return total, err } @@ -259,14 +259,14 @@ func (t *ItemBatchMessage) DecodeScale(dec *scale.Decoder) (total int, err error func (t *ProbeMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { { - n, err := scale.EncodeOption(enc, t.RangeX) + n, err := t.RangeX.EncodeScale(enc) if err != nil { return total, err } total += n } { - n, err := scale.EncodeOption(enc, t.RangeY) + n, err := t.RangeY.EncodeScale(enc) if err != nil { return total, err } @@ -291,20 +291,18 @@ func (t *ProbeMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { func (t *ProbeMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { { - field, n, err := scale.DecodeOption[types.Hash32](dec) + n, err := t.RangeX.DecodeScale(dec) if err != nil { return total, err } total += n - t.RangeX = field } { - field, n, err := scale.DecodeOption[types.Hash32](dec) + n, err := t.RangeY.DecodeScale(dec) if err != nil { return total, err } total += n - t.RangeY = field } { n, err := scale.DecodeByteArray(dec, t.RangeFingerprint[:]) @@ -326,14 +324,14 @@ func (t *ProbeMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { func (t *ProbeResponseMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { { - n, err := scale.EncodeOption(enc, t.RangeX) + n, err := t.RangeX.EncodeScale(enc) if err != nil { return total, err } total += n } { - n, err := scale.EncodeOption(enc, t.RangeY) + n, err := t.RangeY.EncodeScale(enc) if err != nil { return total, err } @@ -365,20 +363,18 @@ func (t *ProbeResponseMessage) EncodeScale(enc *scale.Encoder) (total int, err e func (t *ProbeResponseMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { { - field, n, err := scale.DecodeOption[types.Hash32](dec) + n, err := t.RangeX.DecodeScale(dec) if err != nil { return total, err } total += n - t.RangeX = field } { - field, n, err := scale.DecodeOption[types.Hash32](dec) + n, err := t.RangeY.DecodeScale(dec) if err != nil { return total, err } total += n - t.RangeY = field } { n, err := scale.DecodeByteArray(dec, t.RangeFingerprint[:]) From b4e9c0db244cf46a0b50799853aa1a5683203acb Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sun, 25 Aug 2024 03:40:38 +0400 Subject: [PATCH 64/76] sync2: add table snapshotting based on rowid --- sql/database.go | 10 ++ sync2/dbsync/dbitemstore.go | 82 ++++----- sync2/dbsync/dbitemstore_test.go | 24 ++- sync2/dbsync/dbiter.go | 13 +- sync2/dbsync/dbiter_test.go | 27 ++- sync2/dbsync/fptree_test.go | 30 +++- sync2/dbsync/p2p_test.go | 8 +- sync2/dbsync/sqlidstore.go | 18 +- sync2/dbsync/sqlidstore_test.go | 8 +- sync2/dbsync/syncedtable.go | 200 +++++++++++++++++++++ sync2/dbsync/syncedtable_test.go | 299 +++++++++++++++++++++++++++++++ 11 files changed, 627 insertions(+), 92 deletions(-) create mode 100644 sync2/dbsync/syncedtable.go create mode 100644 sync2/dbsync/syncedtable_test.go diff --git a/sql/database.go b/sql/database.go index a647086f4a..51f4a7dd1e 100644 --- a/sql/database.go +++ b/sql/database.go @@ -11,6 +11,7 @@ import ( "strings" "sync" "sync/atomic" + "testing" "time" sqlite "github.com/go-llsqlite/crawshaw" @@ -233,6 +234,15 @@ func InMemory(opts ...Opt) *sqliteDatabase { return db } +// InMemoryTest returns an in-mem database for testing and ensures database is closed during `tb.Cleanup`. +func InMemoryTest(tb testing.TB, opts ...Opt) *sqliteDatabase { + // When using empty DB schema, we don't want to check for schema drift due to + // "PRAGMA user_version = 0;" in the initial schema retrieved from the DB. + db := InMemory(append(opts, WithNoCheckSchemaDrift())...) + tb.Cleanup(func() { db.Close() }) + return db +} + // Open database with options. // // Database is opened in WAL mode and pragma synchronous=normal. diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go index dc86035af8..479100152d 100644 --- a/sync2/dbsync/dbitemstore.go +++ b/sync2/dbsync/dbitemstore.go @@ -3,6 +3,7 @@ package dbsync import ( "context" "errors" + "fmt" "sync" "github.com/spacemeshos/go-spacemesh/common/types" @@ -11,57 +12,52 @@ import ( ) type DBItemStore struct { - loadMtx sync.Mutex - loaded bool - db sql.Database - ft *fpTree - loadQuery string - iterQuery string - keyLen int - maxDepth int + loadMtx sync.Mutex + db sql.Database + ft *fpTree + st *SyncedTable + keyLen int + maxDepth int } var _ hashsync.ItemStore = &DBItemStore{} func NewDBItemStore( db sql.Database, - loadQuery, iterQuery string, + st *SyncedTable, keyLen, maxDepth int, ) *DBItemStore { - var np nodePool - dbStore := newDBBackedStore(db, iterQuery, keyLen) + // var np nodePool + // dbStore := newDBBackedStore(db, iterQuery, keyLen) return &DBItemStore{ - db: db, - ft: newFPTree(&np, dbStore, keyLen, maxDepth), - loadQuery: loadQuery, - iterQuery: iterQuery, - keyLen: keyLen, - maxDepth: maxDepth, + db: db, + // ft: newFPTree(&np, dbStore, keyLen, maxDepth), + st: st, + keyLen: keyLen, + maxDepth: maxDepth, } } -func (d *DBItemStore) load(ctx context.Context) error { - db := ContextSQLExec(ctx, d.db) - _, err := db.Exec(d.loadQuery, nil, - func(stmt *sql.Statement) bool { - id := make(KeyBytes, d.keyLen) // TODO: don't allocate new ID - stmt.ColumnBytes(0, id[:]) - d.ft.addStoredHash(id) - return true - }) - return err -} - func (d *DBItemStore) EnsureLoaded(ctx context.Context) error { d.loadMtx.Lock() defer d.loadMtx.Unlock() - if !d.loaded { - if err := d.load(ctx); err != nil { - return err - } - d.loaded = true + if d.ft != nil { + return nil } - return nil + db := ContextSQLExec(ctx, d.db) + sts, err := d.st.snapshot(db) + if err != nil { + return fmt.Errorf("error taking snapshot: %w", err) + } + var np nodePool + dbStore := newDBBackedStore(db, sts, d.keyLen) + d.ft = newFPTree(&np, dbStore, d.keyLen, d.maxDepth) + return sts.loadIDs(db, func(stmt *sql.Statement) bool { + id := make(KeyBytes, d.keyLen) // TODO: don't allocate new ID + stmt.ColumnBytes(0, id[:]) + d.ft.addStoredHash(id) + return true + }) } // Add implements hashsync.ItemStore. @@ -180,18 +176,18 @@ func (d *DBItemStore) Min(ctx context.Context) (hashsync.Iterator, error) { // Copy implements hashsync.ItemStore. func (d *DBItemStore) Copy() hashsync.ItemStore { - if !d.loaded { + d.loadMtx.Lock() + d.loadMtx.Unlock() + if d.ft == nil { // FIXME panic("BUG: can't copy DBItemStore before it's loaded") } return &DBItemStore{ - db: d.db, - ft: d.ft.clone(), - loadQuery: d.loadQuery, - iterQuery: d.iterQuery, - keyLen: d.keyLen, - maxDepth: d.maxDepth, - loaded: true, + db: d.db, + ft: d.ft.clone(), + st: d.st, + keyLen: d.keyLen, + maxDepth: d.maxDepth, } } diff --git a/sync2/dbsync/dbitemstore_test.go b/sync2/dbsync/dbitemstore_test.go index ecb6bae5f0..f1d36e1aa4 100644 --- a/sync2/dbsync/dbitemstore_test.go +++ b/sync2/dbsync/dbitemstore_test.go @@ -11,7 +11,11 @@ import ( func TestDBItemStoreEmpty(t *testing.T) { db := populateDB(t, 32, nil) - s := NewDBItemStore(db, "select id from foo", testQuery, 32, 24) + st := &SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := NewDBItemStore(db, st, 32, 24) ctx := context.Background() it, err := s.Min(ctx) require.NoError(t, err) @@ -48,7 +52,11 @@ func TestDBItemStore(t *testing.T) { } ctx := context.Background() db := populateDB(t, 32, ids) - s := NewDBItemStore(db, "select id from foo", testQuery, 32, 24) + st := &SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := NewDBItemStore(db, st, 32, 24) it, err := s.Min(ctx) require.NoError(t, err) require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", @@ -137,7 +145,11 @@ func TestDBItemStoreAdd(t *testing.T) { util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), } db := populateDB(t, 32, ids) - s := NewDBItemStore(db, "select id from foo", testQuery, 32, 24) + st := &SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := NewDBItemStore(db, st, 32, 24) ctx := context.Background() it, err := s.Min(ctx) require.NoError(t, err) @@ -169,7 +181,11 @@ func TestDBItemStoreCopy(t *testing.T) { util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), } db := populateDB(t, 32, ids) - s := NewDBItemStore(db, "select id from foo", testQuery, 32, 24) + st := &SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := NewDBItemStore(db, st, 32, 24) ctx := context.Background() it, err := s.Min(ctx) require.NoError(t, err) diff --git a/sync2/dbsync/dbiter.go b/sync2/dbsync/dbiter.go index 2099fe3820..88c174780c 100644 --- a/sync2/dbsync/dbiter.go +++ b/sync2/dbsync/dbiter.go @@ -75,7 +75,7 @@ func newLRU() *lru { type dbRangeIterator struct { db sql.Executor from KeyBytes - query string + sts *SyncedTableSnapshot chunkSize int maxChunkSize int chunk []KeyBytes @@ -92,7 +92,7 @@ var _ hashsync.Iterator = &dbRangeIterator{} // If query returns no rows even after starting from zero ID, errEmptySet error is returned. func newDBRangeIterator( db sql.Executor, - query string, + sts *SyncedTableSnapshot, from KeyBytes, maxChunkSize int, lru *lru, @@ -106,7 +106,7 @@ func newDBRangeIterator( return &dbRangeIterator{ db: db, from: from.Clone(), - query: query, + sts: sts, chunkSize: 1, maxChunkSize: maxChunkSize, keyLen: len(from), @@ -156,11 +156,8 @@ func (it *dbRangeIterator) load() error { var ierr, err error found, n := it.loadCached(key) if !found { - _, err = it.db.Exec( - it.query, func(stmt *sql.Statement) { - stmt.BindBytes(1, it.from) - stmt.BindInt64(2, int64(it.chunkSize)) - }, + err := it.sts.loadIDRange( + it.db, it.from, it.chunkSize, func(stmt *sql.Statement) bool { if n >= len(it.chunk) { ierr = errors.New("too many rows") diff --git a/sync2/dbsync/dbiter_test.go b/sync2/dbsync/dbiter_test.go index 04aa6ea1fb..4e6e12a68f 100644 --- a/sync2/dbsync/dbiter_test.go +++ b/sync2/dbsync/dbiter_test.go @@ -8,9 +8,10 @@ import ( "slices" "testing" + "github.com/stretchr/testify/require" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sync2/hashsync" - "github.com/stretchr/testify/require" ) func TestIncID(t *testing.T) { @@ -42,19 +43,7 @@ func TestIncID(t *testing.T) { } func createDB(t *testing.T, keyLen int) sql.Database { - // QQQQQ: FIXME - // tmpDir := t.TempDir() - // t.Logf("QQQQQ: temp dir: %s", tmpDir) - // db, err := sql.Open( - // fmt.Sprintf("file:%s/test.db", tmpDir), - // sql.WithIgnoreSchemaDrift(), - // sql.WithConnections(16), - // ) - // require.NoError(t, err) - db := sql.InMemory(sql.WithNoCheckSchemaDrift(), sql.WithConnections(16)) - t.Cleanup(func() { - require.NoError(t, db.Close()) - }) + db := sql.InMemoryTest(t) _, err := db.Exec(fmt.Sprintf("create table foo(id char(%d) not null primary key)", keyLen), nil, nil) require.NoError(t, err) return db @@ -88,8 +77,6 @@ func populateDB(t *testing.T, keyLen int, content []KeyBytes) sql.Database { return db } -const testQuery = "select id from foo where id >= ? order by id limit ?" - func TestDBRangeIterator(t *testing.T) { db := createDB(t, 4) for _, tc := range []struct { @@ -304,8 +291,14 @@ func TestDBRangeIterator(t *testing.T) { deleteDBItems(t, db) insertDBItems(t, db, tc.items) cache := newLRU() + st := &SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + sts, err := st.snapshot(db) + require.NoError(t, err) for maxChunkSize := 1; maxChunkSize < 12; maxChunkSize++ { - it := newDBRangeIterator(db, testQuery, tc.from, maxChunkSize, cache) + it := newDBRangeIterator(db, sts, tc.from, maxChunkSize, cache) if tc.expErr != nil { _, err := it.Key() require.ErrorIs(t, err, tc.expErr) diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index 23cf5e8e01..5ddfce6d96 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -202,10 +202,14 @@ type fakeIDDBStore struct { var _ idStore = &fakeIDDBStore{} -const fakeIDQuery = "select id from foo where id >= ? order by id limit ?" - -func newFakeATXIDStore(db sql.Database, maxDepth int) *fakeIDDBStore { - return &fakeIDDBStore{db: db, sqlIDStore: newSQLIDStore(db, fakeIDQuery, 32)} +func newFakeATXIDStore(t *testing.T, db sql.Database, maxDepth int) *fakeIDDBStore { + st := &SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + sts, err := st.snapshot(db) + require.NoError(t, err) + return &fakeIDDBStore{db: db, sqlIDStore: newSQLIDStore(db, sts, 32)} } func (s *fakeIDDBStore) registerHash(h KeyBytes) error { @@ -773,7 +777,7 @@ func TestFPTree(t *testing.T) { testFPTree(t, func(maxDepth int) idStore { _, err := db.Exec("delete from foo", nil, nil) require.NoError(t, err) - return newFakeATXIDStore(db, maxDepth) + return newFakeATXIDStore(t, db, maxDepth) }) }) } @@ -1201,7 +1205,7 @@ func TestFPTreeManyItems(t *testing.T) { repeatTestFPTreeManyItems(t, func(maxDepth int) idStore { _, err := db.Exec("delete from foo", nil, nil) require.NoError(t, err) - return newFakeATXIDStore(db, maxDepth) + return newFakeATXIDStore(t, db, maxDepth) }, false, numItems, maxDepth, repeatOuter, repeatInner) }) t.Run("SQL, random bounds", func(t *testing.T) { @@ -1210,7 +1214,7 @@ func TestFPTreeManyItems(t *testing.T) { repeatTestFPTreeManyItems(t, func(maxDepth int) idStore { _, err := db.Exec("delete from foo", nil, nil) require.NoError(t, err) - return newFakeATXIDStore(db, maxDepth) + return newFakeATXIDStore(t, db, maxDepth) }, true, numItems, maxDepth, repeatOuter, repeatInner) }) // TBD: test limits with both random and non-random bounds @@ -1437,7 +1441,17 @@ func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { var stats1 runtime.MemStats runtime.ReadMemStats(&stats1) // TODO: pass extra bind params to the SQL query - store := newSQLIDStore(db, "select id from atxs where id >= ? and epoch = 26 order by id limit ?", 32) + st := &SyncedTable{ + TableName: "atxs", + IDColumn: "id", + Filter: parseSQLExpr(t, "epoch = ?"), + Binder: func(stmt *sql.Statement) { + stmt.BindInt64(1, 26) + }, + } + sts, err := st.snapshot(db) + require.NoError(t, err) + store := newSQLIDStore(db, sts, 32) ft := newFPTree(&np, store, 32, maxDepth) for _, id := range *hs { ft.addHash(id[:]) diff --git a/sync2/dbsync/p2p_test.go b/sync2/dbsync/p2p_test.go index 7c9f36b85b..bacd93a8d1 100644 --- a/sync2/dbsync/p2p_test.go +++ b/sync2/dbsync/p2p_test.go @@ -36,13 +36,17 @@ func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { proto := "itest" t.Logf("QQQQQ: 2") ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) - storeA := NewItemStoreAdapter(NewDBItemStore(dbA, "select id from foo", testQuery, 32, maxDepth)) + st := &SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + storeA := NewItemStoreAdapter(NewDBItemStore(dbA, st, 32, maxDepth)) t.Logf("QQQQQ: 2.1") require.NoError(t, dbA.WithTx(ctx, func(tx sql.Transaction) error { return storeA.s.EnsureLoaded(WithSQLExec(ctx, tx)) })) t.Logf("QQQQQ: 3") - storeB := NewItemStoreAdapter(NewDBItemStore(dbB, "select id from foo", testQuery, 32, maxDepth)) + storeB := NewItemStoreAdapter(NewDBItemStore(dbB, st, 32, maxDepth)) t.Logf("QQQQQ: 3.1") var x *types.Hash32 require.NoError(t, dbB.WithTx(ctx, func(tx sql.Transaction) error { diff --git a/sync2/dbsync/sqlidstore.go b/sync2/dbsync/sqlidstore.go index 1c90aa8743..431f088820 100644 --- a/sync2/dbsync/sqlidstore.go +++ b/sync2/dbsync/sqlidstore.go @@ -16,7 +16,7 @@ func WithSQLExec(ctx context.Context, db sql.Executor) context.Context { return context.WithValue(ctx, dbExecKey{}, db) } -func ContextSQLExec(ctx context.Context, db sql.Database) sql.Executor { +func ContextSQLExec(ctx context.Context, db sql.Executor) sql.Executor { v := ctx.Value(dbExecKey{}) if v == nil { return db @@ -31,25 +31,25 @@ func WithSQLTransaction(ctx context.Context, db sql.Database, toCall func(contex } type sqlIDStore struct { - db sql.Database - query string + db sql.Executor + sts *SyncedTableSnapshot keyLen int cache *lru } var _ idStore = &sqlIDStore{} -func newSQLIDStore(db sql.Database, query string, keyLen int) *sqlIDStore { +func newSQLIDStore(db sql.Executor, sts *SyncedTableSnapshot, keyLen int) *sqlIDStore { return &sqlIDStore{ db: db, - query: query, + sts: sts, keyLen: keyLen, cache: newLRU(), } } func (s *sqlIDStore) clone() idStore { - return newSQLIDStore(s.db, s.query, s.keyLen) + return newSQLIDStore(s.db, s.sts, s.keyLen) } func (s *sqlIDStore) registerHash(h KeyBytes) error { @@ -66,7 +66,7 @@ func (s *sqlIDStore) iter(ctx context.Context, from KeyBytes) hashsync.Iterator if len(from) != s.keyLen { panic("BUG: invalid key length") } - return newDBRangeIterator(ContextSQLExec(ctx, s.db), s.query, from, sqlMaxChunkSize, s.cache) + return newDBRangeIterator(ContextSQLExec(ctx, s.db), s.sts, from, sqlMaxChunkSize, s.cache) } type dbBackedStore struct { @@ -76,9 +76,9 @@ type dbBackedStore struct { var _ idStore = &dbBackedStore{} -func newDBBackedStore(db sql.Database, query string, keyLen int) *dbBackedStore { +func newDBBackedStore(db sql.Executor, sts *SyncedTableSnapshot, keyLen int) *dbBackedStore { return &dbBackedStore{ - sqlIDStore: newSQLIDStore(db, query, keyLen), + sqlIDStore: newSQLIDStore(db, sts, keyLen), inMemIDStore: newInMemIDStore(keyLen), } } diff --git a/sync2/dbsync/sqlidstore_test.go b/sync2/dbsync/sqlidstore_test.go index 3007747e89..87fd42e1df 100644 --- a/sync2/dbsync/sqlidstore_test.go +++ b/sync2/dbsync/sqlidstore_test.go @@ -15,8 +15,14 @@ func TestDBBackedStore(t *testing.T) { {0, 0, 0, 7, 0, 0, 0, 0}, } db := populateDB(t, 8, initialIDs) + st := SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + sts, err := st.snapshot(db) + require.NoError(t, err) verify := func(t *testing.T, ctx context.Context) { - store := newDBBackedStore(db, fakeIDQuery, 8) + store := newDBBackedStore(db, sts, 8) it := store.iter(ctx, KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}) var actualIDs []KeyBytes for range 5 { diff --git a/sync2/dbsync/syncedtable.go b/sync2/dbsync/syncedtable.go new file mode 100644 index 0000000000..c7b75427f1 --- /dev/null +++ b/sync2/dbsync/syncedtable.go @@ -0,0 +1,200 @@ +package dbsync + +import ( + "fmt" + + rsql "github.com/rqlite/sql" + "github.com/spacemeshos/go-spacemesh/sql" +) + +type Binder func(s *sql.Statement) + +type SyncedTable struct { + TableName string + IDColumn string + Filter rsql.Expr + Binder Binder +} + +func (st *SyncedTable) genSelectAll() *rsql.SelectStatement { + return &rsql.SelectStatement{ + Columns: []*rsql.ResultColumn{ + {Expr: &rsql.Ident{Name: st.IDColumn}}, + }, + Source: &rsql.QualifiedTableName{Name: &rsql.Ident{Name: st.TableName}}, + WhereExpr: st.Filter, + } +} + +func (st *SyncedTable) genSelectMaxRowID() *rsql.SelectStatement { + return &rsql.SelectStatement{ + Columns: []*rsql.ResultColumn{ + { + Expr: &rsql.Call{ + Name: &rsql.Ident{Name: "max"}, + Args: []rsql.Expr{&rsql.Ident{Name: "rowid"}}, + }, + }, + }, + Source: &rsql.QualifiedTableName{Name: &rsql.Ident{Name: st.TableName}}, + } +} + +func (st *SyncedTable) genSelectIDRange() *rsql.SelectStatement { + s := st.genSelectAll() + where := &rsql.BinaryExpr{ + X: &rsql.Ident{Name: st.IDColumn}, + Op: rsql.GE, + Y: &rsql.BindExpr{Name: "?"}, + } + if s.WhereExpr != nil { + s.WhereExpr = &rsql.BinaryExpr{ + X: s.WhereExpr, + Op: rsql.AND, + Y: where, + } + } else { + s.WhereExpr = where + } + s.OrderingTerms = []*rsql.OrderingTerm{ + {X: &rsql.Ident{Name: st.IDColumn}}, + } + s.LimitExpr = &rsql.BindExpr{Name: "?"} + return s +} + +func (st *SyncedTable) rowIDCutoff() rsql.Expr { + return &rsql.BinaryExpr{ + X: &rsql.Ident{Name: "rowid"}, + Op: rsql.LE, + Y: &rsql.BindExpr{Name: "?"}, + } +} + +func (st *SyncedTable) genSelectAllRowIDCutoff() *rsql.SelectStatement { + s := st.genSelectAll() + if s.WhereExpr != nil { + s.WhereExpr = &rsql.BinaryExpr{ + X: s.WhereExpr, + Op: rsql.AND, + Y: st.rowIDCutoff(), + } + } else { + s.WhereExpr = st.rowIDCutoff() + } + return s +} + +func (st *SyncedTable) genSelectIDRangeWithRowIDCutoff() *rsql.SelectStatement { + s := st.genSelectIDRange() + s.WhereExpr = &rsql.BinaryExpr{ + X: s.WhereExpr, + Op: rsql.AND, + Y: st.rowIDCutoff(), + } + return s +} + +func (st *SyncedTable) loadMaxRowID(db sql.Executor) (maxRowID int64, err error) { + nRows, err := db.Exec( + st.genSelectMaxRowID().String(), nil, + func(st *sql.Statement) bool { + maxRowID = st.ColumnInt64(0) + return true + }) + if nRows != 1 { + return 0, fmt.Errorf("expected 1 row, got %d", nRows) + } + return maxRowID, err +} + +func (st *SyncedTable) snapshot(db sql.Executor) (*SyncedTableSnapshot, error) { + maxRowID, err := st.loadMaxRowID(db) + if err != nil { + return nil, err + } + return &SyncedTableSnapshot{st, maxRowID}, nil +} + +type SyncedTableSnapshot struct { + *SyncedTable + maxRowID int64 +} + +func (sts *SyncedTableSnapshot) loadIDs( + db sql.Executor, + dec func(stmt *sql.Statement) bool, +) error { + _, err := db.Exec( + sts.genSelectAllRowIDCutoff().String(), + func(stmt *sql.Statement) { + if sts.Binder != nil { + sts.Binder(stmt) + } + stmt.BindInt64(stmt.BindParamCount(), sts.maxRowID) + }, + dec) + return err +} + +func (sts *SyncedTableSnapshot) loadIDRange( + db sql.Executor, + fromID KeyBytes, + limit int, + dec func(stmt *sql.Statement) bool, +) error { + _, err := db.Exec( + sts.genSelectIDRangeWithRowIDCutoff().String(), + func(stmt *sql.Statement) { + if sts.Binder != nil { + sts.Binder(stmt) + } + nParams := stmt.BindParamCount() + stmt.BindBytes(nParams-2, fromID) + stmt.BindInt64(nParams-1, sts.maxRowID) + stmt.BindInt64(nParams, int64(limit)) + }, + dec) + return err +} + +// func (st *SyncedTable) bind(s *sql.Statement) int { +// ofs := 0 +// if st.Filter != nil { +// var v bindCountVisitor +// if err := rsql.Walk(&v, st.Filter); err != nil { +// panic("BUG: bad filter: " + err.Error()) +// } +// ofs = v.numBinds +// switch { +// case ofs == 0 && st.Binder != nil: +// panic("BUG: filter has no binds but a binder is passed") +// case ofs > 0 && st.Binder == nil: +// panic("BUG: filter has binds but no binder is passed") +// } +// st.Binder(s) +// } else if st.Binder != nil { +// panic("BUG: there's no filter but there's a binder") +// } +// return ofs +// } + +// type bindCountVisitor struct { +// numBinds int +// } + +// var _ rsql.Visitor = &bindCountVisitor{} + +// func (b *bindCountVisitor) Visit(node rsql.Node) (w rsql.Visitor, err error) { +// bExpr, ok := node.(*rsql.BindExpr) +// if !ok { +// return b, nil +// } +// if bExpr.Name != "?" { +// return nil, fmt.Errorf("bad bind %s: only ? binds are supported", bExpr.Name) +// } +// b.numBinds++ +// return nil, nil +// } + +// func (b *bindCountVisitor) VisitEnd(node rsql.Node) error { return nil } diff --git a/sync2/dbsync/syncedtable_test.go b/sync2/dbsync/syncedtable_test.go new file mode 100644 index 0000000000..4d091c7d0a --- /dev/null +++ b/sync2/dbsync/syncedtable_test.go @@ -0,0 +1,299 @@ +package dbsync + +import ( + "testing" + + rsql "github.com/rqlite/sql" + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/common/util" + "github.com/spacemeshos/go-spacemesh/sql" +) + +// func TestRmme(t *testing.T) { +// s, err := rsql.ParseExprString("x.id = ? and y.id = ?") +// require.NoError(t, err) +// var v bindCountVisitor +// require.NoError(t, rsql.Walk(&v, s)) +// require.Equal(t, 2, v.numBinds) + +// st, err := rsql.NewParser(strings.NewReader("select max(rowidx) from foo")).ParseStatement() +// require.NoError(t, err) +// spew.Config.DisableMethods = true +// spew.Config.DisablePointerAddresses = true +// defer func() { +// spew.Config.DisableMethods = false +// spew.Config.DisablePointerAddresses = false +// }() +// t.Logf("s: %s\n", spew.Sdump(st)) +// } + +func parseSQLExpr(t *testing.T, s string) rsql.Expr { + expr, err := rsql.ParseExprString(s) + require.NoError(t, err) + return expr +} + +func TestSyncedTable_GenSQL(t *testing.T) { + for _, tc := range []struct { + name string + st SyncedTable + selectAllRC string + selectMaxRowID string + selectIDs string + selectIDsRC string + }{ + { + name: "no filter", + st: SyncedTable{ + TableName: "atxs", + IDColumn: "id", + }, + selectAllRC: `SELECT "id" FROM "atxs" WHERE "rowid" <= ?`, + selectMaxRowID: `SELECT max("rowid") FROM "atxs"`, + selectIDs: `SELECT "id" FROM "atxs" WHERE "id" >= ? ORDER BY "id" LIMIT ?`, + selectIDsRC: `SELECT "id" FROM "atxs" WHERE "id" >= ? AND "rowid" <= ? ` + + `ORDER BY "id" LIMIT ?`, + }, + { + name: "filter", + st: SyncedTable{ + TableName: "atxs", + IDColumn: "id", + Filter: parseSQLExpr(t, "epoch = ?"), + }, + selectAllRC: `SELECT "id" FROM "atxs" WHERE "epoch" = ? AND "rowid" <= ?`, + selectMaxRowID: `SELECT max("rowid") FROM "atxs"`, + selectIDs: `SELECT "id" FROM "atxs" WHERE "epoch" = ? AND "id" >= ? ` + + `ORDER BY "id" LIMIT ?`, + selectIDsRC: `SELECT "id" FROM "atxs" WHERE "epoch" = ? AND "id" >= ? ` + + `AND "rowid" <= ? ORDER BY "id" LIMIT ?`, + }, + } { + require.Equal(t, tc.selectAllRC, tc.st.genSelectAllRowIDCutoff().String()) + require.Equal(t, tc.selectMaxRowID, tc.st.genSelectMaxRowID().String()) + require.Equal(t, tc.selectIDs, tc.st.genSelectIDRange().String()) + require.Equal(t, tc.selectIDsRC, tc.st.genSelectIDRangeWithRowIDCutoff().String()) + } +} + +func TestSyncedTable_LoadIDs(t *testing.T) { + var db sql.Database + type row struct { + id string + epoch int + } + rows := []row{ + {"0451cd036aff0367b07590032da827b516b63a4c1b36ea9a253dcf9a7e084980", 1}, + {"0e75d10a8e98a4307dd9d0427dc1d2ebf9e45b602d159ef62c5da95197159844", 1}, + {"18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", 2}, + {"1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", 2}, + {"1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", 2}, + {"2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", 3}, + {"24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", 3}, + } + + insertRows := func(rows []row) { + for _, r := range rows { + _, err := db.Exec("insert into atxs (id, epoch) values (?, ?)", + func(stmt *sql.Statement) { + stmt.BindBytes(1, util.FromHex(r.id)) + stmt.BindInt64(2, int64(r.epoch)) + }, nil) + require.NoError(t, err) + } + } + + initDB := func() { + db = sql.InMemoryTest(t) + _, err := db.Exec("create table atxs (id char(32) not null primary key, epoch int)", nil, nil) + require.NoError(t, err) + insertRows(rows) + } + + loadIDs := func(sts *SyncedTableSnapshot) []string { + var ids []string + require.NoError(t, sts.loadIDs(db, func(stmt *sql.Statement) bool { + id := make(KeyBytes, stmt.ColumnLen(0)) + stmt.ColumnBytes(0, id) + ids = append(ids, id.String()) + return true + })) + return ids + } + + loadIDRange := func(sts *SyncedTableSnapshot, from KeyBytes, limit int) []string { + var ids []string + require.NoError(t, sts.loadIDRange( + db, from, limit, + func(stmt *sql.Statement) bool { + id := make(KeyBytes, stmt.ColumnLen(0)) + stmt.ColumnBytes(0, id) + ids = append(ids, id.String()) + return true + })) + return ids + } + + t.Run("no filter", func(t *testing.T) { + initDB() + + st := &SyncedTable{ + TableName: "atxs", + IDColumn: "id", + } + + sts1, err := st.snapshot(db) + require.NoError(t, err) + + require.ElementsMatch(t, + []string{ + "0451cd036aff0367b07590032da827b516b63a4c1b36ea9a253dcf9a7e084980", + "0e75d10a8e98a4307dd9d0427dc1d2ebf9e45b602d159ef62c5da95197159844", + "18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + }, + loadIDs(sts1)) + + fromID := util.FromHex("1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55") + require.ElementsMatch(t, + []string{ + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + }, + loadIDRange(sts1, fromID, 100)) + + require.ElementsMatch(t, + []string{ + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + }, loadIDRange(sts1, fromID, 2)) + + insertRows([]row{ + {"2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", 2}, + }) + + // the new row is not included in the first snapshot + require.ElementsMatch(t, + []string{ + "0451cd036aff0367b07590032da827b516b63a4c1b36ea9a253dcf9a7e084980", + "0e75d10a8e98a4307dd9d0427dc1d2ebf9e45b602d159ef62c5da95197159844", + "18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + }, loadIDs(sts1)) + require.ElementsMatch(t, + []string{ + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + }, + loadIDRange(sts1, fromID, 100)) + + sts2, err := st.snapshot(db) + require.NoError(t, err) + + require.ElementsMatch(t, + []string{ + "0451cd036aff0367b07590032da827b516b63a4c1b36ea9a253dcf9a7e084980", + "0e75d10a8e98a4307dd9d0427dc1d2ebf9e45b602d159ef62c5da95197159844", + "18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, loadIDs(sts2)) + require.ElementsMatch(t, + []string{ + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadIDRange(sts2, fromID, 100)) + }) + + t.Run("filter", func(t *testing.T) { + initDB() + st := &SyncedTable{ + TableName: "atxs", + IDColumn: "id", + Filter: parseSQLExpr(t, "epoch = ?"), + Binder: func(stmt *sql.Statement) { + stmt.BindInt64(1, 2) + }, + } + + sts1, err := st.snapshot(db) + require.NoError(t, err) + + require.ElementsMatch(t, + []string{ + "18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + }, + loadIDs(sts1)) + + fromID := util.FromHex("1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55") + require.ElementsMatch(t, + []string{ + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + }, + loadIDRange(sts1, fromID, 100)) + + require.ElementsMatch(t, + []string{ + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + }, loadIDRange(sts1, fromID, 1)) + + insertRows([]row{ + {"2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", 2}, + }) + + // the new row is not included in the first snapshot + require.ElementsMatch(t, + []string{ + "18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + }, + loadIDs(sts1)) + require.ElementsMatch(t, + []string{ + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + }, + loadIDRange(sts1, fromID, 100)) + + sts2, err := st.snapshot(db) + require.NoError(t, err) + + require.ElementsMatch(t, + []string{ + "18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadIDs(sts2)) + require.ElementsMatch(t, + []string{ + "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadIDRange(sts2, fromID, 100)) + }) +} From 25ca1d1ce2f1446c12b4e2630a4bc98d8dc72e71 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sun, 25 Aug 2024 05:15:41 +0400 Subject: [PATCH 65/76] sync2: implement advancing DBItemStore --- sync2/dbsync/dbitemstore.go | 42 +++++++++++++---- sync2/dbsync/dbitemstore_test.go | 80 ++++++++++++++++++++++++++++++-- sync2/dbsync/sqlidstore.go | 4 ++ sync2/dbsync/syncedtable.go | 41 ++++++++++++++++ sync2/dbsync/syncedtable_test.go | 21 +++++++++ 5 files changed, 175 insertions(+), 13 deletions(-) diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go index 479100152d..61c8c18ee8 100644 --- a/sync2/dbsync/dbitemstore.go +++ b/sync2/dbsync/dbitemstore.go @@ -16,6 +16,8 @@ type DBItemStore struct { db sql.Database ft *fpTree st *SyncedTable + snapshot *SyncedTableSnapshot + dbStore *dbBackedStore keyLen int maxDepth int } @@ -38,6 +40,13 @@ func NewDBItemStore( } } +func (d *DBItemStore) decodeID(stmt *sql.Statement) bool { + id := make(KeyBytes, d.keyLen) // TODO: don't allocate new ID + stmt.ColumnBytes(0, id[:]) + d.ft.addStoredHash(id) + return true +} + func (d *DBItemStore) EnsureLoaded(ctx context.Context) error { d.loadMtx.Lock() defer d.loadMtx.Unlock() @@ -45,19 +54,15 @@ func (d *DBItemStore) EnsureLoaded(ctx context.Context) error { return nil } db := ContextSQLExec(ctx, d.db) - sts, err := d.st.snapshot(db) + var err error + d.snapshot, err = d.st.snapshot(db) if err != nil { return fmt.Errorf("error taking snapshot: %w", err) } var np nodePool - dbStore := newDBBackedStore(db, sts, d.keyLen) - d.ft = newFPTree(&np, dbStore, d.keyLen, d.maxDepth) - return sts.loadIDs(db, func(stmt *sql.Statement) bool { - id := make(KeyBytes, d.keyLen) // TODO: don't allocate new ID - stmt.ColumnBytes(0, id[:]) - d.ft.addStoredHash(id) - return true - }) + d.dbStore = newDBBackedStore(db, d.snapshot, d.keyLen) + d.ft = newFPTree(&np, d.dbStore, d.keyLen, d.maxDepth) + return d.snapshot.loadIDs(db, d.decodeID) } // Add implements hashsync.ItemStore. @@ -174,13 +179,30 @@ func (d *DBItemStore) Min(ctx context.Context) (hashsync.Iterator, error) { return it, nil } +func (d *DBItemStore) Advance(ctx context.Context) error { + d.loadMtx.Lock() + d.loadMtx.Unlock() + if d.ft == nil { + // FIXME + panic("BUG: can't advance the DBItemStore before it's loaded") + } + oldSnapshot := d.snapshot + var err error + d.snapshot, err = d.st.snapshot(ContextSQLExec(ctx, d.db)) + if err != nil { + return fmt.Errorf("error taking snapshot: %w", err) + } + d.dbStore.setSnapshot(d.snapshot) + return d.snapshot.loadIDsSince(d.db, oldSnapshot, d.decodeID) +} + // Copy implements hashsync.ItemStore. func (d *DBItemStore) Copy() hashsync.ItemStore { d.loadMtx.Lock() d.loadMtx.Unlock() if d.ft == nil { // FIXME - panic("BUG: can't copy DBItemStore before it's loaded") + panic("BUG: can't copy the DBItemStore before it's loaded") } return &DBItemStore{ db: d.db, diff --git a/sync2/dbsync/dbitemstore_test.go b/sync2/dbsync/dbitemstore_test.go index f1d36e1aa4..974ede4453 100644 --- a/sync2/dbsync/dbitemstore_test.go +++ b/sync2/dbsync/dbitemstore_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestDBItemStoreEmpty(t *testing.T) { +func TestDBItemStore_Empty(t *testing.T) { db := populateDB(t, 32, nil) st := &SyncedTable{ TableName: "foo", @@ -137,7 +137,7 @@ func TestDBItemStore(t *testing.T) { } } -func TestDBItemStoreAdd(t *testing.T) { +func TestDBItemStore_Add(t *testing.T) { ids := []KeyBytes{ util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), @@ -173,7 +173,7 @@ func TestDBItemStoreAdd(t *testing.T) { require.Equal(t, ids[0], itKey(t, info.End)) } -func TestDBItemStoreCopy(t *testing.T) { +func TestDBItemStore_Copy(t *testing.T) { ids := []KeyBytes{ util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), @@ -218,3 +218,77 @@ func TestDBItemStoreCopy(t *testing.T) { require.Equal(t, ids[2], itKey(t, info.Start)) require.Equal(t, ids[0], itKey(t, info.End)) } + +func TestDBItemStore_Advance(t *testing.T) { + ids := []KeyBytes{ + util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), + util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), + util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), + util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), + } + db := populateDB(t, 32, ids) + st := &SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := NewDBItemStore(db, st, 32, 24) + ctx := context.Background() + require.NoError(t, s.EnsureLoaded(ctx)) + + copy := s.Copy() + + info, err := s.GetRangeInfo(ctx, nil, ids[0], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.(fmt.Stringer).String()) + require.Equal(t, ids[0], itKey(t, info.Start)) + require.Equal(t, ids[0], itKey(t, info.End)) + + info, err = copy.GetRangeInfo(ctx, nil, ids[0], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.(fmt.Stringer).String()) + require.Equal(t, ids[0], itKey(t, info.Start)) + require.Equal(t, ids[0], itKey(t, info.End)) + + insertDBItems(t, db, []KeyBytes{ + util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000"), + }) + + info, err = s.GetRangeInfo(ctx, nil, ids[0], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.(fmt.Stringer).String()) + require.Equal(t, ids[0], itKey(t, info.Start)) + require.Equal(t, ids[0], itKey(t, info.End)) + + info, err = copy.GetRangeInfo(ctx, nil, ids[0], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.(fmt.Stringer).String()) + require.Equal(t, ids[0], itKey(t, info.Start)) + require.Equal(t, ids[0], itKey(t, info.End)) + + require.NoError(t, s.Advance(ctx)) + + info, err = s.GetRangeInfo(ctx, nil, ids[0], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 5, info.Count) + require.Equal(t, "642464b773377bbddddddddd", info.Fingerprint.(fmt.Stringer).String()) + require.Equal(t, ids[0], itKey(t, info.Start)) + require.Equal(t, ids[0], itKey(t, info.End)) + + info, err = copy.GetRangeInfo(ctx, nil, ids[0], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.(fmt.Stringer).String()) + require.Equal(t, ids[0], itKey(t, info.Start)) + require.Equal(t, ids[0], itKey(t, info.End)) + + info, err = s.Copy().GetRangeInfo(ctx, nil, ids[0], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 5, info.Count) + require.Equal(t, "642464b773377bbddddddddd", info.Fingerprint.(fmt.Stringer).String()) + require.Equal(t, ids[0], itKey(t, info.Start)) + require.Equal(t, ids[0], itKey(t, info.End)) +} diff --git a/sync2/dbsync/sqlidstore.go b/sync2/dbsync/sqlidstore.go index 431f088820..bee8bf0931 100644 --- a/sync2/dbsync/sqlidstore.go +++ b/sync2/dbsync/sqlidstore.go @@ -69,6 +69,10 @@ func (s *sqlIDStore) iter(ctx context.Context, from KeyBytes) hashsync.Iterator return newDBRangeIterator(ContextSQLExec(ctx, s.db), s.sts, from, sqlMaxChunkSize, s.cache) } +func (s *sqlIDStore) setSnapshot(sts *SyncedTableSnapshot) { + s.sts = sts +} + type dbBackedStore struct { *sqlIDStore *inMemIDStore diff --git a/sync2/dbsync/syncedtable.go b/sync2/dbsync/syncedtable.go index c7b75427f1..75633495a7 100644 --- a/sync2/dbsync/syncedtable.go +++ b/sync2/dbsync/syncedtable.go @@ -85,6 +85,28 @@ func (st *SyncedTable) genSelectAllRowIDCutoff() *rsql.SelectStatement { return s } +func (st *SyncedTable) genSelectAllRowIDCutoffSince() *rsql.SelectStatement { + s := st.genSelectAll() + rowIDBetween := &rsql.BinaryExpr{ + X: &rsql.Ident{Name: "rowid"}, + Op: rsql.BETWEEN, + Y: &rsql.Range{ + X: &rsql.BindExpr{Name: "?"}, + Y: &rsql.BindExpr{Name: "?"}, + }, + } + if s.WhereExpr != nil { + s.WhereExpr = &rsql.BinaryExpr{ + X: s.WhereExpr, + Op: rsql.AND, + Y: rowIDBetween, + } + } else { + s.WhereExpr = rowIDBetween + } + return s +} + func (st *SyncedTable) genSelectIDRangeWithRowIDCutoff() *rsql.SelectStatement { s := st.genSelectIDRange() s.WhereExpr = &rsql.BinaryExpr{ @@ -137,6 +159,25 @@ func (sts *SyncedTableSnapshot) loadIDs( return err } +func (sts *SyncedTableSnapshot) loadIDsSince( + db sql.Executor, + prev *SyncedTableSnapshot, + dec func(stmt *sql.Statement) bool, +) error { + _, err := db.Exec( + sts.genSelectAllRowIDCutoffSince().String(), + func(stmt *sql.Statement) { + if sts.Binder != nil { + sts.Binder(stmt) + } + nParams := stmt.BindParamCount() + stmt.BindInt64(nParams-1, prev.maxRowID+1) + stmt.BindInt64(nParams, sts.maxRowID) + }, + dec) + return err +} + func (sts *SyncedTableSnapshot) loadIDRange( db sql.Executor, fromID KeyBytes, diff --git a/sync2/dbsync/syncedtable_test.go b/sync2/dbsync/syncedtable_test.go index 4d091c7d0a..1c7a6d9eea 100644 --- a/sync2/dbsync/syncedtable_test.go +++ b/sync2/dbsync/syncedtable_test.go @@ -122,6 +122,17 @@ func TestSyncedTable_LoadIDs(t *testing.T) { return ids } + loadIDsSince := func(stsNew, stsOld *SyncedTableSnapshot) []string { + var ids []string + require.NoError(t, stsNew.loadIDsSince(db, stsOld, func(stmt *sql.Statement) bool { + id := make(KeyBytes, stmt.ColumnLen(0)) + stmt.ColumnBytes(0, id) + ids = append(ids, id.String()) + return true + })) + return ids + } + loadIDRange := func(sts *SyncedTableSnapshot, from KeyBytes, limit int) []string { var ids []string require.NoError(t, sts.loadIDRange( @@ -221,6 +232,11 @@ func TestSyncedTable_LoadIDs(t *testing.T) { "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", }, loadIDRange(sts2, fromID, 100)) + require.ElementsMatch(t, + []string{ + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadIDsSince(sts2, sts1)) }) t.Run("filter", func(t *testing.T) { @@ -295,5 +311,10 @@ func TestSyncedTable_LoadIDs(t *testing.T) { "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", }, loadIDRange(sts2, fromID, 100)) + require.ElementsMatch(t, + []string{ + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadIDsSince(sts2, sts1)) }) } From c805ec98be83713d4a607978ac3f22f8fff4e3b8 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Mon, 26 Aug 2024 14:26:45 +0400 Subject: [PATCH 66/76] sync2: add WithMaxDiff for RangeSetReconciler --- sync2/hashsync/handler.go | 10 +-- sync2/hashsync/handler_test.go | 130 +++++++++++++++++++---------- sync2/hashsync/rangesync.go | 85 +++++++++++++++---- sync2/hashsync/rangesync_test.go | 4 +- sync2/hashsync/wire_types.go | 18 ++-- sync2/hashsync/wire_types_scale.go | 4 +- sync2/hashsync/xorsync_test.go | 1 + 7 files changed, 177 insertions(+), 75 deletions(-) diff --git a/sync2/hashsync/handler.go b/sync2/hashsync/handler.go index 1f00bf44f2..f8686a551b 100644 --- a/sync2/hashsync/handler.go +++ b/sync2/hashsync/handler.go @@ -108,8 +108,8 @@ func (c *wireConduit) NextMessage() (SyncMessage, error) { return nil, err } return &m, nil - case MessageTypeProbeResponse: - var m ProbeResponseMessage + case MessageTypeSample: + var m SampleMessage if _, err := codec.DecodeFrom(c.stream, &m); err != nil { return nil, err } @@ -211,8 +211,8 @@ func (c *wireConduit) SendProbe(x, y Ordered, fingerprint any, sampleSize int) e return c.send(m) } -func (c *wireConduit) SendProbeResponse(x, y Ordered, fingerprint any, count, sampleSize int, it Iterator) error { - m := &ProbeResponseMessage{ +func (c *wireConduit) SendSample(x, y Ordered, fingerprint any, count, sampleSize int, it Iterator) error { + m := &SampleMessage{ RangeFingerprint: fingerprint.(types.Hash12), NumItems: uint32(count), Sample: make([]MinhashSampleItem, sampleSize), @@ -224,7 +224,7 @@ func (c *wireConduit) SendProbeResponse(x, y Ordered, fingerprint any, count, sa return err } m.Sample[n] = MinhashSampleItemFromHash32(k.(types.Hash32)) - // fmt.Fprintf(os.Stderr, "QQQQQ: m.Sample[%d] = %s\n", n, m.Sample[n]) + // fmt.Fprintf(os.Stderr, "QQQQQ: SEND: m.Sample[%d] = %s (full %s)\n", n, m.Sample[n], k.(types.Hash32).String()) if err := it.Next(); err != nil { return err } diff --git a/sync2/hashsync/handler_test.go b/sync2/hashsync/handler_test.go index c2252ea1fa..9878dfe030 100644 --- a/sync2/hashsync/handler_test.go +++ b/sync2/hashsync/handler_test.go @@ -411,52 +411,96 @@ func p2pRequesterGetter(t *testing.T) getRequesterFunc { } } -func testWireSync(t *testing.T, getRequester getRequesterFunc) Requester { - cfg := xorSyncTestConfig{ - // large test: - // maxSendRange: 1, - // numTestHashes: 5000000, - // minNumSpecificA: 15000, - // maxNumSpecificA: 20000, - // minNumSpecificB: 15, - // maxNumSpecificB: 20, - - // QQQQQ: restore! - // maxSendRange: 1, - // numTestHashes: 100000, - // minNumSpecificA: 4, - // maxNumSpecificA: 100, - // minNumSpecificB: 4, - // maxNumSpecificB: 100, +type dumbSyncTracer struct { + dumb bool +} - maxSendRange: 1, - numTestHashes: 100, - minNumSpecificA: 2, - maxNumSpecificA: 4, - minNumSpecificB: 2, - maxNumSpecificB: 4, - } - var client Requester - verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []RangeSetReconcilerOption) bool { - opts = append(opts, WithRangeSyncLogger(zaptest.NewLogger(t))) // QQQQQ: TBD: rm - withClientServer( - storeA, getRequester, opts, - func(ctx context.Context, client Requester, srvPeerID p2p.Peer) { - nr := RmmeNumRead() - nw := RmmeNumWritten() - pss := NewPairwiseStoreSyncer(client, opts) - err := pss.SyncStore(ctx, srvPeerID, storeB, nil, nil) - require.NoError(t, err) +var _ Tracer = &dumbSyncTracer{} - if fr, ok := client.(*fakeRequester); ok { - t.Logf("numSpecific: %d, bytesSent %d, bytesReceived %d", - numSpecific, fr.bytesSent, fr.bytesReceived) - } - t.Logf("bytes read: %d, bytes written: %d", RmmeNumRead()-nr, RmmeNumWritten()-nw) +func (tr *dumbSyncTracer) OnDumbSync() { + tr.dumb = true +} + +func testWireSync(t *testing.T, getRequester getRequesterFunc) { + for _, tc := range []struct { + name string + cfg xorSyncTestConfig + dumb bool + }{ + { + name: "non-dumb sync", + cfg: xorSyncTestConfig{ + maxSendRange: 1, + numTestHashes: 1000, + minNumSpecificA: 8, + maxNumSpecificA: 16, + minNumSpecificB: 8, + maxNumSpecificB: 16, + }, + dumb: false, + }, + { + name: "dumb sync", + cfg: xorSyncTestConfig{ + maxSendRange: 1, + numTestHashes: 1000, + minNumSpecificA: 400, + maxNumSpecificA: 500, + minNumSpecificB: 400, + maxNumSpecificB: 500, + }, + dumb: true, + }, + { + name: "larger sync", + cfg: xorSyncTestConfig{ + // even larger test: + // maxSendRange: 1, + // numTestHashes: 5000000, + // minNumSpecificA: 15000, + // maxNumSpecificA: 20000, + // minNumSpecificB: 15, + // maxNumSpecificB: 20, + + maxSendRange: 1, + numTestHashes: 100000, + minNumSpecificA: 4, + maxNumSpecificA: 100, + minNumSpecificB: 4, + maxNumSpecificB: 100, + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + verifyXORSync(t, tc.cfg, func( + storeA, storeB ItemStore, + numSpecific int, + opts []RangeSetReconcilerOption, + ) bool { + var tr dumbSyncTracer + opts = append(opts, + WithTracer(&tr)) + withClientServer( + storeA, getRequester, opts, + func(ctx context.Context, client Requester, srvPeerID p2p.Peer) { + nr := RmmeNumRead() + nw := RmmeNumWritten() + pss := NewPairwiseStoreSyncer(client, opts) + err := pss.SyncStore(ctx, srvPeerID, storeB, nil, nil) + require.NoError(t, err) + + if fr, ok := client.(*fakeRequester); ok { + t.Logf("numSpecific: %d, bytesSent %d, bytesReceived %d", + numSpecific, fr.bytesSent, fr.bytesReceived) + } + t.Logf("bytes read: %d, bytes written: %d", + RmmeNumRead()-nr, RmmeNumWritten()-nw) + }) + require.Equal(t, tc.dumb, tr.dumb, "dumb sync") + return true }) - return true - }) - return client + }) + } } func TestWireSync(t *testing.T) { diff --git a/sync2/hashsync/rangesync.go b/sync2/hashsync/rangesync.go index 7abb3a8824..dff93bb358 100644 --- a/sync2/hashsync/rangesync.go +++ b/sync2/hashsync/rangesync.go @@ -30,7 +30,7 @@ const ( MessageTypeRangeContents MessageTypeItemBatch MessageTypeProbe - MessageTypeProbeResponse + MessageTypeSample ) var messageTypes = []string{ @@ -42,7 +42,7 @@ var messageTypes = []string{ "rangeContents", "itemBatch", "probe", - "probeResponse", + "sample", } func (mtype MessageType) String() string { @@ -132,9 +132,9 @@ type Conduit interface { // the handler must send a sample subset of its items for MinHash // calculation. SendProbe(x, y Ordered, fingerprint any, sampleSize int) error - // SendProbeResponse sends probe response. If 'it' is not nil, - // the corresponding items are included in the sample - SendProbeResponse(x, y Ordered, fingerprint any, count, sampleSize int, it Iterator) error + // SendSample sends a set sample. If 'it' is not nil, the corresponding items are + // included in the sample + SendSample(x, y Ordered, fingerprint any, count, sampleSize int, it Iterator) error // ShortenKey shortens the key for minhash calculation ShortenKey(k Ordered) Ordered } @@ -159,6 +159,16 @@ func WithSampleSize(s int) RangeSetReconcilerOption { } } +// WithMaxDiff sets maximum estimated relative size of the symmetric difference allowed +// for recursive reconciliation, with value of 0 meaning equal sets and 1 meaning +// completely disjoin set. If the estimated value for the sets exceeds MaxDiff value, the +// whole set is transmitted instead of applying the recursive algorithm. +func WithMaxDiff(d float64) RangeSetReconcilerOption { + return func(r *RangeSetReconciler) { + r.maxDiff = d + } +} + // TODO: RangeSetReconciler should sit in a separate package // and WithRangeSyncLogger should be named WithLogger func WithRangeSyncLogger(log *zap.Logger) RangeSetReconcilerOption { @@ -178,6 +188,7 @@ type RangeSetReconciler struct { maxSendRange int itemChunkSize int sampleSize int + maxDiff float64 log *zap.Logger } @@ -187,6 +198,7 @@ func NewRangeSetReconciler(is ItemStore, opts ...RangeSetReconcilerOption) *Rang maxSendRange: DefaultMaxSendRange, itemChunkSize: DefaultItemChunkSize, sampleSize: DefaultSampleSize, + maxDiff: -1, log: zap.NewNop(), } for _, opt := range opts { @@ -274,7 +286,7 @@ func (rsr *RangeSetReconciler) handleMessage(ctx context.Context, c Conduit, pre HexField("fingerpint", info.Fingerprint), zap.Int("count", info.Count), IteratorField("it", it)) - if err := c.SendProbeResponse(x, y, info.Fingerprint, info.Count, 0, it); err != nil { + if err := c.SendSample(x, y, info.Fingerprint, info.Count, 0, it); err != nil { return false, err } } @@ -309,6 +321,8 @@ func (rsr *RangeSetReconciler) handleMessage(ctx context.Context, c Conduit, pre // items in the range to us, but there may be some items on our // side. In the latter case, send only the items themselves b/c // the range doesn't need any further handling by the peer. + // QQQQQ: TBD: do NOT send items if they're present in the original + // range received. if info.Count != 0 { done = false rsr.log.Debug("handleMessage: send items", zap.Int("count", info.Count), @@ -335,12 +349,30 @@ func (rsr *RangeSetReconciler) handleMessage(ctx context.Context, c Conduit, pre // fmt.Fprintf(os.Stderr, "QQQQQ: fingerprint eq %#v %#v\n", // msg.Fingerprint(), info.Fingerprint) } - if err := c.SendProbeResponse(x, y, info.Fingerprint, info.Count, sampleSize, it); err != nil { + if err := c.SendSample(x, y, info.Fingerprint, info.Count, sampleSize, it); err != nil { return false, err } return true, nil - case msg.Type() != MessageTypeFingerprint: + case msg.Type() != MessageTypeFingerprint && msg.Type() != MessageTypeSample: return false, fmt.Errorf("unexpected message type %s", msg.Type()) + case msg.Type() == MessageTypeSample && rsr.maxDiff >= 0: + // The peer has sent a sample of its items in the range to check if + // recursive reconciliation approach is feasible. + pr, err := rsr.handleSample(c, msg, info) + if err != nil { + return false, err + } + if 1-pr.Sim > rsr.maxDiff { + rsr.log.Debug("handleMessage: maxDiff exceeded, sending full range") + if err := c.SendRangeContents(x, y, info.Count); err != nil { + return false, err + } + if err := c.SendItems(info.Count, rsr.itemChunkSize, info.Start); err != nil { + return false, err + } + break + } + fallthrough case fingerprintEqual(info.Fingerprint, msg.Fingerprint()): // The range is synced // case (info.Count+1)/2 <= rsr.maxSendRange: @@ -422,6 +454,8 @@ func (rsr *RangeSetReconciler) Initiate(ctx context.Context, c Conduit) error { } func (rsr *RangeSetReconciler) InitiateBounded(ctx context.Context, c Conduit, x, y Ordered) error { + // QQQQQ: TBD: add a possibility to send a sample for probing. + // When difference is too high, the remote side should reply with its whole [x, y) range rsr.log.Debug("initiate", HexField("x", x), HexField("y", y)) if x == nil { rsr.log.Debug("initiate: send empty set") @@ -444,6 +478,17 @@ func (rsr *RangeSetReconciler) InitiateBounded(ctx context.Context, c Conduit, x if err := c.SendItems(info.Count, rsr.itemChunkSize, info.Start); err != nil { return err } + case rsr.maxDiff >= 0: + // Use minhash to check if syncing this range is feasible + rsr.log.Debug("initiate: send sample", zap.Int("count", info.Count), + zap.Int("sampleSize", rsr.sampleSize)) + it, err := rsr.is.Min(ctx) + if err != nil { + return fmt.Errorf("error getting min element: %v", err) + } + if err := c.SendSample(x, y, info.Fingerprint, info.Count, 0, it); err != nil { + return err + } default: rsr.log.Debug("initiate: send fingerprint", zap.Int("count", info.Count)) if err := c.SendFingerprint(x, y, info.Fingerprint, info.Count); err != nil { @@ -548,6 +593,21 @@ func (rsr *RangeSetReconciler) calcSim(c Conduit, info RangeInfo, remoteSample [ return float64(numEq) / float64(maxSampleSize), nil } +func (rsr *RangeSetReconciler) handleSample( + c Conduit, + msg SyncMessage, + info RangeInfo, +) (pr ProbeResult, err error) { + pr.FP = msg.Fingerprint() + pr.Count = msg.Count() + sim, err := rsr.calcSim(c, info, msg.Keys(), msg.Fingerprint()) + if err != nil { + return ProbeResult{}, fmt.Errorf("database error: %w", err) + } + pr.Sim = sim + return pr, nil +} + func (rsr *RangeSetReconciler) HandleProbeResponse(c Conduit, info RangeInfo) (pr ProbeResult, err error) { // fmt.Fprintf(os.Stderr, "QQQQQ: HandleProbeResponse\n") // defer fmt.Fprintf(os.Stderr, "QQQQQ: HandleProbeResponse done\n") @@ -571,17 +631,14 @@ func (rsr *RangeSetReconciler) HandleProbeResponse(c Conduit, info RangeInfo) (p return ProbeResult{}, errors.New("no range info received during probe") } return pr, nil - case MessageTypeProbeResponse: + case MessageTypeSample: if gotRange { return ProbeResult{}, errors.New("single range message expected") } - pr.FP = msg.Fingerprint() - pr.Count = msg.Count() - sim, err := rsr.calcSim(c, info, msg.Keys(), msg.Fingerprint()) + pr, err = rsr.handleSample(c, msg, info) if err != nil { - return ProbeResult{}, fmt.Errorf("database error: %w", err) + return ProbeResult{}, err } - pr.Sim = sim gotRange = true case MessageTypeEmptySet, MessageTypeEmptyRange: if gotRange { diff --git a/sync2/hashsync/rangesync_test.go b/sync2/hashsync/rangesync_test.go index 24bfbb0059..3f5b4c6961 100644 --- a/sync2/hashsync/rangesync_test.go +++ b/sync2/hashsync/rangesync_test.go @@ -157,9 +157,9 @@ func (fc *fakeConduit) SendProbe(x, y Ordered, fingerprint any, sampleSize int) return nil } -func (fc *fakeConduit) SendProbeResponse(x, y Ordered, fingerprint any, count, sampleSize int, it Iterator) error { +func (fc *fakeConduit) SendSample(x, y Ordered, fingerprint any, count, sampleSize int, it Iterator) error { msg := rangeMessage{ - mtype: MessageTypeProbeResponse, + mtype: MessageTypeSample, x: x, y: y, fp: fingerprint, diff --git a/sync2/hashsync/wire_types.go b/sync2/hashsync/wire_types.go index 54605228be..57930df388 100644 --- a/sync2/hashsync/wire_types.go +++ b/sync2/hashsync/wire_types.go @@ -158,8 +158,8 @@ func MinhashSampleItemFromHash32(h types.Hash32) MinhashSampleItem { return MinhashSampleItem(uint32(h[28])<<24 + uint32(h[29])<<16 + uint32(h[30])<<8 + uint32(h[31])) } -// ProbeResponseMessage is a response to ProbeMessage -type ProbeResponseMessage struct { +// SampleMessage is a sample of set items +type SampleMessage struct { RangeX, RangeY CompactHash32 RangeFingerprint types.Hash12 NumItems uint32 @@ -167,15 +167,15 @@ type ProbeResponseMessage struct { Sample []MinhashSampleItem `scale:"max=1000"` } -var _ SyncMessage = &ProbeResponseMessage{} +var _ SyncMessage = &SampleMessage{} -func (m *ProbeResponseMessage) Type() MessageType { return MessageTypeProbeResponse } -func (m *ProbeResponseMessage) X() Ordered { return m.RangeX.ToOrdered() } -func (m *ProbeResponseMessage) Y() Ordered { return m.RangeY.ToOrdered() } -func (m *ProbeResponseMessage) Fingerprint() any { return m.RangeFingerprint } -func (m *ProbeResponseMessage) Count() int { return int(m.NumItems) } +func (m *SampleMessage) Type() MessageType { return MessageTypeSample } +func (m *SampleMessage) X() Ordered { return m.RangeX.ToOrdered() } +func (m *SampleMessage) Y() Ordered { return m.RangeY.ToOrdered() } +func (m *SampleMessage) Fingerprint() any { return m.RangeFingerprint } +func (m *SampleMessage) Count() int { return int(m.NumItems) } -func (m *ProbeResponseMessage) Keys() []Ordered { +func (m *SampleMessage) Keys() []Ordered { r := make([]Ordered, len(m.Sample)) for n, item := range m.Sample { r[n] = item diff --git a/sync2/hashsync/wire_types_scale.go b/sync2/hashsync/wire_types_scale.go index 90d5032a70..2508250d8c 100644 --- a/sync2/hashsync/wire_types_scale.go +++ b/sync2/hashsync/wire_types_scale.go @@ -322,7 +322,7 @@ func (t *ProbeMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { return total, nil } -func (t *ProbeResponseMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { +func (t *SampleMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { { n, err := t.RangeX.EncodeScale(enc) if err != nil { @@ -361,7 +361,7 @@ func (t *ProbeResponseMessage) EncodeScale(enc *scale.Encoder) (total int, err e return total, nil } -func (t *ProbeResponseMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { +func (t *SampleMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { { n, err := t.RangeX.DecodeScale(dec) if err != nil { diff --git a/sync2/hashsync/xorsync_test.go b/sync2/hashsync/xorsync_test.go index a57837a0fc..369560ce9b 100644 --- a/sync2/hashsync/xorsync_test.go +++ b/sync2/hashsync/xorsync_test.go @@ -62,6 +62,7 @@ type xorSyncTestConfig struct { func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB ItemStore, numSpecific int, opts []RangeSetReconcilerOption) bool) { opts := []RangeSetReconcilerOption{ WithMaxSendRange(cfg.maxSendRange), + WithMaxDiff(0.05), } numSpecificA := rand.Intn(cfg.maxNumSpecificA+1-cfg.minNumSpecificA) + cfg.minNumSpecificA numSpecificB := rand.Intn(cfg.maxNumSpecificB+1-cfg.minNumSpecificB) + cfg.minNumSpecificB From 7363d40cc2d3327d3be931fe436c6d68949c4e71 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Mon, 26 Aug 2024 17:12:40 +0400 Subject: [PATCH 67/76] sync2: optimize sending whole range Don't send back items received from the peer --- sync2/hashsync/handler.go | 27 ++---- sync2/hashsync/handler_test.go | 13 +-- sync2/hashsync/rangesync.go | 153 ++++++++++++++++++++++++------- sync2/hashsync/rangesync_test.go | 26 ++---- 4 files changed, 138 insertions(+), 81 deletions(-) diff --git a/sync2/hashsync/handler.go b/sync2/hashsync/handler.go index f8686a551b..cfd7e77b25 100644 --- a/sync2/hashsync/handler.go +++ b/sync2/hashsync/handler.go @@ -165,27 +165,14 @@ func (c *wireConduit) SendRangeContents(x, y Ordered, count int) error { }) } -func (c *wireConduit) SendItems(count, itemChunkSize int, it Iterator) error { - for i := 0; i < count; i += itemChunkSize { - // TBD: do not use chunks, just stream the contentkeys - var msg ItemBatchMessage - n := min(itemChunkSize, count-i) - for n > 0 { - k, err := it.Key() - if err != nil { - return err - } - msg.ContentKeys = append(msg.ContentKeys, k.(types.Hash32)) - if err := it.Next(); err != nil { - return err - } - n-- - } - if err := c.send(&msg); err != nil { - return err - } +func (c *wireConduit) SendChunk(items []Ordered) error { + msg := ItemBatchMessage{ + ContentKeys: make([]types.Hash32, len(items)), + } + for n, k := range items { + msg.ContentKeys[n] = k.(types.Hash32) } - return nil + return c.send(&msg) } func (c *wireConduit) SendEndRound() error { diff --git a/sync2/hashsync/handler_test.go b/sync2/hashsync/handler_test.go index 9878dfe030..b1eb8a2ed6 100644 --- a/sync2/hashsync/handler_test.go +++ b/sync2/hashsync/handler_test.go @@ -152,8 +152,7 @@ func (fs *fakeSend) send(c Conduit) error { case fs.done: return c.SendDone() case len(fs.items) != 0: - items := slices.Clone(fs.items) - return c.SendItems(len(items), 2, &sliceIterator{s: items}) + return c.SendChunk(slices.Clone(fs.items)) case fs.x == nil || fs.y == nil: return c.SendEmptySet() case fs.count == 0: @@ -274,10 +273,7 @@ func TestWireConduit(t *testing.T) { name: "server got 2nd request", expectMsgs: []SyncMessage{ &ItemBatchMessage{ - ContentKeys: []types.Hash32{hs[9], hs[10]}, - }, - &ItemBatchMessage{ - ContentKeys: []types.Hash32{hs[11]}, + ContentKeys: []types.Hash32{hs[9], hs[10], hs[11]}, }, &EndRoundMessage{}, }, @@ -324,10 +320,7 @@ func TestWireConduit(t *testing.T) { NumItems: 2, }, &ItemBatchMessage{ - ContentKeys: []types.Hash32{hs[4], hs[5]}, - }, - &ItemBatchMessage{ - ContentKeys: []types.Hash32{hs[7], hs[8]}, + ContentKeys: []types.Hash32{hs[4], hs[5], hs[7], hs[8]}, }, &EndRoundMessage{}, }, diff --git a/sync2/hashsync/rangesync.go b/sync2/hashsync/rangesync.go index dff93bb358..a5fba67f14 100644 --- a/sync2/hashsync/rangesync.go +++ b/sync2/hashsync/rangesync.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "os" "reflect" "slices" "strings" @@ -120,8 +119,8 @@ type Conduit interface { // be included in this sync round. The items themselves are sent via // SendItems SendRangeContents(x, y Ordered, count int) error - // SendItems sends just items without any message - SendItems(count, chunkSize int, it Iterator) error + // SendItems sends a chunk of items + SendChunk(items []Ordered) error // SendEndRound sends a message that signifies the end of sync round SendEndRound() error // SendDone sends a message that notifies the peer that sync is finished @@ -159,10 +158,10 @@ func WithSampleSize(s int) RangeSetReconcilerOption { } } -// WithMaxDiff sets maximum estimated relative size of the symmetric difference allowed -// for recursive reconciliation, with value of 0 meaning equal sets and 1 meaning -// completely disjoin set. If the estimated value for the sets exceeds MaxDiff value, the -// whole set is transmitted instead of applying the recursive algorithm. +// WithMaxDiff sets maximum set difference metric (0..1) allowed for recursive +// reconciliation, with value of 0 meaning equal sets and 1 meaning completely disjoint +// set. If the difference metric MaxDiff value, the whole set is transmitted instead of +// applying the recursive algorithm. func WithMaxDiff(d float64) RangeSetReconcilerOption { return func(r *RangeSetReconciler) { r.maxDiff = d @@ -177,6 +176,24 @@ func WithRangeSyncLogger(log *zap.Logger) RangeSetReconcilerOption { } } +// Tracer tracks the reconciliation process +type Tracer interface { + // OnDumbSync is called when the difference metric exceeds maxDiff and dumb + // reconciliation process is used + OnDumbSync() +} + +type nullTracer struct{} + +func (t nullTracer) OnDumbSync() {} + +// WithTracer specifies a tracer for RangeSetReconciler +func WithTracer(t Tracer) RangeSetReconcilerOption { + return func(r *RangeSetReconciler) { + r.tracer = t + } +} + type ProbeResult struct { FP any Count int @@ -189,6 +206,7 @@ type RangeSetReconciler struct { itemChunkSize int sampleSize int maxDiff float64 + tracer Tracer log *zap.Logger } @@ -199,6 +217,7 @@ func NewRangeSetReconciler(is ItemStore, opts ...RangeSetReconcilerOption) *Rang itemChunkSize: DefaultItemChunkSize, sampleSize: DefaultSampleSize, maxDiff: -1, + tracer: nullTracer{}, log: zap.NewNop(), } for _, opt := range opts { @@ -262,7 +281,60 @@ func (rsr *RangeSetReconciler) processSubrange(c Conduit, x, y Ordered, info Ran return nil } -func (rsr *RangeSetReconciler) handleMessage(ctx context.Context, c Conduit, preceding Iterator, msg SyncMessage) (done bool, err error) { +func (rsr *RangeSetReconciler) sendItems( + c Conduit, + count, itemChunkSize int, + it Iterator, + skipKeys []Ordered, +) error { + skipPos := 0 + for i := 0; i < count; i += itemChunkSize { + // TBD: do not use chunks, just stream the contentkeys + var keys []Ordered + n := min(itemChunkSize, count-i) + IN_CHUNK: + for n > 0 { + k, err := it.Key() + if err != nil { + return err + } + for skipPos < len(skipKeys) { + cmp := k.Compare(skipKeys[skipPos]) + if cmp == 0 { + // fmt.Fprintf(os.Stderr, "QQQQQ: skip key %s\n", k.(fmt.Stringer).String()) + // we can skip this item. Advance skipPos as there are no duplicates + skipPos++ + continue IN_CHUNK + } + if cmp < 0 { + // current ley is yet to reach the skipped key at skipPos + break + } + // current item is greater than the skipped key at skipPos, + // so skipPos needs to catch up with the iterator + skipPos++ + } + keys = append(keys, k) + if err := it.Next(); err != nil { + return err + } + n-- + } + if err := c.SendChunk(keys); err != nil { + return err + } + } + return nil +} + +func (rsr *RangeSetReconciler) handleMessage( + ctx context.Context, + c Conduit, + preceding Iterator, + msgs []SyncMessage, + msgPos int, +) (done bool, err error) { + msg := msgs[msgPos] rsr.log.Debug("handleMessage", IteratorField("preceding", preceding), zap.String("msg", SyncMessageToString(msg))) x := msg.X() @@ -321,13 +393,20 @@ func (rsr *RangeSetReconciler) handleMessage(ctx context.Context, c Conduit, pre // items in the range to us, but there may be some items on our // side. In the latter case, send only the items themselves b/c // the range doesn't need any further handling by the peer. - // QQQQQ: TBD: do NOT send items if they're present in the original - // range received. if info.Count != 0 { done = false rsr.log.Debug("handleMessage: send items", zap.Int("count", info.Count), IteratorField("start", info.Start)) - if err := c.SendItems(info.Count, rsr.itemChunkSize, info.Start); err != nil { + var skipKeys []Ordered + if msg.Type() == MessageTypeRangeContents { + for _, m := range msgs[msgPos+1:] { + if m.Type() != MessageTypeItemBatch { + break + } + skipKeys = append(skipKeys, m.Keys()...) + } + } + if err := rsr.sendItems(c, info.Count, rsr.itemChunkSize, info.Start, skipKeys); err != nil { return false, err } } else { @@ -355,6 +434,8 @@ func (rsr *RangeSetReconciler) handleMessage(ctx context.Context, c Conduit, pre return true, nil case msg.Type() != MessageTypeFingerprint && msg.Type() != MessageTypeSample: return false, fmt.Errorf("unexpected message type %s", msg.Type()) + case fingerprintEqual(info.Fingerprint, msg.Fingerprint()): + // The range is synced case msg.Type() == MessageTypeSample && rsr.maxDiff >= 0: // The peer has sent a sample of its items in the range to check if // recursive reconciliation approach is feasible. @@ -363,18 +444,25 @@ func (rsr *RangeSetReconciler) handleMessage(ctx context.Context, c Conduit, pre return false, err } if 1-pr.Sim > rsr.maxDiff { - rsr.log.Debug("handleMessage: maxDiff exceeded, sending full range") + done = false + rsr.tracer.OnDumbSync() + rsr.log.Debug("handleMessage: maxDiff exceeded, sending full range", + zap.Float64("sim", pr.Sim), + zap.Float64("diff", 1-pr.Sim), + zap.Float64("maxDiff", rsr.maxDiff)) if err := c.SendRangeContents(x, y, info.Count); err != nil { return false, err } - if err := c.SendItems(info.Count, rsr.itemChunkSize, info.Start); err != nil { + if err := rsr.sendItems(c, info.Count, rsr.itemChunkSize, info.Start, nil); err != nil { return false, err } break } + rsr.log.Debug("handleMessage: acceptable maxDiff, proceeding with sync", + zap.Float64("sim", pr.Sim), + zap.Float64("diff", 1-pr.Sim), + zap.Float64("maxDiff", rsr.maxDiff)) fallthrough - case fingerprintEqual(info.Fingerprint, msg.Fingerprint()): - // The range is synced // case (info.Count+1)/2 <= rsr.maxSendRange: case info.Count <= rsr.maxSendRange: // The range differs from the peer's version of it, but the it @@ -388,7 +476,7 @@ func (rsr *RangeSetReconciler) handleMessage(ctx context.Context, c Conduit, pre if err := c.SendRangeContents(x, y, info.Count); err != nil { return false, err } - if err := c.SendItems(info.Count, rsr.itemChunkSize, info.Start); err != nil { + if err := rsr.sendItems(c, info.Count, rsr.itemChunkSize, info.Start, nil); err != nil { return false, err } } else { @@ -475,18 +563,15 @@ func (rsr *RangeSetReconciler) InitiateBounded(ctx context.Context, c Conduit, x if err := c.SendRangeContents(x, y, info.Count); err != nil { return err } - if err := c.SendItems(info.Count, rsr.itemChunkSize, info.Start); err != nil { + if err := rsr.sendItems(c, info.Count, rsr.itemChunkSize, info.Start, nil); err != nil { return err } case rsr.maxDiff >= 0: // Use minhash to check if syncing this range is feasible - rsr.log.Debug("initiate: send sample", zap.Int("count", info.Count), + rsr.log.Debug("initiate: send sample", + zap.Int("count", info.Count), zap.Int("sampleSize", rsr.sampleSize)) - it, err := rsr.is.Min(ctx) - if err != nil { - return fmt.Errorf("error getting min element: %v", err) - } - if err := c.SendSample(x, y, info.Fingerprint, info.Count, 0, it); err != nil { + if err := c.SendSample(x, y, info.Fingerprint, info.Count, rsr.sampleSize, info.Start); err != nil { return err } default: @@ -553,17 +638,21 @@ func (rsr *RangeSetReconciler) calcSim(c Conduit, info RangeInfo, remoteSample [ if info.Start == nil { return 0, nil } + // for n, k := range remoteSample { + // fmt.Fprintf(os.Stderr, "QQQQQ: remoteSample[%d] = %s\n", n, k.(MinhashSampleItem).String()) + // } sampleSize := min(info.Count, rsr.sampleSize) localSample := make([]Ordered, sampleSize) it := info.Start for n := 0; n < sampleSize; n++ { - // fmt.Fprintf(os.Stderr, "QQQQQ: n %d sampleSize %d info.Count %d rsr.sampleSize %d %#v\n", - // n, sampleSize, info.Count, rsr.sampleSize, it.Key()) k, err := it.Key() if err != nil { return 0, err } localSample[n] = c.ShortenKey(k) + // fmt.Fprintf(os.Stderr, "QQQQQ: n %d sampleSize %d info.Count %d rsr.sampleSize %d -- %s -> %s\n", + // n, sampleSize, info.Count, rsr.sampleSize, k.(types.Hash32).String(), + // localSample[n].(MinhashSampleItem).String()) if err := it.Next(); err != nil { return 0, err } @@ -576,7 +665,8 @@ func (rsr *RangeSetReconciler) calcSim(c Conduit, info RangeInfo, remoteSample [ d := localSample[m].Compare(remoteSample[n]) switch { case d < 0: - // fmt.Fprintf(os.Stderr, "QQQQQ: less: %v < %s\n", c.ShortenKey(it.Key()), remoteSample[n]) + // k, _ := it.Key() + // fmt.Fprintf(os.Stderr, "QQQQQ: less: %v < %s\n", c.ShortenKey(k), remoteSample[n]) m++ case d == 0: // fmt.Fprintf(os.Stderr, "QQQQQ: eq: %v\n", remoteSample[n]) @@ -584,7 +674,8 @@ func (rsr *RangeSetReconciler) calcSim(c Conduit, info RangeInfo, remoteSample [ m++ n++ default: - // fmt.Fprintf(os.Stderr, "QQQQQ: gt: %v > %s\n", c.ShortenKey(it.Key()), remoteSample[n]) + // k, _ := it.Key() + // fmt.Fprintf(os.Stderr, "QQQQQ: gt: %v > %s\n", c.ShortenKey(k), remoteSample[n]) n++ } } @@ -665,12 +756,12 @@ func (rsr *RangeSetReconciler) Process(ctx context.Context, c Conduit) (done boo if done { // items already added if len(msgs) != 0 { - return false, errors.New("non-item messages with 'done' marker") + return false, errors.New("no extra messages expected along with 'done' message") } return done, nil } done = true - for _, msg := range msgs { + for n, msg := range msgs { if msg.Type() == MessageTypeItemBatch { for _, k := range msg.Keys() { rsr.log.Debug("Process: add item", HexField("item", k)) @@ -691,7 +782,7 @@ func (rsr *RangeSetReconciler) Process(ctx context.Context, c Conduit) (done boo // breaks if we capture the iterator from handleMessage and // pass it to the next handleMessage call (it shouldn't) var msgDone bool - msgDone, err = rsr.handleMessage(ctx, c, nil, msg) + msgDone, err = rsr.handleMessage(ctx, c, nil, msgs, n) if !msgDone { done = false } @@ -748,7 +839,7 @@ func CollectStoreItems[K Ordered](is ItemStore) ([]K, error) { return nil, err } if k == nil { - fmt.Fprintf(os.Stderr, "QQQQQ: it: %#v\n", it) + // fmt.Fprintf(os.Stderr, "QQQQQ: it: %#v\n", it) panic("BUG: iterator exausted before Count reached") } r = append(r, k.(K)) diff --git a/sync2/hashsync/rangesync_test.go b/sync2/hashsync/rangesync_test.go index 3f5b4c6961..0743088bf4 100644 --- a/sync2/hashsync/rangesync_test.go +++ b/sync2/hashsync/rangesync_test.go @@ -113,26 +113,12 @@ func (fc *fakeConduit) SendRangeContents(x, y Ordered, count int) error { return nil } -func (fc *fakeConduit) SendItems(count, itemChunkSize int, it Iterator) error { - require.Positive(fc.t, count) - require.NotZero(fc.t, count) - require.NotNil(fc.t, it) - for i := 0; i < count; i += itemChunkSize { - msg := rangeMessage{mtype: MessageTypeItemBatch} - n := min(itemChunkSize, count-i) - for n > 0 { - k, err := it.Key() - if err != nil { - return fmt.Errorf("getting item: %w", err) - } - msg.keys = append(msg.keys, k) - if err := it.Next(); err != nil { - return err - } - n-- - } - fc.sendMsg(msg) - } +func (fc *fakeConduit) SendChunk(items []Ordered) error { + require.NotEmpty(fc.t, items) + fc.sendMsg(rangeMessage{ + mtype: MessageTypeItemBatch, + keys: items, + }) return nil } From 568402202450c7000d20cf7182dd61e64ed2c8a6 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 27 Aug 2024 05:52:49 +0400 Subject: [PATCH 68/76] sync2: fix fptree test --- sync2/dbsync/fptree_test.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index 5ddfce6d96..3c48d085b7 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -197,6 +197,7 @@ func TestCommonPrefix(t *testing.T) { type fakeIDDBStore struct { db sql.Database + t *testing.T *sqlIDStore } @@ -209,7 +210,7 @@ func newFakeATXIDStore(t *testing.T, db sql.Database, maxDepth int) *fakeIDDBSto } sts, err := st.snapshot(db) require.NoError(t, err) - return &fakeIDDBStore{db: db, sqlIDStore: newSQLIDStore(db, sts, 32)} + return &fakeIDDBStore{db: db, t: t, sqlIDStore: newSQLIDStore(db, sts, 32)} } func (s *fakeIDDBStore) registerHash(h KeyBytes) error { @@ -220,6 +221,9 @@ func (s *fakeIDDBStore) registerHash(h KeyBytes) error { func(stmt *sql.Statement) { stmt.BindBytes(1, h) }, nil) + sts, err := s.sqlIDStore.sts.snapshot(s.db) + require.NoError(s.t, err) + s.sts = sts return err } From 263369fda2f1e90b4846b28abf7d75cc8d538032 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 31 Aug 2024 05:06:09 +0400 Subject: [PATCH 69/76] sync2: add sync priming via sending recent items --- go.mod | 2 +- sync2/dbsync/dbitemstore.go | 15 + sync2/dbsync/dbiter.go | 45 ++- sync2/dbsync/dbiter_test.go | 2 +- sync2/dbsync/p2p_test.go | 305 +++++++++++--- sync2/dbsync/sqlidstore.go | 17 +- sync2/dbsync/sqlidstore_test.go | 59 ++- sync2/dbsync/syncedtable.go | 155 +++++--- sync2/dbsync/syncedtable_test.go | 227 +++++++---- sync2/hashsync/handler.go | 13 + sync2/hashsync/handler_test.go | 135 ++++++- sync2/hashsync/interface.go | 5 + sync2/hashsync/mocks_test.go | 41 ++ sync2/hashsync/rangesync.go | 617 +++++++++++++++++++++-------- sync2/hashsync/rangesync_test.go | 21 +- sync2/hashsync/sync_tree_store.go | 5 + sync2/hashsync/wire_types.go | 25 +- sync2/hashsync/wire_types_scale.go | 23 ++ sync2/hashsync/xorsync_test.go | 15 +- sync2/p2p_test.go | 3 +- 20 files changed, 1331 insertions(+), 399 deletions(-) diff --git a/go.mod b/go.mod index 0718e88027..771fbfc930 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/spacemeshos/go-spacemesh -go 1.22.4 +go 1.23.0 require ( cloud.google.com/go/storage v1.43.0 diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go index 61c8c18ee8..9ab86985ad 100644 --- a/sync2/dbsync/dbitemstore.go +++ b/sync2/dbsync/dbitemstore.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "sync" + "time" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" @@ -231,6 +232,11 @@ func (d *DBItemStore) Has(ctx context.Context, k hashsync.Ordered) (bool, error) return itK.Compare(k) == 0, nil } +// Recent implements hashsync.ItemStore. +func (d *DBItemStore) Recent(ctx context.Context, since time.Time) (hashsync.Iterator, int, error) { + return d.dbStore.iterSince(ctx, make(KeyBytes, d.keyLen), since.UnixNano()) +} + // TODO: get rid of ItemStoreAdapter, it shouldn't be needed type ItemStoreAdapter struct { s *DBItemStore @@ -343,6 +349,15 @@ func (a *ItemStoreAdapter) Min(ctx context.Context) (hashsync.Iterator, error) { return a.wrapIterator(it), nil } +// Recent implements hashsync.ItemStore. +func (d *ItemStoreAdapter) Recent(ctx context.Context, since time.Time) (hashsync.Iterator, int, error) { + it, count, err := d.s.Recent(ctx, since) + if err != nil { + return nil, 0, err + } + return d.wrapIterator(it), count, nil +} + type iteratorAdapter struct { it hashsync.Iterator } diff --git a/sync2/dbsync/dbiter.go b/sync2/dbsync/dbiter.go index 88c174780c..8d1480c33d 100644 --- a/sync2/dbsync/dbiter.go +++ b/sync2/dbsync/dbiter.go @@ -77,6 +77,7 @@ type dbRangeIterator struct { from KeyBytes sts *SyncedTableSnapshot chunkSize int + ts int64 maxChunkSize int chunk []KeyBytes pos int @@ -94,6 +95,7 @@ func newDBRangeIterator( db sql.Executor, sts *SyncedTableSnapshot, from KeyBytes, + ts int64, maxChunkSize int, lru *lru, ) hashsync.Iterator { @@ -108,6 +110,7 @@ func newDBRangeIterator( from: from.Clone(), sts: sts, chunkSize: 1, + ts: ts, maxChunkSize: maxChunkSize, keyLen: len(from), chunk: make([]KeyBytes, maxChunkSize), @@ -118,6 +121,9 @@ func newDBRangeIterator( } func (it *dbRangeIterator) loadCached(key dbIDKey) (bool, int) { + if it.cache == nil { + return false, 0 + } chunk, ok := it.cache.Get(key) if !ok { // fmt.Fprintf(os.Stderr, "QQQQQ: cache miss\n") @@ -156,24 +162,27 @@ func (it *dbRangeIterator) load() error { var ierr, err error found, n := it.loadCached(key) if !found { - err := it.sts.loadIDRange( - it.db, it.from, it.chunkSize, - func(stmt *sql.Statement) bool { - if n >= len(it.chunk) { - ierr = errors.New("too many rows") - return false - } - // we reuse existing slices when possible for retrieving new IDs - id := it.chunk[n] - if id == nil { - id = make([]byte, it.keyLen) - it.chunk[n] = id - } - stmt.ColumnBytes(0, id) - n++ - return true - }) - if err == nil && ierr == nil { + dec := func(stmt *sql.Statement) bool { + if n >= len(it.chunk) { + ierr = errors.New("too many rows") + return false + } + // we reuse existing slices when possible for retrieving new IDs + id := it.chunk[n] + if id == nil { + id = make([]byte, it.keyLen) + it.chunk[n] = id + } + stmt.ColumnBytes(0, id) + n++ + return true + } + if it.ts <= 0 { + err = it.sts.loadIDRange(it.db, it.from, it.chunkSize, dec) + } else { + err = it.sts.loadRecent(it.db, it.from, it.chunkSize, it.ts, dec) + } + if err == nil && ierr == nil && it.cache != nil { cached := make([]KeyBytes, n) for n, id := range it.chunk[:n] { cached[n] = slices.Clone(id) diff --git a/sync2/dbsync/dbiter_test.go b/sync2/dbsync/dbiter_test.go index 4e6e12a68f..ae526585f8 100644 --- a/sync2/dbsync/dbiter_test.go +++ b/sync2/dbsync/dbiter_test.go @@ -298,7 +298,7 @@ func TestDBRangeIterator(t *testing.T) { sts, err := st.snapshot(db) require.NoError(t, err) for maxChunkSize := 1; maxChunkSize < 12; maxChunkSize++ { - it := newDBRangeIterator(db, sts, tc.from, maxChunkSize, cache) + it := newDBRangeIterator(db, sts, tc.from, -1, maxChunkSize, cache) if tc.expErr != nil { _, err := it.Key() require.ErrorIs(t, err, tc.expErr) diff --git a/sync2/dbsync/p2p_test.go b/sync2/dbsync/p2p_test.go index bacd93a8d1..35046e86d7 100644 --- a/sync2/dbsync/p2p_test.go +++ b/sync2/dbsync/p2p_test.go @@ -3,12 +3,14 @@ package dbsync import ( "context" "encoding/binary" + "encoding/hex" "errors" "io" "slices" "testing" "time" + "github.com/jonboulle/clockwork" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -16,29 +18,79 @@ import ( "golang.org/x/sync/errgroup" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/common/util" "github.com/spacemeshos/go-spacemesh/p2p/server" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sync2/hashsync" ) -func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { +var startDate = time.Date(2024, 8, 29, 18, 0, 0, 0, time.UTC) + +type fooRow struct { + id KeyBytes + ts int64 +} + +func populateFoo(t *testing.T, rows []fooRow) sql.Database { + db := sql.InMemoryTest(t) + _, err := db.Exec( + "create table foo(id char(32) not null primary key, received int)", + nil, nil) + require.NoError(t, err) + for _, row := range rows { + _, err := db.Exec( + "insert into foo (id, received) values (?, ?)", + func(stmt *sql.Statement) { + stmt.BindBytes(1, row.id) + stmt.BindInt64(2, row.ts) + }, nil) + require.NoError(t, err) + } + return db +} + +type syncTracer struct { + dumb bool + receivedItems int + sentItems int +} + +var _ hashsync.Tracer = &syncTracer{} + +func (tr *syncTracer) OnDumbSync() { + // QQQQQQ: use mutex and also update handler_test.go in hashsync!!!! + tr.dumb = true +} + +func (tr *syncTracer) OnRecent(receivedItems, sentItems int) { + tr.receivedItems += receivedItems + tr.sentItems += sentItems +} + +func verifyP2P( + t *testing.T, + rowsA, rowsB []fooRow, + combinedItems []KeyBytes, + clockAt time.Time, + receivedRecent, sentRecent bool, + opts ...hashsync.RangeSetReconcilerOption, +) { nr := hashsync.RmmeNumRead() nw := hashsync.RmmeNumWritten() const maxDepth = 24 log := zaptest.NewLogger(t) t.Logf("QQQQQ: 0") - dbA := populateDB(t, 32, itemsA) + dbA := populateFoo(t, rowsA) t.Logf("QQQQQ: 1") - dbB := populateDB(t, 32, itemsB) + dbB := populateFoo(t, rowsB) mesh, err := mocknet.FullMeshConnected(2) require.NoError(t, err) proto := "itest" t.Logf("QQQQQ: 2") ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) st := &SyncedTable{ - TableName: "foo", - IDColumn: "id", + TableName: "foo", + IDColumn: "id", + TimestampColumn: "received", } storeA := NewItemStoreAdapter(NewDBItemStore(dbA, st, 32, maxDepth)) t.Logf("QQQQQ: 2.1") @@ -87,14 +139,22 @@ func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { // storeB.s.ft.dump(&sb) // t.Logf("storeB:\n%s", sb.String()) + var tr syncTracer + opts = append(opts, + hashsync.WithRangeReconcilerClock(clockwork.NewFakeClockAt(clockAt)), + hashsync.WithTracer(&tr), + ) + opts = opts[:len(opts):len(opts)] + srvPeerID := mesh.Hosts()[0].ID() srv := server.New(mesh.Hosts()[0], proto, func(ctx context.Context, req []byte, stream io.ReadWriter) error { - pss := hashsync.NewPairwiseStoreSyncer(nil, []hashsync.RangeSetReconcilerOption{ + pss := hashsync.NewPairwiseStoreSyncer(nil, append( + opts, hashsync.WithMaxSendRange(1), // uncomment to enable verbose logging which may slow down tests - // hashsync.WithRangeSyncLogger(log.Named("sideA")), - }) + // hashsync.WithRangeReconcilerLogger(log.Named("sideA")), + )) return dbA.WithTx(ctx, func(tx sql.Transaction) error { return pss.Serve(WithSQLExec(ctx, tx), req, stream, storeA) }) @@ -132,11 +192,12 @@ func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { return true }, time.Second, 10*time.Millisecond) - pss := hashsync.NewPairwiseStoreSyncer(client, []hashsync.RangeSetReconcilerOption{ + pss := hashsync.NewPairwiseStoreSyncer(client, append( + opts, hashsync.WithMaxSendRange(1), // uncomment to enable verbose logging which may slow down tests - // hashsync.WithRangeSyncLogger(log.Named("sideB")), - }) + // hashsync.WithRangeReconcilerLogger(log.Named("sideB")), + )) tStart := time.Now() require.NoError(t, dbB.WithTx(ctx, func(tx sql.Transaction) error { @@ -145,6 +206,9 @@ func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { t.Logf("synced in %v", time.Since(tStart)) t.Logf("bytes read: %d, bytes written: %d", hashsync.RmmeNumRead()-nr, hashsync.RmmeNumWritten()-nw) + require.Equal(t, receivedRecent, tr.receivedItems > 0) + require.Equal(t, sentRecent, tr.sentItems > 0) + // // QQQQQ: rmme // sb = strings.Builder{} // storeA.s.ft.dump(&sb) @@ -205,73 +269,171 @@ func verifyP2P(t *testing.T, itemsA, itemsB, combinedItems []KeyBytes) { assert.Equal(t, actItemsA, actItemsB) } +func fooR(id string, seconds int) fooRow { + return fooRow{ + hexID(id), + startDate.Add(time.Duration(seconds) * time.Second).UnixNano(), + } +} + +func hexID(s string) KeyBytes { + b, err := hex.DecodeString(s) + if err != nil { + panic(err) + } + return b +} + func TestP2P(t *testing.T) { t.Run("predefined items", func(t *testing.T) { verifyP2P( - t, []KeyBytes{ - util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), - util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), - util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), - util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000"), + t, []fooRow{ + fooR("1111111111111111111111111111111111111111111111111111111111111111", 10), + fooR("123456789abcdef0000000000000000000000000000000000000000000000000", 20), + fooR("8888888888888888888888888888888888888888888888888888888888888888", 30), + fooR("abcdef1234567890000000000000000000000000000000000000000000000000", 40), }, - []KeyBytes{ - util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), - util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), - util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), - util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), + []fooRow{ + fooR("1111111111111111111111111111111111111111111111111111111111111111", 11), + fooR("123456789abcdef0000000000000000000000000000000000000000000000000", 12), + fooR("5555555555555555555555555555555555555555555555555555555555555555", 13), + fooR("8888888888888888888888888888888888888888888888888888888888888888", 14), }, []KeyBytes{ - util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), - util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), - util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), - util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), - util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000"), - }) + hexID("1111111111111111111111111111111111111111111111111111111111111111"), + hexID("123456789abcdef0000000000000000000000000000000000000000000000000"), + hexID("5555555555555555555555555555555555555555555555555555555555555555"), + hexID("8888888888888888888888888888888888888888888888888888888888888888"), + hexID("abcdef1234567890000000000000000000000000000000000000000000000000"), + }, + startDate, + false, + false, + ) }) t.Run("predefined items 2", func(t *testing.T) { verifyP2P( - t, []KeyBytes{ - util.FromHex("80e95b39faa731eb50eae7585a8b1cae98f503481f950fdb690e60ff86c21236"), - util.FromHex("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7"), - util.FromHex("d862b2413af5c252028e8f9871be8297e807661d64decd8249ac2682db168b90"), - util.FromHex("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567"), + t, []fooRow{ + fooR("80e95b39faa731eb50eae7585a8b1cae98f503481f950fdb690e60ff86c21236", 10), + fooR("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7", 20), + fooR("d862b2413af5c252028e8f9871be8297e807661d64decd8249ac2682db168b90", 30), + fooR("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567", 40), + }, + []fooRow{ + fooR("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7", 11), + fooR("bc6218a88d1648b8145fbf93ae74af8975f193af88788e7add3608e0bc50f701", 12), + fooR("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567", 13), + fooR("fbf03324234f79a3fe0587cf5505d7e4c826cb2be38d72eafa60296ed77b3f8f", 14), }, []KeyBytes{ - util.FromHex("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7"), - util.FromHex("bc6218a88d1648b8145fbf93ae74af8975f193af88788e7add3608e0bc50f701"), - util.FromHex("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567"), - util.FromHex("fbf03324234f79a3fe0587cf5505d7e4c826cb2be38d72eafa60296ed77b3f8f"), + hexID("80e95b39faa731eb50eae7585a8b1cae98f503481f950fdb690e60ff86c21236"), + hexID("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7"), + hexID("bc6218a88d1648b8145fbf93ae74af8975f193af88788e7add3608e0bc50f701"), + hexID("d862b2413af5c252028e8f9871be8297e807661d64decd8249ac2682db168b90"), + hexID("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567"), + hexID("fbf03324234f79a3fe0587cf5505d7e4c826cb2be38d72eafa60296ed77b3f8f"), + }, + startDate, + false, + false, + ) + }) + t.Run("predefined items with recent", func(t *testing.T) { + verifyP2P( + t, []fooRow{ + fooR("80e95b39faa731eb50eae7585a8b1cae98f503481f950fdb690e60ff86c21236", 10), + fooR("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7", 20), + fooR("d862b2413af5c252028e8f9871be8297e807661d64decd8249ac2682db168b90", 30), + fooR("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567", 40), + }, + []fooRow{ + fooR("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7", 11), + fooR("bc6218a88d1648b8145fbf93ae74af8975f193af88788e7add3608e0bc50f701", 12), + fooR("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567", 13), + fooR("fbf03324234f79a3fe0587cf5505d7e4c826cb2be38d72eafa60296ed77b3f8f", 14), }, []KeyBytes{ - util.FromHex("80e95b39faa731eb50eae7585a8b1cae98f503481f950fdb690e60ff86c21236"), - util.FromHex("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7"), - util.FromHex("bc6218a88d1648b8145fbf93ae74af8975f193af88788e7add3608e0bc50f701"), - util.FromHex("d862b2413af5c252028e8f9871be8297e807661d64decd8249ac2682db168b90"), - util.FromHex("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567"), - util.FromHex("fbf03324234f79a3fe0587cf5505d7e4c826cb2be38d72eafa60296ed77b3f8f"), - }) + hexID("80e95b39faa731eb50eae7585a8b1cae98f503481f950fdb690e60ff86c21236"), + hexID("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7"), + hexID("bc6218a88d1648b8145fbf93ae74af8975f193af88788e7add3608e0bc50f701"), + hexID("d862b2413af5c252028e8f9871be8297e807661d64decd8249ac2682db168b90"), + hexID("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567"), + hexID("fbf03324234f79a3fe0587cf5505d7e4c826cb2be38d72eafa60296ed77b3f8f"), + }, + startDate.Add(time.Minute), + true, + true, + hashsync.WithRecentTimeSpan(48*time.Second), + ) }) t.Run("empty to non-empty", func(t *testing.T) { verifyP2P( t, nil, + []fooRow{ + fooR("1111111111111111111111111111111111111111111111111111111111111111", 11), + fooR("123456789abcdef0000000000000000000000000000000000000000000000000", 12), + fooR("5555555555555555555555555555555555555555555555555555555555555555", 13), + fooR("8888888888888888888888888888888888888888888888888888888888888888", 14), + }, + []KeyBytes{ + hexID("1111111111111111111111111111111111111111111111111111111111111111"), + hexID("123456789abcdef0000000000000000000000000000000000000000000000000"), + hexID("5555555555555555555555555555555555555555555555555555555555555555"), + hexID("8888888888888888888888888888888888888888888888888888888888888888"), + }, + startDate, + false, + false, + ) + }) + t.Run("empty to non-empty with recent", func(t *testing.T) { + verifyP2P( + t, nil, + []fooRow{ + fooR("1111111111111111111111111111111111111111111111111111111111111111", 11), + fooR("123456789abcdef0000000000000000000000000000000000000000000000000", 12), + fooR("5555555555555555555555555555555555555555555555555555555555555555", 13), + fooR("8888888888888888888888888888888888888888888888888888888888888888", 14), + }, []KeyBytes{ - util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), - util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), - util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), - util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), + hexID("1111111111111111111111111111111111111111111111111111111111111111"), + hexID("123456789abcdef0000000000000000000000000000000000000000000000000"), + hexID("5555555555555555555555555555555555555555555555555555555555555555"), + hexID("8888888888888888888888888888888888888888888888888888888888888888"), + }, + startDate.Add(time.Minute), + true, + true, + hashsync.WithRecentTimeSpan(48*time.Second), + ) + }) + t.Run("non-empty to empty with recent", func(t *testing.T) { + verifyP2P( + t, + []fooRow{ + fooR("1111111111111111111111111111111111111111111111111111111111111111", 11), + fooR("123456789abcdef0000000000000000000000000000000000000000000000000", 12), + fooR("5555555555555555555555555555555555555555555555555555555555555555", 13), + fooR("8888888888888888888888888888888888888888888888888888888888888888", 14), }, + nil, []KeyBytes{ - util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), - util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), - util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), - util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), - }) + hexID("1111111111111111111111111111111111111111111111111111111111111111"), + hexID("123456789abcdef0000000000000000000000000000000000000000000000000"), + hexID("5555555555555555555555555555555555555555555555555555555555555555"), + hexID("8888888888888888888888888888888888888888888888888888888888888888"), + }, + startDate.Add(time.Minute), + // no actual recent exchange happens due to the initial EmptySet message + false, + false, + hashsync.WithRecentTimeSpan(48*time.Second), + ) }) t.Run("empty to empty", func(t *testing.T) { - verifyP2P(t, nil, nil, nil) + verifyP2P(t, nil, nil, nil, startDate, false, false) }) t.Run("random test", func(t *testing.T) { - // TODO: increase these values and profile // const nShared = 8000000 // const nUniqueA = 100 // const nUniqueB = 80000 @@ -285,21 +447,32 @@ func TestP2P(t *testing.T) { // const nUniqueA = 2 // const nUniqueB = 2 combined := make([]KeyBytes, 0, nShared+nUniqueA+nUniqueB) - itemsA := make([]KeyBytes, nShared+nUniqueA) - for i := range itemsA { + rowsA := make([]fooRow, nShared+nUniqueA) + for i := range rowsA { h := types.RandomHash() - itemsA[i] = KeyBytes(h[:]) - combined = append(combined, itemsA[i]) + k := KeyBytes(h[:]) + rowsA[i] = fooRow{ + id: k, + ts: startDate.Add(time.Duration(i) * time.Second).UnixNano(), + } + combined = append(combined, k) // t.Logf("itemsA[%d] = %s", i, itemsA[i]) } - itemsB := make([]KeyBytes, nShared+nUniqueB) - for i := range itemsB { + rowsB := make([]fooRow, nShared+nUniqueB) + for i := range rowsB { if i < nShared { - itemsB[i] = slices.Clone(itemsA[i]) + rowsB[i] = fooRow{ + id: slices.Clone(rowsA[i].id), + ts: rowsA[i].ts, + } } else { h := types.RandomHash() - itemsB[i] = KeyBytes(h[:]) - combined = append(combined, itemsB[i]) + k := KeyBytes(h[:]) + rowsB[i] = fooRow{ + id: k, + ts: startDate.Add(time.Duration(i) * time.Second).UnixNano(), + } + combined = append(combined, k) } // t.Logf("itemsB[%d] = %s", i, itemsB[i]) } @@ -309,7 +482,9 @@ func TestP2P(t *testing.T) { // for i, v := range combined { // t.Logf("combined[%d] = %s", i, v) // } - verifyP2P(t, itemsA, itemsB, combined) + verifyP2P(t, rowsA, rowsB, combined, startDate, false, false) // TODO: multiple iterations }) } + +// QQQQQ: TBD empty sets with recent diff --git a/sync2/dbsync/sqlidstore.go b/sync2/dbsync/sqlidstore.go index bee8bf0931..177616a384 100644 --- a/sync2/dbsync/sqlidstore.go +++ b/sync2/dbsync/sqlidstore.go @@ -66,7 +66,22 @@ func (s *sqlIDStore) iter(ctx context.Context, from KeyBytes) hashsync.Iterator if len(from) != s.keyLen { panic("BUG: invalid key length") } - return newDBRangeIterator(ContextSQLExec(ctx, s.db), s.sts, from, sqlMaxChunkSize, s.cache) + return newDBRangeIterator(ContextSQLExec(ctx, s.db), s.sts, from, -1, sqlMaxChunkSize, s.cache) +} + +func (s *sqlIDStore) iterSince(ctx context.Context, from KeyBytes, since int64) (hashsync.Iterator, int, error) { + if len(from) != s.keyLen { + panic("BUG: invalid key length") + } + db := ContextSQLExec(ctx, s.db) + count, err := s.sts.loadRecentCount(db, since) + if err != nil { + return nil, 0, err + } + if count == 0 { + return nil, 0, nil + } + return newDBRangeIterator(db, s.sts, from, since, sqlMaxChunkSize, nil), count, nil } func (s *sqlIDStore) setSnapshot(sts *SyncedTableSnapshot) { diff --git a/sync2/dbsync/sqlidstore_test.go b/sync2/dbsync/sqlidstore_test.go index 87fd42e1df..03b1c1cc5b 100644 --- a/sync2/dbsync/sqlidstore_test.go +++ b/sync2/dbsync/sqlidstore_test.go @@ -4,20 +4,49 @@ import ( "context" "testing" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/stretchr/testify/require" ) func TestDBBackedStore(t *testing.T) { - initialIDs := []KeyBytes{ - {0, 0, 0, 1, 0, 0, 0, 0}, - {0, 0, 0, 3, 0, 0, 0, 0}, - {0, 0, 0, 5, 0, 0, 0, 0}, - {0, 0, 0, 7, 0, 0, 0, 0}, + db := sql.InMemoryTest(t) + _, err := db.Exec( + "create table foo(id char(8) not null primary key, received int)", + nil, nil) + require.NoError(t, err) + for _, row := range []struct { + id KeyBytes + ts int64 + }{ + { + id: KeyBytes{0, 0, 0, 1, 0, 0, 0, 0}, + ts: 100, + }, + { + id: KeyBytes{0, 0, 0, 3, 0, 0, 0, 0}, + ts: 200, + }, + { + id: KeyBytes{0, 0, 0, 5, 0, 0, 0, 0}, + ts: 300, + }, + { + id: KeyBytes{0, 0, 0, 7, 0, 0, 0, 0}, + ts: 400, + }, + } { + _, err := db.Exec( + "insert into foo (id, received) values (?, ?)", + func(stmt *sql.Statement) { + stmt.BindBytes(1, row.id) + stmt.BindInt64(2, row.ts) + }, nil) + require.NoError(t, err) } - db := populateDB(t, 8, initialIDs) st := SyncedTable{ - TableName: "foo", - IDColumn: "id", + TableName: "foo", + IDColumn: "id", + TimestampColumn: "received", } sts, err := st.snapshot(db) require.NoError(t, err) @@ -43,6 +72,20 @@ func TestDBBackedStore(t *testing.T) { require.NoError(t, it.Next()) } + actualIDs = nil + it, count, err := store.iterSince(ctx, KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}, 300) + require.NoError(t, err) + require.Equal(t, 2, count) + for range 3 { + actualIDs = append(actualIDs, itKey(t, it)) + require.NoError(t, it.Next()) + } + require.Equal(t, []KeyBytes{ + {0, 0, 0, 5, 0, 0, 0, 0}, + {0, 0, 0, 7, 0, 0, 0, 0}, + {0, 0, 0, 5, 0, 0, 0, 0}, // wrapped around + }, actualIDs) + require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 2, 0, 0, 0, 0})) require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 9, 0, 0, 0, 0})) actualIDs = nil diff --git a/sync2/dbsync/syncedtable.go b/sync2/dbsync/syncedtable.go index 75633495a7..aba9ac9e73 100644 --- a/sync2/dbsync/syncedtable.go +++ b/sync2/dbsync/syncedtable.go @@ -10,16 +10,19 @@ import ( type Binder func(s *sql.Statement) type SyncedTable struct { - TableName string - IDColumn string - Filter rsql.Expr - Binder Binder + TableName string + IDColumn string + TimestampColumn string + Filter rsql.Expr + Binder Binder } func (st *SyncedTable) genSelectAll() *rsql.SelectStatement { return &rsql.SelectStatement{ Columns: []*rsql.ResultColumn{ - {Expr: &rsql.Ident{Name: st.IDColumn}}, + { + Expr: &rsql.Ident{Name: st.IDColumn}, + }, }, Source: &rsql.QualifiedTableName{Name: &rsql.Ident{Name: st.TableName}}, WhereExpr: st.Filter, @@ -117,6 +120,55 @@ func (st *SyncedTable) genSelectIDRangeWithRowIDCutoff() *rsql.SelectStatement { return s } +func (st *SyncedTable) genSelectRecentRowIDCutoff() *rsql.SelectStatement { + s := st.genSelectIDRangeWithRowIDCutoff() + s.WhereExpr = &rsql.BinaryExpr{ + X: s.WhereExpr, + Op: rsql.AND, + Y: &rsql.BinaryExpr{ + X: &rsql.Ident{Name: st.TimestampColumn}, + Op: rsql.GE, + Y: &rsql.BindExpr{Name: "?"}, + }, + } + return s +} + +func (st *SyncedTable) genRecentCount() *rsql.SelectStatement { + where := &rsql.BinaryExpr{ + X: &rsql.BinaryExpr{ + X: &rsql.Ident{Name: "rowid"}, + Op: rsql.LE, + Y: &rsql.BindExpr{Name: "?"}, + }, + Op: rsql.AND, + Y: &rsql.BinaryExpr{ + X: &rsql.Ident{Name: st.TimestampColumn}, + Op: rsql.GE, + Y: &rsql.BindExpr{Name: "?"}, + }, + } + if st.Filter != nil { + where = &rsql.BinaryExpr{ + X: st.Filter, + Op: rsql.AND, + Y: where, + } + } + return &rsql.SelectStatement{ + Columns: []*rsql.ResultColumn{ + { + Expr: &rsql.Call{ + Name: &rsql.Ident{Name: "count"}, + Args: []rsql.Expr{&rsql.Ident{Name: st.IDColumn}}, + }, + }, + }, + Source: &rsql.QualifiedTableName{Name: &rsql.Ident{Name: st.TableName}}, + WhereExpr: where, + } +} + func (st *SyncedTable) loadMaxRowID(db sql.Executor) (maxRowID int64, err error) { nRows, err := db.Exec( st.genSelectMaxRowID().String(), nil, @@ -191,6 +243,9 @@ func (sts *SyncedTableSnapshot) loadIDRange( sts.Binder(stmt) } nParams := stmt.BindParamCount() + // fmt.Fprintf(os.Stderr, "QQQQQ: STMT: %s\nfromID %s maxRowID %d limit %d\n", + // sts.genSelectIDRangeWithRowIDCutoff().String(), + // fromID.String(), sts.maxRowID, limit) stmt.BindBytes(nParams-2, fromID) stmt.BindInt64(nParams-1, sts.maxRowID) stmt.BindInt64(nParams, int64(limit)) @@ -199,43 +254,53 @@ func (sts *SyncedTableSnapshot) loadIDRange( return err } -// func (st *SyncedTable) bind(s *sql.Statement) int { -// ofs := 0 -// if st.Filter != nil { -// var v bindCountVisitor -// if err := rsql.Walk(&v, st.Filter); err != nil { -// panic("BUG: bad filter: " + err.Error()) -// } -// ofs = v.numBinds -// switch { -// case ofs == 0 && st.Binder != nil: -// panic("BUG: filter has no binds but a binder is passed") -// case ofs > 0 && st.Binder == nil: -// panic("BUG: filter has binds but no binder is passed") -// } -// st.Binder(s) -// } else if st.Binder != nil { -// panic("BUG: there's no filter but there's a binder") -// } -// return ofs -// } - -// type bindCountVisitor struct { -// numBinds int -// } - -// var _ rsql.Visitor = &bindCountVisitor{} - -// func (b *bindCountVisitor) Visit(node rsql.Node) (w rsql.Visitor, err error) { -// bExpr, ok := node.(*rsql.BindExpr) -// if !ok { -// return b, nil -// } -// if bExpr.Name != "?" { -// return nil, fmt.Errorf("bad bind %s: only ? binds are supported", bExpr.Name) -// } -// b.numBinds++ -// return nil, nil -// } - -// func (b *bindCountVisitor) VisitEnd(node rsql.Node) error { return nil } +func (sts *SyncedTableSnapshot) loadRecentCount( + db sql.Executor, + since int64, +) (int, error) { + if sts.TimestampColumn == "" { + return 0, fmt.Errorf("no timestamp column") + } + var count int + _, err := db.Exec( + sts.genRecentCount().String(), + func(stmt *sql.Statement) { + if sts.Binder != nil { + sts.Binder(stmt) + } + nParams := stmt.BindParamCount() + stmt.BindInt64(nParams-1, sts.maxRowID) + stmt.BindInt64(nParams, since) + }, + func(stmt *sql.Statement) bool { + count = stmt.ColumnInt(0) + return true + }) + return count, err +} + +func (sts *SyncedTableSnapshot) loadRecent( + db sql.Executor, + fromID KeyBytes, + limit int, + since int64, + dec func(stmt *sql.Statement) bool, +) error { + if sts.TimestampColumn == "" { + return fmt.Errorf("no timestamp column") + } + _, err := db.Exec( + sts.genSelectRecentRowIDCutoff().String(), + func(stmt *sql.Statement) { + if sts.Binder != nil { + sts.Binder(stmt) + } + nParams := stmt.BindParamCount() + stmt.BindBytes(nParams-3, fromID) + stmt.BindInt64(nParams-2, sts.maxRowID) + stmt.BindInt64(nParams-1, since) + stmt.BindInt64(nParams, int64(limit)) + }, + dec) + return err +} diff --git a/sync2/dbsync/syncedtable_test.go b/sync2/dbsync/syncedtable_test.go index 1c7a6d9eea..1285f497c7 100644 --- a/sync2/dbsync/syncedtable_test.go +++ b/sync2/dbsync/syncedtable_test.go @@ -10,24 +10,6 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" ) -// func TestRmme(t *testing.T) { -// s, err := rsql.ParseExprString("x.id = ? and y.id = ?") -// require.NoError(t, err) -// var v bindCountVisitor -// require.NoError(t, rsql.Walk(&v, s)) -// require.Equal(t, 2, v.numBinds) - -// st, err := rsql.NewParser(strings.NewReader("select max(rowidx) from foo")).ParseStatement() -// require.NoError(t, err) -// spew.Config.DisableMethods = true -// spew.Config.DisablePointerAddresses = true -// defer func() { -// spew.Config.DisableMethods = false -// spew.Config.DisablePointerAddresses = false -// }() -// t.Logf("s: %s\n", spew.Sdump(st)) -// } - func parseSQLExpr(t *testing.T, s string) rsql.Expr { expr, err := rsql.ParseExprString(s) require.NoError(t, err) @@ -36,44 +18,54 @@ func parseSQLExpr(t *testing.T, s string) rsql.Expr { func TestSyncedTable_GenSQL(t *testing.T) { for _, tc := range []struct { - name string - st SyncedTable - selectAllRC string - selectMaxRowID string - selectIDs string - selectIDsRC string + name string + st SyncedTable + allRC string + maxRowID string + IDs string + IDsRC string + Recent string }{ { name: "no filter", st: SyncedTable{ - TableName: "atxs", - IDColumn: "id", + TableName: "atxs", + IDColumn: "id", + TimestampColumn: "received", }, - selectAllRC: `SELECT "id" FROM "atxs" WHERE "rowid" <= ?`, - selectMaxRowID: `SELECT max("rowid") FROM "atxs"`, - selectIDs: `SELECT "id" FROM "atxs" WHERE "id" >= ? ORDER BY "id" LIMIT ?`, - selectIDsRC: `SELECT "id" FROM "atxs" WHERE "id" >= ? AND "rowid" <= ? ` + + allRC: `SELECT "id" FROM "atxs" WHERE "rowid" <= ?`, + maxRowID: `SELECT max("rowid") FROM "atxs"`, + IDs: `SELECT "id" FROM "atxs" WHERE "id" >= ? ORDER BY "id" LIMIT ?`, + IDsRC: `SELECT "id" FROM "atxs" WHERE "id" >= ? AND "rowid" <= ? ` + `ORDER BY "id" LIMIT ?`, + Recent: `SELECT "id" FROM "atxs" WHERE "id" >= ? AND "rowid" <= ? ` + + `AND "received" >= ? ORDER BY "id" LIMIT ?`, }, { name: "filter", st: SyncedTable{ - TableName: "atxs", - IDColumn: "id", - Filter: parseSQLExpr(t, "epoch = ?"), + TableName: "atxs", + IDColumn: "id", + Filter: parseSQLExpr(t, "epoch = ?"), + TimestampColumn: "received", }, - selectAllRC: `SELECT "id" FROM "atxs" WHERE "epoch" = ? AND "rowid" <= ?`, - selectMaxRowID: `SELECT max("rowid") FROM "atxs"`, - selectIDs: `SELECT "id" FROM "atxs" WHERE "epoch" = ? AND "id" >= ? ` + + allRC: `SELECT "id" FROM "atxs" WHERE "epoch" = ? AND "rowid" <= ?`, + maxRowID: `SELECT max("rowid") FROM "atxs"`, + IDs: `SELECT "id" FROM "atxs" WHERE "epoch" = ? AND "id" >= ? ` + `ORDER BY "id" LIMIT ?`, - selectIDsRC: `SELECT "id" FROM "atxs" WHERE "epoch" = ? AND "id" >= ? ` + + IDsRC: `SELECT "id" FROM "atxs" WHERE "epoch" = ? AND "id" >= ? ` + `AND "rowid" <= ? ORDER BY "id" LIMIT ?`, + Recent: `SELECT "id" FROM "atxs" WHERE "epoch" = ? AND "id" >= ? ` + + `AND "rowid" <= ? AND "received" >= ? ORDER BY "id" LIMIT ?`, }, } { - require.Equal(t, tc.selectAllRC, tc.st.genSelectAllRowIDCutoff().String()) - require.Equal(t, tc.selectMaxRowID, tc.st.genSelectMaxRowID().String()) - require.Equal(t, tc.selectIDs, tc.st.genSelectIDRange().String()) - require.Equal(t, tc.selectIDsRC, tc.st.genSelectIDRangeWithRowIDCutoff().String()) + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.allRC, tc.st.genSelectAllRowIDCutoff().String()) + require.Equal(t, tc.maxRowID, tc.st.genSelectMaxRowID().String()) + require.Equal(t, tc.IDs, tc.st.genSelectIDRange().String()) + require.Equal(t, tc.IDsRC, tc.st.genSelectIDRangeWithRowIDCutoff().String()) + require.Equal(t, tc.Recent, tc.st.genSelectRecentRowIDCutoff().String()) + }) } } @@ -82,23 +74,25 @@ func TestSyncedTable_LoadIDs(t *testing.T) { type row struct { id string epoch int + ts int } rows := []row{ - {"0451cd036aff0367b07590032da827b516b63a4c1b36ea9a253dcf9a7e084980", 1}, - {"0e75d10a8e98a4307dd9d0427dc1d2ebf9e45b602d159ef62c5da95197159844", 1}, - {"18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", 2}, - {"1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", 2}, - {"1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", 2}, - {"2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", 3}, - {"24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", 3}, + {"0451cd036aff0367b07590032da827b516b63a4c1b36ea9a253dcf9a7e084980", 1, 100}, + {"0e75d10a8e98a4307dd9d0427dc1d2ebf9e45b602d159ef62c5da95197159844", 1, 110}, + {"18040e78f834b879a9585fba90f6f5e7394dc3bb27f20829baf6bfc9e1bfe44b", 2, 120}, + {"1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", 2, 150}, + {"1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", 2, 180}, + {"2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", 3, 190}, + {"24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", 3, 220}, } insertRows := func(rows []row) { for _, r := range rows { - _, err := db.Exec("insert into atxs (id, epoch) values (?, ?)", + _, err := db.Exec("insert into atxs (id, epoch, received) values (?, ?, ?)", func(stmt *sql.Statement) { stmt.BindBytes(1, util.FromHex(r.id)) stmt.BindInt64(2, int64(r.epoch)) + stmt.BindInt64(3, int64(r.ts)) }, nil) require.NoError(t, err) } @@ -106,43 +100,58 @@ func TestSyncedTable_LoadIDs(t *testing.T) { initDB := func() { db = sql.InMemoryTest(t) - _, err := db.Exec("create table atxs (id char(32) not null primary key, epoch int)", nil, nil) + _, err := db.Exec(`create table atxs ( + id char(32) not null primary key, + epoch int, + received int)`, nil, nil) require.NoError(t, err) insertRows(rows) } - loadIDs := func(sts *SyncedTableSnapshot) []string { - var ids []string - require.NoError(t, sts.loadIDs(db, func(stmt *sql.Statement) bool { + mkDecode := func(ids *[]string) func(stmt *sql.Statement) bool { + return func(stmt *sql.Statement) bool { id := make(KeyBytes, stmt.ColumnLen(0)) stmt.ColumnBytes(0, id) - ids = append(ids, id.String()) + *ids = append(*ids, id.String()) return true - })) + } + } + + loadIDs := func(sts *SyncedTableSnapshot) []string { + var ids []string + require.NoError(t, sts.loadIDs(db, mkDecode(&ids))) return ids } loadIDsSince := func(stsNew, stsOld *SyncedTableSnapshot) []string { var ids []string - require.NoError(t, stsNew.loadIDsSince(db, stsOld, func(stmt *sql.Statement) bool { - id := make(KeyBytes, stmt.ColumnLen(0)) - stmt.ColumnBytes(0, id) - ids = append(ids, id.String()) - return true - })) + require.NoError(t, stsNew.loadIDsSince(db, stsOld, mkDecode(&ids))) return ids } loadIDRange := func(sts *SyncedTableSnapshot, from KeyBytes, limit int) []string { var ids []string - require.NoError(t, sts.loadIDRange( - db, from, limit, - func(stmt *sql.Statement) bool { - id := make(KeyBytes, stmt.ColumnLen(0)) - stmt.ColumnBytes(0, id) - ids = append(ids, id.String()) - return true - })) + require.NoError(t, sts.loadIDRange(db, from, limit, mkDecode(&ids))) + return ids + } + + loadRecentCount := func( + sts *SyncedTableSnapshot, + ts int64, + ) int { + count, err := sts.loadRecentCount(db, ts) + require.NoError(t, err) + return count + } + + loadRecent := func( + sts *SyncedTableSnapshot, + from KeyBytes, + limit int, + ts int64, + ) []string { + var ids []string + require.NoError(t, sts.loadRecent(db, from, limit, ts, mkDecode(&ids))) return ids } @@ -150,8 +159,9 @@ func TestSyncedTable_LoadIDs(t *testing.T) { initDB() st := &SyncedTable{ - TableName: "atxs", - IDColumn: "id", + TableName: "atxs", + IDColumn: "id", + TimestampColumn: "received", } sts1, err := st.snapshot(db) @@ -184,9 +194,16 @@ func TestSyncedTable_LoadIDs(t *testing.T) { "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", }, loadIDRange(sts1, fromID, 2)) + require.ElementsMatch(t, + []string{ + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + }, loadRecent(sts1, fromID, 3, 180)) + require.Equal(t, 3, loadRecentCount(sts1, 180)) insertRows([]row{ - {"2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", 2}, + {"2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", 2, 300}, }) // the new row is not included in the first snapshot @@ -208,6 +225,13 @@ func TestSyncedTable_LoadIDs(t *testing.T) { "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", }, loadIDRange(sts1, fromID, 100)) + require.ElementsMatch(t, + []string{ + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + }, + loadRecent(sts1, fromID, 3, 180)) sts2, err := st.snapshot(db) require.NoError(t, err) @@ -222,7 +246,8 @@ func TestSyncedTable_LoadIDs(t *testing.T) { "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", - }, loadIDs(sts2)) + }, + loadIDs(sts2)) require.ElementsMatch(t, []string{ "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", @@ -237,14 +262,40 @@ func TestSyncedTable_LoadIDs(t *testing.T) { "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", }, loadIDsSince(sts2, sts1)) + require.ElementsMatch(t, + []string{ + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadRecent(sts2, fromID, 4, 180)) + require.ElementsMatch(t, + []string{ + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadRecent(sts2, + util.FromHex("2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b"), + 4, 180)) + require.ElementsMatch(t, + []string{ + "2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b", + "24b31a6acc8cd13b119dd5aa81a6c3803250a8a79eb32231f16b09e0971f1b23", + }, + loadRecent(sts2, + util.FromHex("2023eee75bec75da61ad7644bd43f02b9397a72cf489565cb53a4337975a290b"), + 2, 180)) }) t.Run("filter", func(t *testing.T) { initDB() st := &SyncedTable{ - TableName: "atxs", - IDColumn: "id", - Filter: parseSQLExpr(t, "epoch = ?"), + TableName: "atxs", + IDColumn: "id", + TimestampColumn: "received", + Filter: parseSQLExpr(t, "epoch = ?"), Binder: func(stmt *sql.Statement) { stmt.BindInt64(1, 2) }, @@ -268,14 +319,18 @@ func TestSyncedTable_LoadIDs(t *testing.T) { "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", }, loadIDRange(sts1, fromID, 100)) - require.ElementsMatch(t, []string{ "1a9b743abdabe7970041ba2006c0e8bb51a27b1dbfd1a8c70ef5e7703ddeaa55", }, loadIDRange(sts1, fromID, 1)) + require.ElementsMatch(t, + []string{ + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + }, + loadRecent(sts1, fromID, 1, 180)) insertRows([]row{ - {"2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", 2}, + {"2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", 2, 300}, }) // the new row is not included in the first snapshot @@ -292,6 +347,11 @@ func TestSyncedTable_LoadIDs(t *testing.T) { "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", }, loadIDRange(sts1, fromID, 100)) + require.ElementsMatch(t, + []string{ + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + }, + loadRecent(sts1, fromID, 1, 180)) sts2, err := st.snapshot(db) require.NoError(t, err) @@ -316,5 +376,16 @@ func TestSyncedTable_LoadIDs(t *testing.T) { "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", }, loadIDsSince(sts2, sts1)) + require.ElementsMatch(t, + []string{ + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + "2664e267650ee22dee7d8c987b5cf44ba5596c78df3db5b99fb0ce79cc649d69", + }, + loadRecent(sts2, fromID, 2, 180)) + require.ElementsMatch(t, + []string{ + "1b49b5a17161995cc288523637bd63af5bed99f4f7188effb702da8a7a4beee1", + }, + loadRecent(sts2, fromID, 1, 180)) }) } diff --git a/sync2/hashsync/handler.go b/sync2/hashsync/handler.go index cfd7e77b25..2b16c6ad8b 100644 --- a/sync2/hashsync/handler.go +++ b/sync2/hashsync/handler.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "sync/atomic" + "time" "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" @@ -114,6 +115,12 @@ func (c *wireConduit) NextMessage() (SyncMessage, error) { return nil, err } return &m, nil + case MessageTypeRecent: + var m RecentMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + return &m, nil default: return nil, fmt.Errorf("invalid message code %02x", b[0]) } @@ -227,6 +234,12 @@ func (c *wireConduit) SendSample(x, y Ordered, fingerprint any, count, sampleSiz return c.send(m) } +func (c *wireConduit) SendRecent(since time.Time) error { + return c.send(&RecentMessage{ + SinceTime: uint64(since.UnixNano()), + }) +} + func (c *wireConduit) withInitialRequest(toCall func(Conduit) error) ([]byte, error) { c.initReqBuf = new(bytes.Buffer) defer func() { c.initReqBuf = nil }() diff --git a/sync2/hashsync/handler_test.go b/sync2/hashsync/handler_test.go index b1eb8a2ed6..79791f71db 100644 --- a/sync2/hashsync/handler_test.go +++ b/sync2/hashsync/handler_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/jonboulle/clockwork" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" @@ -383,6 +384,7 @@ func p2pRequesterGetter(t *testing.T) getRequesterFunc { require.NoError(t, err) proto := "itest" opts := []server.Opt{ + server.WithRequestSizeLimit(100_000_000), server.WithTimeout(10 * time.Second), server.WithLog(zaptest.NewLogger(t)), } @@ -404,21 +406,97 @@ func p2pRequesterGetter(t *testing.T) getRequesterFunc { } } -type dumbSyncTracer struct { - dumb bool +type syncTracer struct { + dumb bool + receivedItems int + sentItems int } -var _ Tracer = &dumbSyncTracer{} +var _ Tracer = &syncTracer{} -func (tr *dumbSyncTracer) OnDumbSync() { +func (tr *syncTracer) OnDumbSync() { tr.dumb = true } +func (tr *syncTracer) OnRecent(receivedItems, sentItems int) { + tr.receivedItems += receivedItems + tr.sentItems += sentItems +} + +type fakeRecentIterator struct { + items []types.Hash32 + p int +} + +func (it *fakeRecentIterator) Clone() Iterator { + return &fakeRecentIterator{items: it.items} +} + +func (it *fakeRecentIterator) Key() (Ordered, error) { + return it.items[it.p], nil +} + +func (it *fakeRecentIterator) Next() error { + it.p = (it.p + 1) % len(it.items) + return nil +} + +var _ Iterator = &fakeRecentIterator{} + +type fakeRecentSet struct { + ItemStore + timestamps map[types.Hash32]time.Time + clock clockwork.Clock +} + +var _ ItemStore = &fakeRecentSet{} + +var startDate = time.Date(2024, 8, 29, 18, 0, 0, 0, time.UTC) + +func (frs *fakeRecentSet) registerAll(ctx context.Context) error { + frs.timestamps = make(map[types.Hash32]time.Time) + t := startDate + for v, err := range IterItems[types.Hash32](ctx, frs.ItemStore) { + if err != nil { + return err + } + frs.timestamps[v] = t + t = t.Add(time.Second) + } + return nil +} + +func (frs *fakeRecentSet) Add(ctx context.Context, k Ordered) error { + if err := frs.ItemStore.Add(ctx, k); err != nil { + return err + } + h := k.(types.Hash32) + frs.timestamps[h] = frs.clock.Now() + return nil +} + +func (frs *fakeRecentSet) Recent(ctx context.Context, since time.Time) (Iterator, int, error) { + var items []types.Hash32 + for h, err := range IterItems[types.Hash32](ctx, frs.ItemStore) { + if err != nil { + return nil, 0, err + } + if !frs.timestamps[h].Before(since) { + items = append(items, h) + } + } + return &fakeRecentIterator{items: items}, len(items), nil +} + func testWireSync(t *testing.T, getRequester getRequesterFunc) { for _, tc := range []struct { - name string - cfg xorSyncTestConfig - dumb bool + name string + cfg xorSyncTestConfig + dumb bool + opts []RangeSetReconcilerOption + advance time.Duration + sentRecent bool + receivedRecent bool }{ { name: "non-dumb sync", @@ -444,6 +522,25 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) { }, dumb: true, }, + { + name: "recent sync", + cfg: xorSyncTestConfig{ + maxSendRange: 1, + numTestHashes: 1000, + minNumSpecificA: 400, + maxNumSpecificA: 500, + minNumSpecificB: 400, + maxNumSpecificB: 500, + allowReAdd: true, + }, + dumb: false, + opts: []RangeSetReconcilerOption{ + WithRecentTimeSpan(990 * time.Second), + }, + advance: 1000 * time.Second, + sentRecent: true, + receivedRecent: true, + }, { name: "larger sync", cfg: xorSyncTestConfig{ @@ -462,6 +559,7 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) { minNumSpecificB: 4, maxNumSpecificB: 100, }, + dumb: false, }, } { t.Run(tc.name, func(t *testing.T) { @@ -470,11 +568,24 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) { numSpecific int, opts []RangeSetReconcilerOption, ) bool { - var tr dumbSyncTracer - opts = append(opts, - WithTracer(&tr)) + clock := clockwork.NewFakeClockAt(startDate) + // Note that at this point, the items are already added to the sets + // and thus fakeRecentSet.Add is not invoked for them, just underlying + // set's Add method + frsA := &fakeRecentSet{ItemStore: storeA, clock: clock} + require.NoError(t, frsA.registerAll(context.Background())) + storeA = frsA + frsB := &fakeRecentSet{ItemStore: storeB, clock: clock} + require.NoError(t, frsB.registerAll(context.Background())) + storeB = frsB + var tr syncTracer + opts = append(opts, WithTracer(&tr), WithRangeReconcilerClock(clock)) + opts = append(opts, tc.opts...) + opts = opts[0:len(opts):len(opts)] + clock.Advance(tc.advance) withClientServer( - storeA, getRequester, opts, + storeA, getRequester, + opts, func(ctx context.Context, client Requester, srvPeerID p2p.Peer) { nr := RmmeNumRead() nw := RmmeNumWritten() @@ -490,6 +601,8 @@ func testWireSync(t *testing.T, getRequester getRequesterFunc) { RmmeNumRead()-nr, RmmeNumWritten()-nw) }) require.Equal(t, tc.dumb, tr.dumb, "dumb sync") + require.Equal(t, tc.receivedRecent, tr.receivedItems > 0) + require.Equal(t, tc.sentRecent, tr.sentItems > 0) return true }) }) diff --git a/sync2/hashsync/interface.go b/sync2/hashsync/interface.go index d7445177f9..acde320a84 100644 --- a/sync2/hashsync/interface.go +++ b/sync2/hashsync/interface.go @@ -3,6 +3,7 @@ package hashsync import ( "context" "io" + "time" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/p2p" @@ -65,6 +66,10 @@ type ItemStore interface { Copy() ItemStore // Has returns true if the specified key is present in ItemStore Has(ctx context.Context, k Ordered) (bool, error) + // Recent returns an Iterator that yields the items added since the specified + // timestamp. Some ItemStore implementations may not have Recent implemented, in + // which case it should return an error. + Recent(ctx context.Context, since time.Time) (Iterator, int, error) } type Requester interface { diff --git a/sync2/hashsync/mocks_test.go b/sync2/hashsync/mocks_test.go index b8a011e764..6902d22b30 100644 --- a/sync2/hashsync/mocks_test.go +++ b/sync2/hashsync/mocks_test.go @@ -13,6 +13,7 @@ import ( context "context" io "io" reflect "reflect" + time "time" types "github.com/spacemeshos/go-spacemesh/common/types" p2p "github.com/spacemeshos/go-spacemesh/p2p" @@ -374,6 +375,46 @@ func (c *MockItemStoreMinCall) DoAndReturn(f func(context.Context) (Iterator, er return c } +// Recent mocks base method. +func (m *MockItemStore) Recent(ctx context.Context, since time.Time) (Iterator, int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Recent", ctx, since) + ret0, _ := ret[0].(Iterator) + ret1, _ := ret[1].(int) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// Recent indicates an expected call of Recent. +func (mr *MockItemStoreMockRecorder) Recent(ctx, since any) *MockItemStoreRecentCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Recent", reflect.TypeOf((*MockItemStore)(nil).Recent), ctx, since) + return &MockItemStoreRecentCall{Call: call} +} + +// MockItemStoreRecentCall wrap *gomock.Call +type MockItemStoreRecentCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockItemStoreRecentCall) Return(arg0 Iterator, arg1 int, arg2 error) *MockItemStoreRecentCall { + c.Call = c.Call.Return(arg0, arg1, arg2) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockItemStoreRecentCall) Do(f func(context.Context, time.Time) (Iterator, int, error)) *MockItemStoreRecentCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockItemStoreRecentCall) DoAndReturn(f func(context.Context, time.Time) (Iterator, int, error)) *MockItemStoreRecentCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // SplitRange mocks base method. func (m *MockItemStore) SplitRange(ctx context.Context, preceding Iterator, x, y Ordered, count int) (SplitInfo, error) { m.ctrl.T.Helper() diff --git a/sync2/hashsync/rangesync.go b/sync2/hashsync/rangesync.go index a5fba67f14..070984d061 100644 --- a/sync2/hashsync/rangesync.go +++ b/sync2/hashsync/rangesync.go @@ -4,13 +4,200 @@ import ( "context" "errors" "fmt" + "iter" "reflect" "slices" "strings" + "time" + "github.com/jonboulle/clockwork" "go.uber.org/zap" ) +// Interactions: +// +// A: empty set; B: empty set +// A -> B: +// +// EmptySet +// EndRound +// +// B -> A: +// +// Done +// +// A: empty set; B: non-empty set +// A -> B: +// +// EmptySet +// EndRound +// +// B -> A: +// +// ItemBatch +// ItemBatch +// ... +// +// A -> B: +// +// Done +// +// A: small set (< maxSendRange); B: non-empty set +// A -> B: +// +// ItemBatch +// ItemBatch +// ... +// RangeContents [x, y) +// EndRound +// +// B -> A: +// +// ItemBatch +// ItemBatch +// ... +// +// A -> B: +// +// Done +// +// A: large set; B: non-empty set; maxDiff < 0 +// A -> B: +// +// Fingerprint [x, y) +// EndRound +// +// B -> A: +// +// Fingerprint [x, m) +// Fingerprint [m, y) +// EndRound +// +// A -> B: +// +// ItemBatch +// ItemBatch +// ... +// RangeContents [x, m) +// EndRound +// +// A -> B: +// +// Done +// +// A: large set; B: non-empty set; maxDiff >= 0; differenceMetric <= maxDiff +// NOTE: Sample includes fingerprint +// A -> B: +// +// Sample [x, y) +// EndRound +// +// B -> A: +// +// Fingerprint [x, m) +// Fingerprint [m, y) +// EndRound +// +// A -> B: +// +// ItemBatch +// ItemBatch +// ... +// RangeContents [x, m) +// EndRound +// +// A -> B: +// +// Done +// +// A: large set; B: non-empty set; maxDiff >= 0; differenceMetric > maxDiff +// A -> B: +// +// Sample [x, y) +// EndRound +// +// B -> A: +// +// ItemBatch +// ItemBatch +// ... +// RangeContents [x, y) +// EndRound +// +// A -> B: +// +// Done +// +// A: large set; B: non-empty set; sync priming; maxDiff >= 0; differenceMetric <= maxDiff (after priming) +// A -> B: +// +// ItemBatch +// ItemBatch +// ... +// Recent +// EndRound +// +// B -> A: +// +// ItemBatch +// ItemBatch +// ... +// Sample [x, y) +// EndRound +// +// A -> B: +// +// Fingerprint [x, m) +// Fingerprint [m, y) +// EndRound +// +// B -> A: +// +// ItemBatch +// ItemBatch +// ... +// RangeContents [x, m) +// EndRound +// +// A -> B: +// +// Done +// +// A: large set; B: non-empty set; sync priming; maxDiff < 0 +// A -> B: +// +// ItemBatch +// ItemBatch +// ... +// Recent +// EndRound +// +// B -> A: +// +// ItemBatch +// ItemBatch +// ... +// Fingerprint [x, y) +// EndRound +// +// A -> B: +// +// Fingerprint [x, m) +// Fingerprint [m, y) +// EndRound +// +// B -> A: +// +// ItemBatch +// ItemBatch +// ... +// RangeContents [x, m) +// EndRound +// +// A -> B: +// +// Done + const ( DefaultMaxSendRange = 16 DefaultItemChunkSize = 16 @@ -30,6 +217,7 @@ const ( MessageTypeItemBatch MessageTypeProbe MessageTypeSample + MessageTypeRecent ) var messageTypes = []string{ @@ -42,6 +230,7 @@ var messageTypes = []string{ "itemBatch", "probe", "sample", + "recent", } func (mtype MessageType) String() string { @@ -58,6 +247,7 @@ type SyncMessage interface { Fingerprint() any Count() int Keys() []Ordered + Since() time.Time } func formatID(v any) string { @@ -134,6 +324,8 @@ type Conduit interface { // SendSample sends a set sample. If 'it' is not nil, the corresponding items are // included in the sample SendSample(x, y Ordered, fingerprint any, count, sampleSize int, it Iterator) error + // SendRecent sends recent items + SendRecent(since time.Time) error // ShortenKey shortens the key for minhash calculation ShortenKey(k Ordered) Ordered } @@ -169,23 +361,33 @@ func WithMaxDiff(d float64) RangeSetReconcilerOption { } // TODO: RangeSetReconciler should sit in a separate package -// and WithRangeSyncLogger should be named WithLogger -func WithRangeSyncLogger(log *zap.Logger) RangeSetReconcilerOption { +// and WithRangeReconcilerLogger should be named WithLogger +func WithRangeReconcilerLogger(log *zap.Logger) RangeSetReconcilerOption { return func(r *RangeSetReconciler) { r.log = log } } +// WithRecentTimeSpan specifies the time span for recent items +func WithRecentTimeSpan(d time.Duration) RangeSetReconcilerOption { + return func(r *RangeSetReconciler) { + r.recentTimeSpan = d + } +} + // Tracer tracks the reconciliation process type Tracer interface { // OnDumbSync is called when the difference metric exceeds maxDiff and dumb // reconciliation process is used OnDumbSync() + // OnRecent is invoked when Recent message is received + OnRecent(receivedItems, sentItems int) } type nullTracer struct{} -func (t nullTracer) OnDumbSync() {} +func (t nullTracer) OnDumbSync() {} +func (t nullTracer) OnRecent(int, int) {} // WithTracer specifies a tracer for RangeSetReconciler func WithTracer(t Tracer) RangeSetReconcilerOption { @@ -194,6 +396,13 @@ func WithTracer(t Tracer) RangeSetReconcilerOption { } } +// TBD: rename +func WithRangeReconcilerClock(c clockwork.Clock) RangeSetReconcilerOption { + return func(r *RangeSetReconciler) { + r.clock = c + } +} + type ProbeResult struct { FP any Count int @@ -201,13 +410,15 @@ type ProbeResult struct { } type RangeSetReconciler struct { - is ItemStore - maxSendRange int - itemChunkSize int - sampleSize int - maxDiff float64 - tracer Tracer - log *zap.Logger + is ItemStore + maxSendRange int + itemChunkSize int + sampleSize int + maxDiff float64 + recentTimeSpan time.Duration + tracer Tracer + clock clockwork.Clock + log *zap.Logger } func NewRangeSetReconciler(is ItemStore, opts ...RangeSetReconcilerOption) *RangeSetReconciler { @@ -218,6 +429,7 @@ func NewRangeSetReconciler(is ItemStore, opts ...RangeSetReconcilerOption) *Rang sampleSize: DefaultSampleSize, maxDiff: -1, tracer: nullTracer{}, + clock: clockwork.NewRealClock(), log: zap.NewNop(), } for _, opt := range opts { @@ -255,10 +467,10 @@ func (rsr *RangeSetReconciler) processSubrange(c Conduit, x, y Ordered, info Ran // // so we can't use SendItemsOnly(), instead we use SendItems, // // which includes our items and asks the peer to send any // // items it has in the range. - // if err := c.SendRangeContents(x, y, info.Count); err != nil { + // if err := c.SendItems(info.Count, rsr.itemChunkSize, info.Start); err != nil { // return nil, err // } - // if err := c.SendItems(info.Count, rsr.itemChunkSize, info.Start); err != nil { + // if err := c.SendRangeContents(x, y, info.Count); err != nil { // return nil, err // } case info.Count == 0: @@ -281,27 +493,87 @@ func (rsr *RangeSetReconciler) processSubrange(c Conduit, x, y Ordered, info Ran return nil } +func (rsr *RangeSetReconciler) splitRange( + ctx context.Context, + c Conduit, + preceding Iterator, + count int, + x, y Ordered, +) error { + count = count / 2 + rsr.log.Debug("handleMessage: PRE split range", + HexField("x", x), HexField("y", y), + zap.Int("countArg", count)) + si, err := rsr.is.SplitRange(ctx, preceding, x, y, count) + if err != nil { + return err + } + rsr.log.Debug("handleMessage: split range", + HexField("x", x), HexField("y", y), + zap.Int("countArg", count), + zap.Int("count0", si.Parts[0].Count), + HexField("fp0", si.Parts[0].Fingerprint), + IteratorField("start0", si.Parts[0].Start), + IteratorField("end0", si.Parts[0].End), + zap.Int("count1", si.Parts[1].Count), + HexField("fp1", si.Parts[1].Fingerprint), + IteratorField("start1", si.Parts[1].End), + IteratorField("end1", si.Parts[1].End)) + if err := rsr.processSubrange(c, x, si.Middle, si.Parts[0]); err != nil { + return err + } + // fmt.Fprintf(os.Stderr, "QQQQQ: next=%q\n", qqqqRmmeK(next)) + if err := rsr.processSubrange(c, si.Middle, y, si.Parts[1]); err != nil { + return err + } + return nil +} + +func (rsr *RangeSetReconciler) sendSmallRange( + c Conduit, + count int, + it Iterator, + x, y Ordered, +) error { + if count == 0 { + rsr.log.Debug("handleMessage: empty incoming range", + HexField("x", x), HexField("y", y)) + // fmt.Fprintf(os.Stderr, "small incoming range: %s -> empty range msg\n", msg) + return c.SendEmptyRange(x, y) + } + rsr.log.Debug("handleMessage: send small range", + HexField("x", x), HexField("y", y), + zap.Int("count", count), + zap.Int("maxSendRange", rsr.maxSendRange)) + // fmt.Fprintf(os.Stderr, "small incoming range: %s -> SendItems\n", msg) + if err := c.SendRangeContents(x, y, count); err != nil { + return err + } + _, err := rsr.sendItems(c, count, it, nil) + return err +} + func (rsr *RangeSetReconciler) sendItems( c Conduit, - count, itemChunkSize int, + count int, it Iterator, skipKeys []Ordered, -) error { +) (int, error) { + nSent := 0 skipPos := 0 - for i := 0; i < count; i += itemChunkSize { + for i := 0; i < count; i += rsr.itemChunkSize { // TBD: do not use chunks, just stream the contentkeys var keys []Ordered - n := min(itemChunkSize, count-i) + n := min(rsr.itemChunkSize, count-i) IN_CHUNK: for n > 0 { k, err := it.Key() if err != nil { - return err + return nSent, err } for skipPos < len(skipKeys) { cmp := k.Compare(skipKeys[skipPos]) if cmp == 0 { - // fmt.Fprintf(os.Stderr, "QQQQQ: skip key %s\n", k.(fmt.Stringer).String()) // we can skip this item. Advance skipPos as there are no duplicates skipPos++ continue IN_CHUNK @@ -316,31 +588,35 @@ func (rsr *RangeSetReconciler) sendItems( } keys = append(keys, k) if err := it.Next(); err != nil { - return err + return nSent, err } n-- } if err := c.SendChunk(keys); err != nil { - return err + return nSent, err } + nSent += len(keys) } - return nil + return nSent, nil } +// handleMessage handles incoming messages. Note that the set reconciliation protocol is +// designed to be stateless. func (rsr *RangeSetReconciler) handleMessage( ctx context.Context, c Conduit, preceding Iterator, - msgs []SyncMessage, - msgPos int, + msg SyncMessage, + receivedKeys []Ordered, ) (done bool, err error) { - msg := msgs[msgPos] rsr.log.Debug("handleMessage", IteratorField("preceding", preceding), zap.String("msg", SyncMessageToString(msg))) x := msg.X() y := msg.Y() done = true - if msg.Type() == MessageTypeEmptySet || (msg.Type() == MessageTypeProbe && x == nil && y == nil) { + if msg.Type() == MessageTypeEmptySet || + msg.Type() == MessageTypeRecent || + (msg.Type() == MessageTypeProbe && x == nil && y == nil) { // The peer has no items at all so didn't // even send X & Y (SendEmptySet) it, err := rsr.is.Min(ctx) @@ -362,6 +638,9 @@ func (rsr *RangeSetReconciler) handleMessage( return false, err } } + if msg.Type() == MessageTypeRecent { + rsr.tracer.OnRecent(len(receivedKeys), 0) + } return true, nil } x, err = it.Key() @@ -370,7 +649,7 @@ func (rsr *RangeSetReconciler) handleMessage( } y = x } else if x == nil || y == nil { - return false, errors.New("bad X or Y") + return false, fmt.Errorf("bad X or Y in a message of type %s", msg.Type()) } info, err := rsr.is.GetRangeInfo(ctx, preceding, x, y, -1) if err != nil { @@ -384,6 +663,7 @@ func (rsr *RangeSetReconciler) handleMessage( HexField("fingerprint", info.Fingerprint)) // fmt.Fprintf(os.Stderr, "QQQQQ msg %s %#v fp %v start %#v end %#v count %d\n", msg.Type(), msg, info.Fingerprint, info.Start, info.End, info.Count) + // TODO: do not use done variable switch { case msg.Type() == MessageTypeEmptyRange || msg.Type() == MessageTypeRangeContents || @@ -397,16 +677,7 @@ func (rsr *RangeSetReconciler) handleMessage( done = false rsr.log.Debug("handleMessage: send items", zap.Int("count", info.Count), IteratorField("start", info.Start)) - var skipKeys []Ordered - if msg.Type() == MessageTypeRangeContents { - for _, m := range msgs[msgPos+1:] { - if m.Type() != MessageTypeItemBatch { - break - } - skipKeys = append(skipKeys, m.Keys()...) - } - } - if err := rsr.sendItems(c, info.Count, rsr.itemChunkSize, info.Start, skipKeys); err != nil { + if _, err := rsr.sendItems(c, info.Count, info.Start, receivedKeys); err != nil { return false, err } } else { @@ -432,6 +703,37 @@ func (rsr *RangeSetReconciler) handleMessage( return false, err } return true, nil + case msg.Type() == MessageTypeRecent: + it, count, err := rsr.is.Recent(ctx, msg.Since()) + if err != nil { + return false, fmt.Errorf("error getting recent items: %w", err) + } + nSent := 0 + if count != 0 { + // Do not send back recent items that were received + if nSent, err = rsr.sendItems(c, count, it, receivedKeys); err != nil { + return false, err + } + } + rsr.log.Debug("handled recent message", + zap.Int("receivedCount", len(receivedKeys)), + zap.Int("sentCount", nSent)) + rsr.tracer.OnRecent(len(receivedKeys), nSent) + // if x == nil { + // // FIXME: code duplication + // it, err := rsr.is.Min(ctx) + // if err != nil { + // return false, err + // } + // if it != nil { + // x, err = it.Key() + // if err != nil { + // return false, err + // } + // y = x + // } + // } + return false, rsr.initiateBounded(ctx, c, x, y, false) case msg.Type() != MessageTypeFingerprint && msg.Type() != MessageTypeSample: return false, fmt.Errorf("unexpected message type %s", msg.Type()) case fingerprintEqual(info.Fingerprint, msg.Fingerprint()): @@ -444,84 +746,29 @@ func (rsr *RangeSetReconciler) handleMessage( return false, err } if 1-pr.Sim > rsr.maxDiff { - done = false rsr.tracer.OnDumbSync() rsr.log.Debug("handleMessage: maxDiff exceeded, sending full range", zap.Float64("sim", pr.Sim), zap.Float64("diff", 1-pr.Sim), zap.Float64("maxDiff", rsr.maxDiff)) - if err := c.SendRangeContents(x, y, info.Count); err != nil { + if _, err := rsr.sendItems(c, info.Count, info.Start, nil); err != nil { return false, err } - if err := rsr.sendItems(c, info.Count, rsr.itemChunkSize, info.Start, nil); err != nil { - return false, err - } - break + return false, c.SendRangeContents(x, y, info.Count) } rsr.log.Debug("handleMessage: acceptable maxDiff, proceeding with sync", zap.Float64("sim", pr.Sim), zap.Float64("diff", 1-pr.Sim), zap.Float64("maxDiff", rsr.maxDiff)) - fallthrough + if info.Count > rsr.maxSendRange { + return false, rsr.splitRange(ctx, c, preceding, info.Count, x, y) + } + return false, rsr.sendSmallRange(c, info.Count, info.Start, x, y) // case (info.Count+1)/2 <= rsr.maxSendRange: case info.Count <= rsr.maxSendRange: - // The range differs from the peer's version of it, but the it - // is small enough (or would be small enough after split) or - // empty on our side - done = false - if info.Count != 0 { - rsr.log.Debug("handleMessage: send small range", - HexField("x", x), HexField("y", y), zap.Int("count", info.Count)) - // fmt.Fprintf(os.Stderr, "small incoming range: %s -> SendItems\n", msg) - if err := c.SendRangeContents(x, y, info.Count); err != nil { - return false, err - } - if err := rsr.sendItems(c, info.Count, rsr.itemChunkSize, info.Start, nil); err != nil { - return false, err - } - } else { - rsr.log.Debug("handleMessage: empty incoming range", - HexField("x", x), HexField("y", y)) - // fmt.Fprintf(os.Stderr, "small incoming range: %s -> empty range msg\n", msg) - if err := c.SendEmptyRange(x, y); err != nil { - return false, err - } - } + return false, rsr.sendSmallRange(c, info.Count, info.Start, x, y) default: - // Need to split the range. - // Note that there's no special handling for rollover ranges with x >= y - // These need to be handled by ItemStore.GetRangeInfo() - // TODO: instead of count-based split, use split between X and Y with - // lower bits set to zero to avoid SQL queries on the edges - count := (info.Count + 1) / 2 - rsr.log.Debug("handleMessage: PRE split range", - HexField("x", x), HexField("y", y), - zap.Int("countArg", count)) - si, err := rsr.is.SplitRange(ctx, preceding, x, y, count) - if err != nil { - return false, err - } - rsr.log.Debug("handleMessage: split range", - HexField("x", x), HexField("y", y), - zap.Int("countArg", count), - zap.Int("count0", si.Parts[0].Count), - HexField("fp0", si.Parts[0].Fingerprint), - IteratorField("start0", si.Parts[0].Start), - IteratorField("end0", si.Parts[0].End), - zap.Int("count1", si.Parts[1].Count), - HexField("fp1", si.Parts[1].Fingerprint), - IteratorField("start1", si.Parts[1].End), - IteratorField("end1", si.Parts[1].End)) - if err := rsr.processSubrange(c, x, si.Middle, si.Parts[0]); err != nil { - return false, err - } - // fmt.Fprintf(os.Stderr, "QQQQQ: next=%q\n", qqqqRmmeK(next)) - if err := rsr.processSubrange(c, si.Middle, y, si.Parts[1]); err != nil { - return false, err - } - // fmt.Fprintf(os.Stderr, "normal: split X %s - middle %s - Y %s:\n %s", - // msg.X(), middle, msg.Y(), msg) - done = false + return false, rsr.splitRange(ctx, c, preceding, info.Count, x, y) } return done, nil } @@ -542,49 +789,66 @@ func (rsr *RangeSetReconciler) Initiate(ctx context.Context, c Conduit) error { } func (rsr *RangeSetReconciler) InitiateBounded(ctx context.Context, c Conduit, x, y Ordered) error { - // QQQQQ: TBD: add a possibility to send a sample for probing. - // When difference is too high, the remote side should reply with its whole [x, y) range + haveRecent := rsr.recentTimeSpan > 0 + if err := rsr.initiateBounded(ctx, c, x, y, haveRecent); err != nil { + return err + } + return c.SendEndRound() + +} +func (rsr *RangeSetReconciler) initiateBounded(ctx context.Context, c Conduit, x, y Ordered, haveRecent bool) error { rsr.log.Debug("initiate", HexField("x", x), HexField("y", y)) if x == nil { rsr.log.Debug("initiate: send empty set") - if err := c.SendEmptySet(); err != nil { + return c.SendEmptySet() + } + info, err := rsr.is.GetRangeInfo(ctx, nil, x, y, -1) + if err != nil { + return fmt.Errorf("get range info: %w", err) + } + switch { + case info.Count == 0: + panic("empty full min-min range") + case info.Count < rsr.maxSendRange: + rsr.log.Debug("initiate: send whole range", zap.Int("count", info.Count)) + if _, err := rsr.sendItems(c, info.Count, info.Start, nil); err != nil { return err } - } else { - info, err := rsr.is.GetRangeInfo(ctx, nil, x, y, -1) + return c.SendRangeContents(x, y, info.Count) + case haveRecent: + rsr.log.Debug("initiate: checking recent items") + since := rsr.clock.Now().Add(-rsr.recentTimeSpan) + it, count, err := rsr.is.Recent(ctx, since) if err != nil { - return err + return fmt.Errorf("error getting recent items: %w", err) } - switch { - case info.Count == 0: - panic("empty full min-min range") - case info.Count < rsr.maxSendRange: - rsr.log.Debug("initiate: send whole range", zap.Int("count", info.Count)) - if err := c.SendRangeContents(x, y, info.Count); err != nil { - return err - } - if err := rsr.sendItems(c, info.Count, rsr.itemChunkSize, info.Start, nil); err != nil { - return err - } - case rsr.maxDiff >= 0: - // Use minhash to check if syncing this range is feasible - rsr.log.Debug("initiate: send sample", - zap.Int("count", info.Count), - zap.Int("sampleSize", rsr.sampleSize)) - if err := c.SendSample(x, y, info.Fingerprint, info.Count, rsr.sampleSize, info.Start); err != nil { - return err - } - default: - rsr.log.Debug("initiate: send fingerprint", zap.Int("count", info.Count)) - if err := c.SendFingerprint(x, y, info.Fingerprint, info.Count); err != nil { + if count != 0 { + rsr.log.Debug("initiate: sending recent items", zap.Int("count", count)) + if n, err := rsr.sendItems(c, count, it, nil); err != nil { return err + } else if n != count { + panic("BUG: wrong number of items sent") } + } else { + rsr.log.Debug("initiate: no recent items") } + rsr.tracer.OnRecent(0, count) + // Send Recent message even if there are no recent items, b/c we want to + // receive recent items from the peer, if any. + if err := c.SendRecent(since); err != nil { + return err + } + return nil + case rsr.maxDiff >= 0: + // Use minhash to check if syncing this range is feasible + rsr.log.Debug("initiate: send sample", + zap.Int("count", info.Count), + zap.Int("sampleSize", rsr.sampleSize)) + return c.SendSample(x, y, info.Fingerprint, info.Count, rsr.sampleSize, info.Start) + default: + rsr.log.Debug("initiate: send fingerprint", zap.Int("count", info.Count)) + return c.SendFingerprint(x, y, info.Fingerprint, info.Count) } - if err := c.SendEndRound(); err != nil { - return err - } - return nil } func (rsr *RangeSetReconciler) getMessages(c Conduit) (msgs []SyncMessage, done bool, err error) { @@ -761,13 +1025,15 @@ func (rsr *RangeSetReconciler) Process(ctx context.Context, c Conduit) (done boo return done, nil } done = true - for n, msg := range msgs { + var receivedKeys []Ordered + for _, msg := range msgs { if msg.Type() == MessageTypeItemBatch { for _, k := range msg.Keys() { rsr.log.Debug("Process: add item", HexField("item", k)) if err := rsr.is.Add(ctx, k); err != nil { return false, fmt.Errorf("error adding an item to the store: %w", err) } + receivedKeys = append(receivedKeys, k) } continue } @@ -782,10 +1048,11 @@ func (rsr *RangeSetReconciler) Process(ctx context.Context, c Conduit) (done boo // breaks if we capture the iterator from handleMessage and // pass it to the next handleMessage call (it shouldn't) var msgDone bool - msgDone, err = rsr.handleMessage(ctx, c, nil, msgs, n) + msgDone, err = rsr.handleMessage(ctx, c, nil, msg, receivedKeys) if !msgDone { done = false } + receivedKeys = nil } if err != nil { @@ -810,40 +1077,64 @@ func fingerprintEqual(a, b any) bool { return reflect.DeepEqual(a, b) } -// CollectStoreItems returns the list of items in the given store -func CollectStoreItems[K Ordered](is ItemStore) ([]K, error) { - ctx := context.Background() - var r []K - it, err := is.Min(ctx) - if err != nil { - return nil, err - } - if it == nil { - return nil, nil - } - k, err := it.Key() - if err != nil { - return nil, err - } - info, err := is.GetRangeInfo(ctx, nil, k, k, -1) - if err != nil { - return nil, err - } - it, err = is.Min(ctx) - if err != nil { - return nil, err - } - for n := 0; n < info.Count; n++ { +type IterEntry[T Ordered] struct { + V T + Err error +} + +func IterItems[T Ordered](ctx context.Context, is ItemStore) iter.Seq2[T, error] { + return iter.Seq2[T, error](func(yield func(T, error) bool) { + var empty T + ctx := context.Background() + it, err := is.Min(ctx) + if err != nil { + yield(empty, err) + return + } + if it == nil { + return + } k, err := it.Key() if err != nil { - return nil, err + yield(empty, err) + return } - if k == nil { - // fmt.Fprintf(os.Stderr, "QQQQQ: it: %#v\n", it) - panic("BUG: iterator exausted before Count reached") + info, err := is.GetRangeInfo(ctx, nil, k, k, -1) + if err != nil { + yield(empty, err) + return + } + it, err = is.Min(ctx) + if err != nil { + yield(empty, err) + return + } + for n := 0; n < info.Count; n++ { + k, err := it.Key() + if err != nil { + yield(empty, err) + return + } + if k == nil { + // fmt.Fprintf(os.Stderr, "QQQQQ: it: %#v\n", it) + panic("BUG: iterator exausted before Count reached") + } + yield(k.(T), nil) + if err := it.Next(); err != nil { + yield(empty, err) + return + } + } + }) +} + +// CollectStoreItems returns the list of items in the given store +func CollectStoreItems[T Ordered](ctx context.Context, is ItemStore) (r []T, err error) { + for v, err := range IterItems[T](ctx, is) { + if err != nil { + return nil, err } - r = append(r, k.(K)) - it.Next() + r = append(r, v) } return r, nil } diff --git a/sync2/hashsync/rangesync_test.go b/sync2/hashsync/rangesync_test.go index 0743088bf4..bd50bdf761 100644 --- a/sync2/hashsync/rangesync_test.go +++ b/sync2/hashsync/rangesync_test.go @@ -8,6 +8,7 @@ import ( "slices" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -20,6 +21,7 @@ type rangeMessage struct { fp any count int keys []Ordered + since time.Time } var _ SyncMessage = rangeMessage{} @@ -30,6 +32,7 @@ func (m rangeMessage) Y() Ordered { return m.y } func (m rangeMessage) Fingerprint() any { return m.fp } func (m rangeMessage) Count() int { return m.count } func (m rangeMessage) Keys() []Ordered { return m.keys } +func (m rangeMessage) Since() time.Time { return m.since } func (m rangeMessage) String() string { return SyncMessageToString(m) @@ -165,6 +168,14 @@ func (fc *fakeConduit) SendSample(x, y Ordered, fingerprint any, count, sampleSi return nil } +func (fc *fakeConduit) SendRecent(since time.Time) error { + fc.sendMsg(rangeMessage{ + mtype: MessageTypeRecent, + since: since, + }) + return nil +} + func (fc *fakeConduit) ShortenKey(k Ordered) Ordered { return k } @@ -359,6 +370,10 @@ func (ds *dumbStore) Has(ctx context.Context, k Ordered) (bool, error) { return false, nil } +func (ds *dumbStore) Recent(ctx context.Context, since time.Time) (Iterator, int, error) { + return nil, 0, nil +} + type verifiedStoreIterator struct { t *testing.T knownGood Iterator @@ -587,6 +602,10 @@ func (vs *verifiedStore) Has(ctx context.Context, k Ordered) (bool, error) { return h2, nil } +func (vs *verifiedStore) Recent(ctx context.Context, since time.Time) (Iterator, int, error) { + return nil, 0, nil +} + type storeFactory func(t *testing.T) ItemStore func makeDumbStore(t *testing.T) ItemStore { @@ -614,7 +633,7 @@ func makeStore(t *testing.T, f storeFactory, items string) ItemStore { } func storeItemStr(is ItemStore) string { - ids, err := CollectStoreItems[sampleID](is) + ids, err := CollectStoreItems[sampleID](context.Background(), is) if err != nil { panic("store error") } diff --git a/sync2/hashsync/sync_tree_store.go b/sync2/hashsync/sync_tree_store.go index b7904691dc..2c5cab0e01 100644 --- a/sync2/hashsync/sync_tree_store.go +++ b/sync2/hashsync/sync_tree_store.go @@ -3,6 +3,7 @@ package hashsync import ( "context" "errors" + "time" ) type syncTreeIterator struct { @@ -168,3 +169,7 @@ func (sts *SyncTreeStore) Has(ctx context.Context, k Ordered) (bool, error) { _, found := sts.st.Lookup(k) return found, nil } + +func (sts *SyncTreeStore) Recent(ctx context.Context, since time.Time) (Iterator, int, error) { + return nil, 0, nil +} diff --git a/sync2/hashsync/wire_types.go b/sync2/hashsync/wire_types.go index 57930df388..824ec8fbb3 100644 --- a/sync2/hashsync/wire_types.go +++ b/sync2/hashsync/wire_types.go @@ -3,6 +3,7 @@ package hashsync import ( "cmp" "fmt" + "time" "github.com/spacemeshos/go-scale" "github.com/spacemeshos/go-spacemesh/common/types" @@ -17,6 +18,7 @@ func (*Marker) Y() Ordered { return nil } func (*Marker) Fingerprint() any { return nil } func (*Marker) Count() int { return 0 } func (*Marker) Keys() []Ordered { return nil } +func (*Marker) Since() time.Time { return time.Time{} } // DoneMessage is a SyncMessage that denotes the end of the synchronization. // The peer should stop any further processing after receiving this message. @@ -55,6 +57,7 @@ func (m *EmptyRangeMessage) Y() Ordered { return m.RangeY.ToOrdered() } func (m *EmptyRangeMessage) Fingerprint() any { return nil } func (m *EmptyRangeMessage) Count() int { return 0 } func (m *EmptyRangeMessage) Keys() []Ordered { return nil } +func (m *EmptyRangeMessage) Since() time.Time { return time.Time{} } // FingerprintMessage contains range fingerprint for comparison against the // peer's fingerprint of the range with the same bounds [RangeX, RangeY) @@ -72,6 +75,7 @@ func (m *FingerprintMessage) Y() Ordered { return m.RangeY.ToOrdered() } func (m *FingerprintMessage) Fingerprint() any { return m.RangeFingerprint } func (m *FingerprintMessage) Count() int { return int(m.NumItems) } func (m *FingerprintMessage) Keys() []Ordered { return nil } +func (m *FingerprintMessage) Since() time.Time { return time.Time{} } // RangeContentsMessage denotes a range for which the set of items has been sent. // The peer needs to send back any items it has in the same range bounded @@ -89,6 +93,7 @@ func (m *RangeContentsMessage) Y() Ordered { return m.RangeY.ToOrdered() func (m *RangeContentsMessage) Fingerprint() any { return nil } func (m *RangeContentsMessage) Count() int { return int(m.NumItems) } func (m *RangeContentsMessage) Keys() []Ordered { return nil } +func (m *RangeContentsMessage) Since() time.Time { return time.Time{} } // ItemBatchMessage denotes a batch of items to be added to the peer's set. type ItemBatchMessage struct { @@ -107,6 +112,7 @@ func (m *ItemBatchMessage) Keys() []Ordered { } return r } +func (m *ItemBatchMessage) Since() time.Time { return time.Time{} } // ProbeMessage requests bounded range fingerprint and count from the peer, // along with a minhash sample if fingerprints differ @@ -124,6 +130,7 @@ func (m *ProbeMessage) Y() Ordered { return m.RangeY.ToOrdered() } func (m *ProbeMessage) Fingerprint() any { return m.RangeFingerprint } func (m *ProbeMessage) Count() int { return int(m.SampleSize) } func (m *ProbeMessage) Keys() []Ordered { return nil } +func (m *ProbeMessage) Since() time.Time { return time.Time{} } // MinhashSampleItem represents an item of minhash sample subset type MinhashSampleItem uint32 @@ -174,7 +181,6 @@ func (m *SampleMessage) X() Ordered { return m.RangeX.ToOrdered() } func (m *SampleMessage) Y() Ordered { return m.RangeY.ToOrdered() } func (m *SampleMessage) Fingerprint() any { return m.RangeFingerprint } func (m *SampleMessage) Count() int { return int(m.NumItems) } - func (m *SampleMessage) Keys() []Ordered { r := make([]Ordered, len(m.Sample)) for n, item := range m.Sample { @@ -182,5 +188,22 @@ func (m *SampleMessage) Keys() []Ordered { } return r } +func (m *SampleMessage) Since() time.Time { return time.Time{} } + +// RecentMessage is a SyncMessage that denotes a set of items that have been +// added to the peer's set since the specific point in time. +type RecentMessage struct { + SinceTime uint64 +} + +var _ SyncMessage = &RecentMessage{} + +func (m *RecentMessage) Type() MessageType { return MessageTypeRecent } +func (m *RecentMessage) X() Ordered { return nil } +func (m *RecentMessage) Y() Ordered { return nil } +func (m *RecentMessage) Fingerprint() any { return nil } +func (m *RecentMessage) Count() int { return 0 } +func (m *RecentMessage) Keys() []Ordered { return nil } +func (m *RecentMessage) Since() time.Time { return time.Unix(0, int64(m.SinceTime)) } // TODO: don't do scalegen for empty types diff --git a/sync2/hashsync/wire_types_scale.go b/sync2/hashsync/wire_types_scale.go index 2508250d8c..3edce0cc69 100644 --- a/sync2/hashsync/wire_types_scale.go +++ b/sync2/hashsync/wire_types_scale.go @@ -401,3 +401,26 @@ func (t *SampleMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { } return total, nil } + +func (t *RecentMessage) EncodeScale(enc *scale.Encoder) (total int, err error) { + { + n, err := scale.EncodeCompact64(enc, uint64(t.SinceTime)) + if err != nil { + return total, err + } + total += n + } + return total, nil +} + +func (t *RecentMessage) DecodeScale(dec *scale.Decoder) (total int, err error) { + { + field, n, err := scale.DecodeCompact64(dec) + if err != nil { + return total, err + } + total += n + t.SinceTime = uint64(field) + } + return total, nil +} diff --git a/sync2/hashsync/xorsync_test.go b/sync2/hashsync/xorsync_test.go index 369560ce9b..94dca0e53b 100644 --- a/sync2/hashsync/xorsync_test.go +++ b/sync2/hashsync/xorsync_test.go @@ -57,12 +57,13 @@ type xorSyncTestConfig struct { maxNumSpecificA int minNumSpecificB int maxNumSpecificB int + allowReAdd bool } func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB ItemStore, numSpecific int, opts []RangeSetReconcilerOption) bool) { opts := []RangeSetReconcilerOption{ WithMaxSendRange(cfg.maxSendRange), - WithMaxDiff(0.05), + WithMaxDiff(0.1), } numSpecificA := rand.Intn(cfg.maxNumSpecificA+1-cfg.minNumSpecificA) + cfg.minNumSpecificA numSpecificB := rand.Intn(cfg.maxNumSpecificB+1-cfg.minNumSpecificB) + cfg.minNumSpecificB @@ -76,7 +77,9 @@ func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB for _, h := range sliceA { require.NoError(t, storeA.Add(context.Background(), h)) } - storeA = &catchTransferTwice{t: t, ItemStore: storeA} + if !cfg.allowReAdd { + storeA = &catchTransferTwice{t: t, ItemStore: storeA} + } sliceB := append([]types.Hash32(nil), src[:cfg.numTestHashes-numSpecificB-numSpecificA]...) sliceB = append(sliceB, src[cfg.numTestHashes-numSpecificB:]...) @@ -84,16 +87,18 @@ func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB for _, h := range sliceB { require.NoError(t, storeB.Add(context.Background(), h)) } - storeB = &catchTransferTwice{t: t, ItemStore: storeB} + if !cfg.allowReAdd { + storeB = &catchTransferTwice{t: t, ItemStore: storeB} + } slices.SortFunc(src, func(a, b types.Hash32) int { return a.Compare(b) }) if sync(storeA, storeB, numSpecificA+numSpecificB, opts) { - itemsA, err := CollectStoreItems[types.Hash32](storeA) + itemsA, err := CollectStoreItems[types.Hash32](context.Background(), storeA) require.NoError(t, err) - itemsB, err := CollectStoreItems[types.Hash32](storeB) + itemsB, err := CollectStoreItems[types.Hash32](context.Background(), storeB) require.NoError(t, err) require.Equal(t, itemsA, itemsB) srcKeys := make([]types.Hash32, len(src)) diff --git a/sync2/p2p_test.go b/sync2/p2p_test.go index d8fb13090f..e40fa74b94 100644 --- a/sync2/p2p_test.go +++ b/sync2/p2p_test.go @@ -88,7 +88,8 @@ func TestP2P(t *testing.T) { for _, hsync := range hs { hsync.Stop() - actualItems, err := hashsync.CollectStoreItems[types.Hash32](hsync.ItemStore()) + actualItems, err := hashsync.CollectStoreItems[types.Hash32]( + context.Background(), hsync.ItemStore()) require.NoError(t, err) require.ElementsMatch(t, initialSet, actualItems) } From 1d7f9a75fe06f8f1c1da49c2950cddff251ed11b Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 31 Aug 2024 06:21:58 +0400 Subject: [PATCH 70/76] sync2: fix logging --- sync2/hashsync/multipeer.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sync2/hashsync/multipeer.go b/sync2/hashsync/multipeer.go index 8bc31daf3e..195c94d0c5 100644 --- a/sync2/hashsync/multipeer.go +++ b/sync2/hashsync/multipeer.go @@ -10,7 +10,6 @@ import ( "golang.org/x/sync/errgroup" "github.com/spacemeshos/go-spacemesh/fetch/peers" - "github.com/spacemeshos/go-spacemesh/log" "github.com/spacemeshos/go-spacemesh/p2p" ) @@ -161,7 +160,7 @@ func (mpr *MultiPeerReconciler) probePeers(ctx context.Context, syncPeers []p2p. mpr.logger.Debug("probe peer", zap.Stringer("peer", p)) pr, err := mpr.syncBase.Probe(ctx, p) if err != nil { - log.Warning("error probing the peer", zap.Any("peer", p), zap.Error(err)) + mpr.logger.Warn("error probing the peer", zap.Any("peer", p), zap.Error(err)) if errors.Is(err, context.Canceled) { return s, err } From c97f02381467f679135caf8c86eeab2d32961f76 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 4 Sep 2024 09:49:19 +0400 Subject: [PATCH 71/76] sync2: clean up Remove monoid tree for now and remove test dependencies. Use Go 1.23 range over function feature instead of custom iterator protocol. --- sync2/dbsync/combine_seqs.go | 172 +++ sync2/dbsync/combine_seqs_test.go | 220 ++++ sync2/dbsync/dbitemstore.go | 383 ------ sync2/dbsync/dbitemstore_test.go | 294 ----- sync2/dbsync/dbiter.go | 396 ------ sync2/dbsync/dbiter_test.go | 465 ------- sync2/dbsync/dbseq.go | 205 +++ sync2/dbsync/dbseq_test.go | 289 +++++ sync2/dbsync/dbset.go | 248 ++++ sync2/dbsync/dbset_test.go | 287 +++++ sync2/dbsync/fptree.go | 311 ++--- sync2/dbsync/fptree_test.go | 585 +++------ sync2/dbsync/inmemidstore.go | 72 +- sync2/dbsync/inmemidstore_test.go | 105 +- sync2/dbsync/interface.go | 14 + sync2/dbsync/p2p_test.go | 229 ++-- sync2/dbsync/sqlidstore.go | 56 +- sync2/dbsync/sqlidstore_test.go | 59 +- sync2/dbsync/syncedtable.go | 6 +- sync2/dbsync/syncedtable_test.go | 7 +- sync2/hashsync/handler.go | 374 ------ sync2/hashsync/handler_test.go | 682 ---------- sync2/hashsync/interface.go | 102 -- sync2/hashsync/mocks_test.go | 1110 ----------------- sync2/hashsync/monoid.go | 61 - sync2/hashsync/rangesync_test.go | 985 --------------- sync2/hashsync/sync_tree.go | 903 -------------- sync2/hashsync/sync_tree_store.go | 175 --- sync2/hashsync/sync_tree_test.go | 568 --------- sync2/hashsync/wire_helpers.go | 70 -- sync2/hashsync/wire_types.go | 209 ---- sync2/hashsync/xorsync.go | 59 - sync2/hashsync/xorsync_test.go | 130 -- sync2/multipeer/delim.go | 22 + sync2/multipeer/delim_test.go | 103 ++ sync2/multipeer/interface.go | 46 + sync2/multipeer/mocks_test.go | 572 +++++++++ sync2/{hashsync => multipeer}/multipeer.go | 9 +- .../{hashsync => multipeer}/multipeer_test.go | 29 +- sync2/{hashsync => multipeer}/setsyncbase.go | 58 +- .../setsyncbase_test.go | 140 ++- sync2/{hashsync => multipeer}/split_sync.go | 23 +- .../split_sync_test.go | 61 +- sync2/{hashsync => multipeer}/sync_queue.go | 20 +- .../sync_queue_test.go | 4 +- sync2/p2p.go | 53 +- sync2/p2p_test.go | 34 +- sync2/rangesync/dumbset.go | 271 ++++ sync2/rangesync/interface.go | 62 + sync2/{hashsync => rangesync}/log.go | 33 +- sync2/rangesync/mocks/mocks.go | 460 +++++++ sync2/rangesync/p2p.go | 107 ++ sync2/rangesync/p2p_test.go | 358 ++++++ sync2/{hashsync => rangesync}/rangesync.go | 385 +++--- sync2/rangesync/rangesync_test.go | 612 +++++++++ sync2/rangesync/wire_conduit.go | 232 ++++ sync2/rangesync/wire_conduit_test.go | 315 +++++ sync2/rangesync/wire_helpers.go | 173 +++ sync2/rangesync/wire_types.go | 213 ++++ .../wire_types_scale.go | 8 +- sync2/types/types.go | 160 +++ sync2/types/types_test.go | 101 ++ 62 files changed, 6185 insertions(+), 8310 deletions(-) create mode 100644 sync2/dbsync/combine_seqs.go create mode 100644 sync2/dbsync/combine_seqs_test.go delete mode 100644 sync2/dbsync/dbitemstore.go delete mode 100644 sync2/dbsync/dbitemstore_test.go delete mode 100644 sync2/dbsync/dbiter.go delete mode 100644 sync2/dbsync/dbiter_test.go create mode 100644 sync2/dbsync/dbseq.go create mode 100644 sync2/dbsync/dbseq_test.go create mode 100644 sync2/dbsync/dbset.go create mode 100644 sync2/dbsync/dbset_test.go create mode 100644 sync2/dbsync/interface.go delete mode 100644 sync2/hashsync/handler.go delete mode 100644 sync2/hashsync/handler_test.go delete mode 100644 sync2/hashsync/interface.go delete mode 100644 sync2/hashsync/mocks_test.go delete mode 100644 sync2/hashsync/monoid.go delete mode 100644 sync2/hashsync/rangesync_test.go delete mode 100644 sync2/hashsync/sync_tree.go delete mode 100644 sync2/hashsync/sync_tree_store.go delete mode 100644 sync2/hashsync/sync_tree_test.go delete mode 100644 sync2/hashsync/wire_helpers.go delete mode 100644 sync2/hashsync/wire_types.go delete mode 100644 sync2/hashsync/xorsync.go delete mode 100644 sync2/hashsync/xorsync_test.go create mode 100644 sync2/multipeer/delim.go create mode 100644 sync2/multipeer/delim_test.go create mode 100644 sync2/multipeer/interface.go create mode 100644 sync2/multipeer/mocks_test.go rename sync2/{hashsync => multipeer}/multipeer.go (97%) rename sync2/{hashsync => multipeer}/multipeer_test.go (90%) rename sync2/{hashsync => multipeer}/setsyncbase.go (63%) rename sync2/{hashsync => multipeer}/setsyncbase_test.go (58%) rename sync2/{hashsync => multipeer}/split_sync.go (90%) rename sync2/{hashsync => multipeer}/split_sync_test.go (75%) rename sync2/{hashsync => multipeer}/sync_queue.go (84%) rename sync2/{hashsync => multipeer}/sync_queue_test.go (97%) create mode 100644 sync2/rangesync/dumbset.go create mode 100644 sync2/rangesync/interface.go rename sync2/{hashsync => rangesync}/log.go (61%) create mode 100644 sync2/rangesync/mocks/mocks.go create mode 100644 sync2/rangesync/p2p.go create mode 100644 sync2/rangesync/p2p_test.go rename sync2/{hashsync => rangesync}/rangesync.go (77%) create mode 100644 sync2/rangesync/rangesync_test.go create mode 100644 sync2/rangesync/wire_conduit.go create mode 100644 sync2/rangesync/wire_conduit_test.go create mode 100644 sync2/rangesync/wire_helpers.go create mode 100644 sync2/rangesync/wire_types.go rename sync2/{hashsync => rangesync}/wire_types_scale.go (96%) create mode 100644 sync2/types/types.go create mode 100644 sync2/types/types_test.go diff --git a/sync2/dbsync/combine_seqs.go b/sync2/dbsync/combine_seqs.go new file mode 100644 index 0000000000..12428c92c1 --- /dev/null +++ b/sync2/dbsync/combine_seqs.go @@ -0,0 +1,172 @@ +package dbsync + +import ( + "iter" + "slices" + + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +type generator struct { + nextFn func() (types.Ordered, error, bool) + stop func() + k types.Ordered + err error + done bool +} + +func gen(seq types.Seq) *generator { + var g generator + g.nextFn, g.stop = iter.Pull2(iter.Seq2[types.Ordered, error](seq)) + return &g +} + +func (g *generator) next() (k types.Ordered, err error, ok bool) { + if g.done { + return nil, nil, false + } + if g.k != nil || g.err != nil { + k = g.k + err = g.err + g.k = nil + g.err = nil + return k, err, true + } + return g.nextFn() +} + +func (g *generator) peek() (k types.Ordered, err error, ok bool) { + if !g.done && g.k == nil && g.err == nil { + g.k, g.err, ok = g.nextFn() + g.done = !ok + } + if g.done { + return nil, nil, false + } + return g.k, g.err, true +} + +type combinedSeq struct { + gens []*generator + wrapped []*generator +} + +// combineSeqs combines multiple ordered sequences into one, returning the smallest +// current key among all iterators at each step. +func combineSeqs(startingPoint types.Ordered, seqs ...types.Seq) types.Seq { + return func(yield func(types.Ordered, error) bool) { + var c combinedSeq + if err := c.begin(startingPoint, seqs); err != nil { + yield(nil, err) + return + } + c.iterate(yield) + } +} + +func (c *combinedSeq) begin(startingPoint types.Ordered, seqs []types.Seq) error { + for _, seq := range seqs { + g := gen(seq) + k, err, ok := g.peek() + if !ok { + continue + } + if err != nil { + return err + } + if startingPoint != nil && k.Compare(startingPoint) < 0 { + c.wrapped = append(c.wrapped, g) + } else { + c.gens = append(c.gens, g) + } + } + if len(c.gens) == 0 { + // all iterators wrapped around + c.gens = c.wrapped + c.wrapped = nil + } + return nil +} + +func (c *combinedSeq) aheadGen() (ahead *generator, aheadIdx int, err error) { + // remove any exhausted generators + j := 0 + for i := range c.gens { + _, _, ok := c.gens[i].peek() + if ok { + c.gens[j] = c.gens[i] + j++ + } + } + c.gens = c.gens[:j] + // if all the generators ha + if len(c.gens) == 0 { + if len(c.wrapped) == 0 { + return nil, 0, nil + } + c.gens = c.wrapped + c.wrapped = nil + } + ahead = c.gens[0] + aheadIdx = 0 + aK, err, _ := ahead.peek() + if err != nil { + return nil, 0, err + } + for i := 1; i < len(c.gens); i++ { + curK, err, _ := c.gens[i].peek() + if err != nil { + return nil, 0, err + } + if curK != nil { + if curK.Compare(aK) < 0 { + ahead = c.gens[i] + aheadIdx = i + aK = curK + } + } + } + return ahead, aheadIdx, nil +} + +func (c *combinedSeq) iterate(yield func(types.Ordered, error) bool) { + for { + g, idx, err := c.aheadGen() + if err != nil { + yield(nil, err) + return + } + if g == nil { + break + } + k, err, ok := g.next() + if err != nil { + yield(nil, err) + return + } + if !ok { + c.gens = slices.Delete(c.gens, idx, idx+1) + continue + } + if !yield(k, nil) { + break + } + newK, err, ok := g.peek() + if !ok { + // if this iterator is exhausted, it'll be removed by the + // next aheadGen call + continue + } + if err != nil { + yield(nil, err) + return + } + if k.Compare(newK) >= 0 { + // the iterator has wrapped around, move it to the wrapped + // list which will be used after all the iterators have + // wrapped around + c.wrapped = append(c.wrapped, g) + c.gens = slices.Delete(c.gens, idx, idx+1) + } + } +} diff --git a/sync2/dbsync/combine_seqs_test.go b/sync2/dbsync/combine_seqs_test.go new file mode 100644 index 0000000000..1e7572306f --- /dev/null +++ b/sync2/dbsync/combine_seqs_test.go @@ -0,0 +1,220 @@ +package dbsync + +import ( + "errors" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +var seqTestErr = errors.New("test error") + +type fakeSeqItem struct { + k string + err error + stop bool +} + +func mkFakeSeqItem(s string) fakeSeqItem { + switch s { + case "!": + return fakeSeqItem{err: seqTestErr} + case "$": + return fakeSeqItem{stop: true} + default: + return fakeSeqItem{k: s} + } +} + +type fakeSeq []fakeSeqItem + +func mkFakeSeq(s string) fakeSeq { + seq := make(fakeSeq, len(s)) + for n, c := range s { + seq[n] = mkFakeSeqItem(string(c)) + } + return seq +} + +func (seq fakeSeq) items(startIdx int) func(yield func(types.Ordered, error) bool) { + if startIdx > len(seq) { + panic("bad startIdx") + } + return func(yield func(types.Ordered, error) bool) { + if len(seq) == 0 { + return + } + n := startIdx + for { + if n == len(seq) { + n = 0 + } + item := seq[n] + if item.stop || !yield(types.KeyBytes(item.k), item.err) || item.err != nil { + return + } + n++ + } + } +} + +func seqToStr(t *testing.T, seq types.Seq) string { + var sb strings.Builder + var firstK types.Ordered + wrap := 0 + var s string + for k, err := range seq { + if wrap != 0 { + // after wraparound, make sure the sequence is repeated + require.NoError(t, err) + if k.Compare(firstK) == 0 { + // arrived to the element for the second time + return s + } + require.Equal(t, s[wrap], k.(types.KeyBytes)[0]) + wrap++ + continue + } + if err != nil { + require.Nil(t, k) + require.Equal(t, seqTestErr, err) + sb.WriteString("!") // error + return sb.String() + } + require.NotNil(t, k) + if firstK == nil { + firstK = k + } else if k.Compare(firstK) == 0 { + s = sb.String() // wraparound + wrap = 1 + continue + } + sb.Write(k.(types.KeyBytes)) + } + return sb.String() + "$" // stop +} + +func TestCombineSeqs(t *testing.T) { + for _, tc := range []struct { + // In each seq, $ means the end of sequence (lack of $ means wraparound), + // and ! means an error. + seqs []string + indices []int + result string + startingPoint string + }{ + // { + // seqs: []string{"abcd"}, + // indices: []int{0}, + // result: "abcd", + // startingPoint: "a", + // }, + // { + // seqs: []string{"abcd"}, + // indices: []int{0}, + // result: "abcd", + // startingPoint: "c", + // }, + // { + // seqs: []string{"abcd"}, + // indices: []int{2}, + // result: "cdab", + // startingPoint: "c", + // }, + // { + // seqs: []string{"abcd$"}, + // indices: []int{0}, + // result: "abcd$", + // startingPoint: "a", + // }, + // { + // seqs: []string{"abcd!"}, + // indices: []int{0}, + // result: "abcd!", + // startingPoint: "a", + // }, + // { + // seqs: []string{"abcd", "efgh"}, + // indices: []int{0, 0}, + // result: "abcdefgh", + // startingPoint: "a", + // }, + // { + // seqs: []string{"aceg", "bdfh"}, + // indices: []int{0, 0}, + // result: "abcdefgh", + // startingPoint: "a", + // }, + // { + // seqs: []string{"abcd$", "efgh$"}, + // indices: []int{0, 0}, + // result: "abcdefgh$", + // startingPoint: "a", + // }, + // { + // seqs: []string{"aceg$", "bdfh$"}, + // indices: []int{0, 0}, + // result: "abcdefgh$", + // startingPoint: "a", + // }, + // { + // seqs: []string{"abcd!", "efgh!"}, + // indices: []int{0, 0}, + // result: "abcd!", + // startingPoint: "a", + // }, + // { + // seqs: []string{"aceg!", "bdfh!"}, + // indices: []int{0, 0}, + // result: "abcdefg!", + // startingPoint: "a", + // }, + // { + // // wraparound: + // // "ac"+"bdefgh" + // // abcdefgh ==> + // // defghabc + // // starting point is d. + // // Each sequence must either start after the starting point, or + // // all of its elements are considered to be below the starting + // // point. "ac" is considered to be wrapped around initially + // seqs: []string{"ac", "bdefgh"}, + // indices: []int{0, 1}, + // result: "defghabc", + // startingPoint: "d", + // }, + // { + // seqs: []string{"bc", "ae"}, + // indices: []int{0, 1}, + // result: "eabc", + // startingPoint: "d", + // }, + // { + // seqs: []string{"ac", "bfg", "deh"}, + // indices: []int{0, 0, 0}, + // result: "abcdefgh", + // startingPoint: "a", + // }, + { + seqs: []string{"abdefgh", "c"}, + indices: []int{0, 0}, + result: "abcdefgh", + startingPoint: "a", + }, + } { + var seqs []types.Seq + for n, s := range tc.seqs { + seqs = append(seqs, mkFakeSeq(s).items(tc.indices[n])) + } + startingPoint := types.KeyBytes(tc.startingPoint) + combined := combineSeqs(startingPoint, seqs...) + for range 3 { // make sure the sequence is reusable + require.Equal(t, tc.result, seqToStr(t, combined), + "combine %v (indices %v) starting with %s", + tc.seqs, tc.indices, tc.startingPoint) + } + } +} diff --git a/sync2/dbsync/dbitemstore.go b/sync2/dbsync/dbitemstore.go deleted file mode 100644 index 9ab86985ad..0000000000 --- a/sync2/dbsync/dbitemstore.go +++ /dev/null @@ -1,383 +0,0 @@ -package dbsync - -import ( - "context" - "errors" - "fmt" - "sync" - "time" - - "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sync2/hashsync" -) - -type DBItemStore struct { - loadMtx sync.Mutex - db sql.Database - ft *fpTree - st *SyncedTable - snapshot *SyncedTableSnapshot - dbStore *dbBackedStore - keyLen int - maxDepth int -} - -var _ hashsync.ItemStore = &DBItemStore{} - -func NewDBItemStore( - db sql.Database, - st *SyncedTable, - keyLen, maxDepth int, -) *DBItemStore { - // var np nodePool - // dbStore := newDBBackedStore(db, iterQuery, keyLen) - return &DBItemStore{ - db: db, - // ft: newFPTree(&np, dbStore, keyLen, maxDepth), - st: st, - keyLen: keyLen, - maxDepth: maxDepth, - } -} - -func (d *DBItemStore) decodeID(stmt *sql.Statement) bool { - id := make(KeyBytes, d.keyLen) // TODO: don't allocate new ID - stmt.ColumnBytes(0, id[:]) - d.ft.addStoredHash(id) - return true -} - -func (d *DBItemStore) EnsureLoaded(ctx context.Context) error { - d.loadMtx.Lock() - defer d.loadMtx.Unlock() - if d.ft != nil { - return nil - } - db := ContextSQLExec(ctx, d.db) - var err error - d.snapshot, err = d.st.snapshot(db) - if err != nil { - return fmt.Errorf("error taking snapshot: %w", err) - } - var np nodePool - d.dbStore = newDBBackedStore(db, d.snapshot, d.keyLen) - d.ft = newFPTree(&np, d.dbStore, d.keyLen, d.maxDepth) - return d.snapshot.loadIDs(db, d.decodeID) -} - -// Add implements hashsync.ItemStore. -func (d *DBItemStore) Add(ctx context.Context, k hashsync.Ordered) error { - if err := d.EnsureLoaded(ctx); err != nil { - return err - } - has, err := d.Has(ctx, k) // TODO: this check shouldn't be needed - if has || err != nil { - return err - } - return d.ft.addHash(k.(KeyBytes)) -} - -// GetRangeInfo implements hashsync.ItemStore. -func (d *DBItemStore) GetRangeInfo( - ctx context.Context, - preceding hashsync.Iterator, - x, y hashsync.Ordered, - count int, -) (hashsync.RangeInfo, error) { - if err := d.EnsureLoaded(ctx); err != nil { - return hashsync.RangeInfo{}, err - } - fpr, err := d.ft.fingerprintInterval(ctx, x.(KeyBytes), y.(KeyBytes), count) - if err != nil { - return hashsync.RangeInfo{}, err - } - return hashsync.RangeInfo{ - Fingerprint: fpr.fp, - Count: int(fpr.count), - Start: fpr.start, - End: fpr.end, - }, nil -} - -func (d *DBItemStore) SplitRange( - ctx context.Context, - preceding hashsync.Iterator, - x, y hashsync.Ordered, - count int, -) ( - hashsync.SplitInfo, - error, -) { - if count <= 0 { - panic("BUG: bad split count") - } - - if err := d.EnsureLoaded(ctx); err != nil { - return hashsync.SplitInfo{}, err - } - - sr, err := d.ft.easySplit(ctx, x.(KeyBytes), y.(KeyBytes), count) - if err == nil { - // fmt.Fprintf(os.Stderr, "QQQQQ: fast split, middle: %s\n", sr.middle.String()) - return hashsync.SplitInfo{ - Parts: [2]hashsync.RangeInfo{ - { - Fingerprint: sr.part0.fp, - Count: int(sr.part0.count), - Start: sr.part0.start, - End: sr.part0.end, - }, - { - Fingerprint: sr.part1.fp, - Count: int(sr.part1.count), - Start: sr.part1.start, - End: sr.part1.end, - }, - }, - Middle: sr.middle, - }, nil - } - - if !errors.Is(err, errEasySplitFailed) { - return hashsync.SplitInfo{}, err - } - - // fmt.Fprintf(os.Stderr, "QQQQQ: slow split x %s y %s\n", x.(fmt.Stringer), y.(fmt.Stringer)) - part0, err := d.GetRangeInfo(ctx, preceding, x, y, count) - if err != nil { - return hashsync.SplitInfo{}, err - } - if part0.Count == 0 { - return hashsync.SplitInfo{}, errors.New("can't split empty range") - } - middle, err := part0.End.Key() - if err != nil { - return hashsync.SplitInfo{}, err - } - part1, err := d.GetRangeInfo(ctx, part0.End.Clone(), middle, y, -1) - if err != nil { - return hashsync.SplitInfo{}, err - } - return hashsync.SplitInfo{ - Parts: [2]hashsync.RangeInfo{part0, part1}, - Middle: middle, - }, nil -} - -// Min implements hashsync.ItemStore. -func (d *DBItemStore) Min(ctx context.Context) (hashsync.Iterator, error) { - if err := d.EnsureLoaded(ctx); err != nil { - return nil, err - } - if d.ft.count() == 0 { - return nil, nil - } - it := d.ft.start(ctx) - if _, err := it.Key(); err != nil { - return nil, err - } - return it, nil -} - -func (d *DBItemStore) Advance(ctx context.Context) error { - d.loadMtx.Lock() - d.loadMtx.Unlock() - if d.ft == nil { - // FIXME - panic("BUG: can't advance the DBItemStore before it's loaded") - } - oldSnapshot := d.snapshot - var err error - d.snapshot, err = d.st.snapshot(ContextSQLExec(ctx, d.db)) - if err != nil { - return fmt.Errorf("error taking snapshot: %w", err) - } - d.dbStore.setSnapshot(d.snapshot) - return d.snapshot.loadIDsSince(d.db, oldSnapshot, d.decodeID) -} - -// Copy implements hashsync.ItemStore. -func (d *DBItemStore) Copy() hashsync.ItemStore { - d.loadMtx.Lock() - d.loadMtx.Unlock() - if d.ft == nil { - // FIXME - panic("BUG: can't copy the DBItemStore before it's loaded") - } - return &DBItemStore{ - db: d.db, - ft: d.ft.clone(), - st: d.st, - keyLen: d.keyLen, - maxDepth: d.maxDepth, - } -} - -// Has implements hashsync.ItemStore. -func (d *DBItemStore) Has(ctx context.Context, k hashsync.Ordered) (bool, error) { - if err := d.EnsureLoaded(ctx); err != nil { - return false, err - } - if d.ft.count() == 0 { - return false, nil - } - // TODO: should often be able to avoid querying the database if we check the key - // against the fptree - it := d.ft.iter(ctx, k.(KeyBytes)) - itK, err := it.Key() - if err != nil { - return false, err - } - return itK.Compare(k) == 0, nil -} - -// Recent implements hashsync.ItemStore. -func (d *DBItemStore) Recent(ctx context.Context, since time.Time) (hashsync.Iterator, int, error) { - return d.dbStore.iterSince(ctx, make(KeyBytes, d.keyLen), since.UnixNano()) -} - -// TODO: get rid of ItemStoreAdapter, it shouldn't be needed -type ItemStoreAdapter struct { - s *DBItemStore -} - -var _ hashsync.ItemStore = &ItemStoreAdapter{} - -func NewItemStoreAdapter(s *DBItemStore) *ItemStoreAdapter { - return &ItemStoreAdapter{s: s} -} - -func (a *ItemStoreAdapter) wrapIterator(it hashsync.Iterator) hashsync.Iterator { - if it == nil { - return nil - } - return &iteratorAdapter{it: it} -} - -// Add implements hashsync.ItemStore. -func (a *ItemStoreAdapter) Add(ctx context.Context, k hashsync.Ordered) error { - h := k.(types.Hash32) - return a.s.Add(ctx, KeyBytes(h[:])) -} - -// Copy implements hashsync.ItemStore. -func (a *ItemStoreAdapter) Copy() hashsync.ItemStore { - return &ItemStoreAdapter{s: a.s.Copy().(*DBItemStore)} -} - -// GetRangeInfo implements hashsync.ItemStore. -func (a *ItemStoreAdapter) GetRangeInfo( - ctx context.Context, - preceding hashsync.Iterator, - x hashsync.Ordered, - y hashsync.Ordered, - count int, -) (hashsync.RangeInfo, error) { - hx := x.(types.Hash32) - hy := y.(types.Hash32) - info, err := a.s.GetRangeInfo(ctx, preceding, KeyBytes(hx[:]), KeyBytes(hy[:]), count) - if err != nil { - return hashsync.RangeInfo{}, err - } - var fp types.Hash12 - src := info.Fingerprint.(fingerprint) - copy(fp[:], src[:]) - return hashsync.RangeInfo{ - Fingerprint: fp, - Count: info.Count, - Start: a.wrapIterator(info.Start), - End: a.wrapIterator(info.End), - }, nil -} - -func (a *ItemStoreAdapter) SplitRange( - ctx context.Context, - preceding hashsync.Iterator, - x hashsync.Ordered, - y hashsync.Ordered, - count int, -) ( - hashsync.SplitInfo, - error, -) { - hx := x.(types.Hash32) - hy := y.(types.Hash32) - si, err := a.s.SplitRange( - ctx, preceding, KeyBytes(hx[:]), KeyBytes(hy[:]), count) - if err != nil { - return hashsync.SplitInfo{}, err - } - var fp1, fp2 types.Hash12 - src1 := si.Parts[0].Fingerprint.(fingerprint) - src2 := si.Parts[1].Fingerprint.(fingerprint) - copy(fp1[:], src1[:]) - copy(fp2[:], src2[:]) - var middle types.Hash32 - copy(middle[:], si.Middle.(KeyBytes)) - return hashsync.SplitInfo{ - Parts: [2]hashsync.RangeInfo{ - { - Fingerprint: fp1, - Count: si.Parts[0].Count, - Start: a.wrapIterator(si.Parts[0].Start), - End: a.wrapIterator(si.Parts[0].End), - }, - { - Fingerprint: fp2, - Count: si.Parts[1].Count, - Start: a.wrapIterator(si.Parts[1].Start), - End: a.wrapIterator(si.Parts[1].End), - }, - }, - Middle: middle, - }, nil -} - -// Has implements hashsync.ItemStore. -func (a *ItemStoreAdapter) Has(ctx context.Context, k hashsync.Ordered) (bool, error) { - h := k.(types.Hash32) - return a.s.Has(ctx, KeyBytes(h[:])) -} - -// Min implements hashsync.ItemStore. -func (a *ItemStoreAdapter) Min(ctx context.Context) (hashsync.Iterator, error) { - it, err := a.s.Min(ctx) - if err != nil { - return nil, err - } - return a.wrapIterator(it), nil -} - -// Recent implements hashsync.ItemStore. -func (d *ItemStoreAdapter) Recent(ctx context.Context, since time.Time) (hashsync.Iterator, int, error) { - it, count, err := d.s.Recent(ctx, since) - if err != nil { - return nil, 0, err - } - return d.wrapIterator(it), count, nil -} - -type iteratorAdapter struct { - it hashsync.Iterator -} - -var _ hashsync.Iterator = &iteratorAdapter{} - -func (ia *iteratorAdapter) Key() (hashsync.Ordered, error) { - k, err := ia.it.Key() - if err != nil { - return nil, err - } - var h types.Hash32 - copy(h[:], k.(KeyBytes)) - return h, nil -} - -func (ia *iteratorAdapter) Next() error { - return ia.it.Next() -} - -func (ia *iteratorAdapter) Clone() hashsync.Iterator { - return &iteratorAdapter{it: ia.it.Clone()} -} diff --git a/sync2/dbsync/dbitemstore_test.go b/sync2/dbsync/dbitemstore_test.go deleted file mode 100644 index 974ede4453..0000000000 --- a/sync2/dbsync/dbitemstore_test.go +++ /dev/null @@ -1,294 +0,0 @@ -package dbsync - -import ( - "context" - "fmt" - "testing" - - "github.com/spacemeshos/go-spacemesh/common/util" - "github.com/stretchr/testify/require" -) - -func TestDBItemStore_Empty(t *testing.T) { - db := populateDB(t, 32, nil) - st := &SyncedTable{ - TableName: "foo", - IDColumn: "id", - } - s := NewDBItemStore(db, st, 32, 24) - ctx := context.Background() - it, err := s.Min(ctx) - require.NoError(t, err) - require.Nil(t, it) - - info, err := s.GetRangeInfo(ctx, nil, - KeyBytes(util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")), - KeyBytes(util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")), - -1) - require.NoError(t, err) - require.Equal(t, 0, info.Count) - require.Equal(t, "000000000000000000000000", info.Fingerprint.(fmt.Stringer).String()) - require.Nil(t, info.Start) - require.Nil(t, info.End) - - info, err = s.GetRangeInfo(ctx, nil, - KeyBytes(util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")), - KeyBytes(util.FromHex("9999000000000000000000000000000000000000000000000000000000000000")), - -1) - require.NoError(t, err) - require.Equal(t, 0, info.Count) - require.Equal(t, "000000000000000000000000", info.Fingerprint.(fmt.Stringer).String()) - require.Nil(t, info.Start) - require.Nil(t, info.End) -} - -func TestDBItemStore(t *testing.T) { - ids := []KeyBytes{ - util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), - util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), - util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), - util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), - util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000"), - } - ctx := context.Background() - db := populateDB(t, 32, ids) - st := &SyncedTable{ - TableName: "foo", - IDColumn: "id", - } - s := NewDBItemStore(db, st, 32, 24) - it, err := s.Min(ctx) - require.NoError(t, err) - require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", - itKey(t, it).String()) - has, err := s.Has(ctx, KeyBytes(util.FromHex("9876000000000000000000000000000000000000000000000000000000000000"))) - require.NoError(t, err) - require.False(t, has) - - for _, tc := range []struct { - xIdx, yIdx int - limit int - fp string - count int - startIdx, endIdx int - }{ - { - xIdx: 0, - yIdx: 0, - limit: 0, - fp: "000000000000000000000000", - count: 0, - startIdx: 0, - endIdx: 0, - }, - { - xIdx: 1, - yIdx: 1, - limit: -1, - fp: "642464b773377bbddddddddd", - count: 5, - startIdx: 1, - endIdx: 1, - }, - { - xIdx: 0, - yIdx: 3, - limit: -1, - fp: "4761032dcfe98ba555555555", - count: 3, - startIdx: 0, - endIdx: 3, - }, - { - xIdx: 2, - yIdx: 0, - limit: -1, - fp: "761032cfe98ba54ddddddddd", - count: 3, - startIdx: 2, - endIdx: 0, - }, - { - xIdx: 3, - yIdx: 2, - limit: 3, - fp: "2345679abcdef01888888888", - count: 3, - startIdx: 3, - endIdx: 1, - }, - } { - name := fmt.Sprintf("%d-%d_%d", tc.xIdx, tc.yIdx, tc.limit) - t.Run(name, func(t *testing.T) { - t.Logf("x %s y %s limit %d", ids[tc.xIdx], ids[tc.yIdx], tc.limit) - info, err := s.GetRangeInfo(ctx, nil, ids[tc.xIdx], ids[tc.yIdx], tc.limit) - require.NoError(t, err) - require.Equal(t, tc.count, info.Count) - require.Equal(t, tc.fp, info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[tc.startIdx], itKey(t, info.Start)) - require.Equal(t, ids[tc.endIdx], itKey(t, info.End)) - has, err := s.Has(ctx, ids[tc.startIdx]) - require.NoError(t, err) - require.True(t, has) - has, err = s.Has(ctx, ids[tc.endIdx]) - require.NoError(t, err) - require.True(t, has) - }) - } -} - -func TestDBItemStore_Add(t *testing.T) { - ids := []KeyBytes{ - util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), - util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), - util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), - util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), - } - db := populateDB(t, 32, ids) - st := &SyncedTable{ - TableName: "foo", - IDColumn: "id", - } - s := NewDBItemStore(db, st, 32, 24) - ctx := context.Background() - it, err := s.Min(ctx) - require.NoError(t, err) - require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", - itKey(t, it).String()) - - newID := KeyBytes(util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000")) - require.NoError(t, s.Add(context.Background(), newID)) - - // // QQQQQ: rm - // s.ft.traceEnabled = true - // var sb strings.Builder - // s.ft.dump(&sb) - // t.Logf("tree:\n%s", sb.String()) - - info, err := s.GetRangeInfo(ctx, nil, ids[2], ids[0], -1) - require.NoError(t, err) - require.Equal(t, 3, info.Count) - require.Equal(t, "761032cfe98ba54ddddddddd", info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[2], itKey(t, info.Start)) - require.Equal(t, ids[0], itKey(t, info.End)) -} - -func TestDBItemStore_Copy(t *testing.T) { - ids := []KeyBytes{ - util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), - util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), - util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), - util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), - } - db := populateDB(t, 32, ids) - st := &SyncedTable{ - TableName: "foo", - IDColumn: "id", - } - s := NewDBItemStore(db, st, 32, 24) - ctx := context.Background() - it, err := s.Min(ctx) - require.NoError(t, err) - require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", - itKey(t, it).String()) - - copy := s.Copy() - - info, err := copy.GetRangeInfo(ctx, nil, ids[2], ids[0], -1) - require.NoError(t, err) - require.Equal(t, 2, info.Count) - require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[2], itKey(t, info.Start)) - require.Equal(t, ids[0], itKey(t, info.End)) - - newID := KeyBytes(util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000")) - require.NoError(t, copy.Add(context.Background(), newID)) - - info, err = s.GetRangeInfo(ctx, nil, ids[2], ids[0], -1) - require.NoError(t, err) - require.Equal(t, 2, info.Count) - require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[2], itKey(t, info.Start)) - require.Equal(t, ids[0], itKey(t, info.End)) - - info, err = copy.GetRangeInfo(ctx, nil, ids[2], ids[0], -1) - require.NoError(t, err) - require.Equal(t, 3, info.Count) - require.Equal(t, "761032cfe98ba54ddddddddd", info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[2], itKey(t, info.Start)) - require.Equal(t, ids[0], itKey(t, info.End)) -} - -func TestDBItemStore_Advance(t *testing.T) { - ids := []KeyBytes{ - util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), - util.FromHex("123456789abcdef0000000000000000000000000000000000000000000000000"), - util.FromHex("5555555555555555555555555555555555555555555555555555555555555555"), - util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), - } - db := populateDB(t, 32, ids) - st := &SyncedTable{ - TableName: "foo", - IDColumn: "id", - } - s := NewDBItemStore(db, st, 32, 24) - ctx := context.Background() - require.NoError(t, s.EnsureLoaded(ctx)) - - copy := s.Copy() - - info, err := s.GetRangeInfo(ctx, nil, ids[0], ids[0], -1) - require.NoError(t, err) - require.Equal(t, 4, info.Count) - require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[0], itKey(t, info.Start)) - require.Equal(t, ids[0], itKey(t, info.End)) - - info, err = copy.GetRangeInfo(ctx, nil, ids[0], ids[0], -1) - require.NoError(t, err) - require.Equal(t, 4, info.Count) - require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[0], itKey(t, info.Start)) - require.Equal(t, ids[0], itKey(t, info.End)) - - insertDBItems(t, db, []KeyBytes{ - util.FromHex("abcdef1234567890000000000000000000000000000000000000000000000000"), - }) - - info, err = s.GetRangeInfo(ctx, nil, ids[0], ids[0], -1) - require.NoError(t, err) - require.Equal(t, 4, info.Count) - require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[0], itKey(t, info.Start)) - require.Equal(t, ids[0], itKey(t, info.End)) - - info, err = copy.GetRangeInfo(ctx, nil, ids[0], ids[0], -1) - require.NoError(t, err) - require.Equal(t, 4, info.Count) - require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[0], itKey(t, info.Start)) - require.Equal(t, ids[0], itKey(t, info.End)) - - require.NoError(t, s.Advance(ctx)) - - info, err = s.GetRangeInfo(ctx, nil, ids[0], ids[0], -1) - require.NoError(t, err) - require.Equal(t, 5, info.Count) - require.Equal(t, "642464b773377bbddddddddd", info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[0], itKey(t, info.Start)) - require.Equal(t, ids[0], itKey(t, info.End)) - - info, err = copy.GetRangeInfo(ctx, nil, ids[0], ids[0], -1) - require.NoError(t, err) - require.Equal(t, 4, info.Count) - require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[0], itKey(t, info.Start)) - require.Equal(t, ids[0], itKey(t, info.End)) - - info, err = s.Copy().GetRangeInfo(ctx, nil, ids[0], ids[0], -1) - require.NoError(t, err) - require.Equal(t, 5, info.Count) - require.Equal(t, "642464b773377bbddddddddd", info.Fingerprint.(fmt.Stringer).String()) - require.Equal(t, ids[0], itKey(t, info.Start)) - require.Equal(t, ids[0], itKey(t, info.End)) -} diff --git a/sync2/dbsync/dbiter.go b/sync2/dbsync/dbiter.go deleted file mode 100644 index 8d1480c33d..0000000000 --- a/sync2/dbsync/dbiter.go +++ /dev/null @@ -1,396 +0,0 @@ -package dbsync - -import ( - "bytes" - "encoding/hex" - "errors" - "slices" - - "github.com/hashicorp/golang-lru/v2/simplelru" - "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sync2/hashsync" -) - -type KeyBytes []byte - -var _ hashsync.Ordered = KeyBytes(nil) - -func (k KeyBytes) String() string { - return hex.EncodeToString(k) -} - -func (k KeyBytes) Clone() KeyBytes { - return slices.Clone(k) -} - -func (k KeyBytes) Compare(other any) int { - return bytes.Compare(k, other.(KeyBytes)) -} - -func (k KeyBytes) inc() (overflow bool) { - for i := len(k) - 1; i >= 0; i-- { - k[i]++ - if k[i] != 0 { - return false - } - } - - return true -} - -func (k KeyBytes) zero() { - for i := range k { - k[i] = 0 - } -} - -func (k KeyBytes) isZero() bool { - for _, b := range k { - if b != 0 { - return false - } - } - return true -} - -var errEmptySet = errors.New("empty range") - -type dbIDKey struct { - id string - chunkSize int -} - -type lru = simplelru.LRU[dbIDKey, []KeyBytes] - -const lruCacheSize = 1024 * 1024 - -func newLRU() *lru { - cache, err := simplelru.NewLRU[dbIDKey, []KeyBytes](lruCacheSize, nil) - if err != nil { - panic("BUG: failed to create LRU cache: " + err.Error()) - } - return cache -} - -type dbRangeIterator struct { - db sql.Executor - from KeyBytes - sts *SyncedTableSnapshot - chunkSize int - ts int64 - maxChunkSize int - chunk []KeyBytes - pos int - keyLen int - singleChunk bool - loaded bool - cache *lru -} - -var _ hashsync.Iterator = &dbRangeIterator{} - -// makeDBIterator creates a dbRangeIterator and initializes it from the database. -// If query returns no rows even after starting from zero ID, errEmptySet error is returned. -func newDBRangeIterator( - db sql.Executor, - sts *SyncedTableSnapshot, - from KeyBytes, - ts int64, - maxChunkSize int, - lru *lru, -) hashsync.Iterator { - if from == nil { - panic("BUG: makeDBIterator: nil from") - } - if maxChunkSize <= 0 { - panic("BUG: makeDBIterator: chunkSize must be > 0") - } - return &dbRangeIterator{ - db: db, - from: from.Clone(), - sts: sts, - chunkSize: 1, - ts: ts, - maxChunkSize: maxChunkSize, - keyLen: len(from), - chunk: make([]KeyBytes, maxChunkSize), - singleChunk: false, - loaded: false, - cache: lru, - } -} - -func (it *dbRangeIterator) loadCached(key dbIDKey) (bool, int) { - if it.cache == nil { - return false, 0 - } - chunk, ok := it.cache.Get(key) - if !ok { - // fmt.Fprintf(os.Stderr, "QQQQQ: cache miss\n") - return false, 0 - } - - // fmt.Fprintf(os.Stderr, "QQQQQ: cache hit, chunk size %d\n", len(chunk)) - for n, id := range it.chunk[:len(chunk)] { - if id == nil { - id = make([]byte, it.keyLen) - it.chunk[n] = id - } - copy(id, chunk[n]) - } - return true, len(chunk) -} - -func (it *dbRangeIterator) load() error { - it.pos = 0 - if it.singleChunk { - // we have a single-chunk DB iterator, don't need to reload, - // just wrap around - return nil - } - - n := 0 - // if the chunk size was reduced due to a short chunk before wraparound, we need - // to extend it back - if cap(it.chunk) < it.chunkSize { - it.chunk = make([]KeyBytes, it.chunkSize) - } else { - it.chunk = it.chunk[:it.chunkSize] - } - // fmt.Fprintf(os.Stderr, "QQQQQ: from: %s chunkSize: %d\n", hex.EncodeToString(it.from), it.chunkSize) - key := dbIDKey{string(it.from), it.chunkSize} - var ierr, err error - found, n := it.loadCached(key) - if !found { - dec := func(stmt *sql.Statement) bool { - if n >= len(it.chunk) { - ierr = errors.New("too many rows") - return false - } - // we reuse existing slices when possible for retrieving new IDs - id := it.chunk[n] - if id == nil { - id = make([]byte, it.keyLen) - it.chunk[n] = id - } - stmt.ColumnBytes(0, id) - n++ - return true - } - if it.ts <= 0 { - err = it.sts.loadIDRange(it.db, it.from, it.chunkSize, dec) - } else { - err = it.sts.loadRecent(it.db, it.from, it.chunkSize, it.ts, dec) - } - if err == nil && ierr == nil && it.cache != nil { - cached := make([]KeyBytes, n) - for n, id := range it.chunk[:n] { - cached[n] = slices.Clone(id) - } - it.cache.Add(key, cached) - } - } - fromZero := it.from.isZero() - it.chunkSize = min(it.chunkSize*2, it.maxChunkSize) - switch { - case err != nil || ierr != nil: - return errors.Join(ierr, err) - case n == 0: - // empty chunk - if fromZero { - // already wrapped around or started from 0, - // the set is empty - return errEmptySet - } - // wrap around - it.from.zero() - return it.load() - case n < len(it.chunk): - // short chunk means there are no more items after it, - // start the next chunk from 0 - it.from.zero() - it.chunk = it.chunk[:n] - // wrapping around on an incomplete chunk that started - // from 0 means we have just a single chunk - it.singleChunk = fromZero - default: - // use last item incremented by 1 as the start of the next chunk - copy(it.from, it.chunk[n-1]) - // inc may wrap around if it's 0xffff...fff, but it's fine - if it.from.inc() { - // if we wrapped around and the current chunk started from 0, - // we have just a single chunk - it.singleChunk = fromZero - } - } - return nil -} - -func (it *dbRangeIterator) Key() (hashsync.Ordered, error) { - if !it.loaded { - if err := it.load(); err != nil { - return nil, err - } - it.loaded = true - } - if it.pos < len(it.chunk) { - return slices.Clone(it.chunk[it.pos]), nil - } - return nil, errEmptySet -} - -func (it *dbRangeIterator) Next() error { - if !it.loaded { - if err := it.load(); err != nil { - return err - } - it.loaded = true - if len(it.chunk) == 0 || it.pos != 0 { - panic("BUG: load didn't report empty set or set a wrong pos") - } - it.pos++ - return nil - } - it.pos++ - if it.pos < len(it.chunk) { - return nil - } - return it.load() -} - -func (it *dbRangeIterator) Clone() hashsync.Iterator { - cloned := *it - cloned.from = slices.Clone(it.from) - cloned.chunk = make([]KeyBytes, len(it.chunk)) - for i, k := range it.chunk { - cloned.chunk[i] = slices.Clone(k) - } - return &cloned -} - -type combinedIterator struct { - startingPoint hashsync.Ordered - iters []hashsync.Iterator - wrapped []hashsync.Iterator - ahead hashsync.Iterator - aheadIdx int -} - -// combineIterators combines multiple iterators into one, returning the smallest current -// key among all iterators at each step. -func combineIterators(startingPoint hashsync.Ordered, iters ...hashsync.Iterator) hashsync.Iterator { - return &combinedIterator{startingPoint: startingPoint, iters: iters} -} - -func (c *combinedIterator) begin() error { - // Some of the iterators may already be wrapped around. - // This corresponds to the case when we ask an idStore for iterator - // with a starting point beyond the last key in the store. - iters := c.iters - c.iters = nil - for _, it := range iters { - k, err := it.Key() - if err != nil { - if errors.Is(err, errEmptySet) { - // ignore empty iterators - continue - } - return err - } - if c.startingPoint != nil && k.Compare(c.startingPoint) < 0 { - c.wrapped = append(c.wrapped, it) - } else { - c.iters = append(c.iters, it) - } - } - if len(c.iters) == 0 { - // all iterators wrapped around - c.iters = c.wrapped - c.wrapped = nil - } - c.startingPoint = nil - return nil -} - -func (c *combinedIterator) aheadIterator() (hashsync.Iterator, error) { - if err := c.begin(); err != nil { - return nil, err - } - if c.ahead == nil { - if len(c.iters) == 0 { - if len(c.wrapped) == 0 { - return nil, nil - } - c.iters = c.wrapped - c.wrapped = nil - } - c.ahead = c.iters[0] - c.aheadIdx = 0 - for i := 1; i < len(c.iters); i++ { - curK, err := c.iters[i].Key() - if err != nil { - return nil, err - } - if curK != nil { - aK, err := c.ahead.Key() - if err != nil { - return nil, err - } - if curK.Compare(aK) < 0 { - c.ahead = c.iters[i] - c.aheadIdx = i - } - } - } - } - return c.ahead, nil -} - -func (c *combinedIterator) Key() (hashsync.Ordered, error) { - it, err := c.aheadIterator() - if err != nil { - return nil, err - } - return it.Key() -} - -func (c *combinedIterator) Next() error { - it, err := c.aheadIterator() - if err != nil { - return err - } - oldKey, err := it.Key() - if err != nil { - return err - } - if err := it.Next(); err != nil { - return err - } - c.ahead = nil - newKey, err := it.Key() - if err != nil { - return err - } - if oldKey.Compare(newKey) >= 0 { - // the iterator has wrapped around, move it to the wrapped list - // which will be used after all the iterators have wrapped around - c.wrapped = append(c.wrapped, it) - c.iters = append(c.iters[:c.aheadIdx], c.iters[c.aheadIdx+1:]...) - } - return nil -} - -func (c *combinedIterator) Clone() hashsync.Iterator { - cloned := &combinedIterator{ - iters: make([]hashsync.Iterator, len(c.iters)), - wrapped: make([]hashsync.Iterator, len(c.wrapped)), - startingPoint: c.startingPoint, - } - for i, it := range c.iters { - cloned.iters[i] = it.Clone() - } - for i, it := range c.wrapped { - cloned.wrapped[i] = it.Clone() - } - return cloned -} diff --git a/sync2/dbsync/dbiter_test.go b/sync2/dbsync/dbiter_test.go deleted file mode 100644 index ae526585f8..0000000000 --- a/sync2/dbsync/dbiter_test.go +++ /dev/null @@ -1,465 +0,0 @@ -package dbsync - -import ( - "context" - "encoding/hex" - "errors" - "fmt" - "slices" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sync2/hashsync" -) - -func TestIncID(t *testing.T) { - for _, tc := range []struct { - id, expected KeyBytes - overflow bool - }{ - { - id: KeyBytes{0x00, 0x00, 0x00, 0x00}, - expected: KeyBytes{0x00, 0x00, 0x00, 0x01}, - overflow: false, - }, - { - id: KeyBytes{0x00, 0x00, 0x00, 0xff}, - expected: KeyBytes{0x00, 0x00, 0x01, 0x00}, - overflow: false, - }, - { - id: KeyBytes{0xff, 0xff, 0xff, 0xff}, - expected: KeyBytes{0x00, 0x00, 0x00, 0x00}, - overflow: true, - }, - } { - id := make(KeyBytes, len(tc.id)) - copy(id, tc.id) - require.Equal(t, tc.overflow, id.inc()) - require.Equal(t, tc.expected, id) - } -} - -func createDB(t *testing.T, keyLen int) sql.Database { - db := sql.InMemoryTest(t) - _, err := db.Exec(fmt.Sprintf("create table foo(id char(%d) not null primary key)", keyLen), nil, nil) - require.NoError(t, err) - return db -} - -func insertDBItems(t *testing.T, db sql.Database, content []KeyBytes) { - err := db.WithTx(context.Background(), func(tx sql.Transaction) error { - for _, id := range content { - _, err := tx.Exec( - "insert into foo(id) values(?)", - func(stmt *sql.Statement) { - stmt.BindBytes(1, id) - }, nil) - if err != nil { - return err - } - } - return nil - }) - require.NoError(t, err) -} - -func deleteDBItems(t *testing.T, db sql.Database) { - _, err := db.Exec("delete from foo", nil, nil) - require.NoError(t, err) -} - -func populateDB(t *testing.T, keyLen int, content []KeyBytes) sql.Database { - db := createDB(t, keyLen) - insertDBItems(t, db, content) - return db -} - -func TestDBRangeIterator(t *testing.T) { - db := createDB(t, 4) - for _, tc := range []struct { - items []KeyBytes - from KeyBytes - fromN int - expErr error - }{ - { - items: nil, - from: KeyBytes{0x00, 0x00, 0x00, 0x00}, - expErr: errEmptySet, - }, - { - items: nil, - from: KeyBytes{0x80, 0x00, 0x00, 0x00}, - expErr: errEmptySet, - }, - { - items: nil, - from: KeyBytes{0xff, 0xff, 0xff, 0xff}, - expErr: errEmptySet, - }, - { - items: []KeyBytes{ - {0x00, 0x00, 0x00, 0x00}, - }, - from: KeyBytes{0x00, 0x00, 0x00, 0x00}, - fromN: 0, - }, - { - items: []KeyBytes{ - {0x00, 0x00, 0x00, 0x00}, - }, - from: KeyBytes{0x01, 0x00, 0x00, 0x00}, - fromN: 0, - }, - { - items: []KeyBytes{ - {0x00, 0x00, 0x00, 0x00}, - }, - from: KeyBytes{0xff, 0xff, 0xff, 0xff}, - fromN: 0, - }, - { - items: []KeyBytes{ - {0x01, 0x02, 0x03, 0x04}, - }, - from: KeyBytes{0x00, 0x00, 0x00, 0x00}, - fromN: 0, - }, - { - items: []KeyBytes{ - {0x01, 0x02, 0x03, 0x04}, - }, - from: KeyBytes{0x01, 0x00, 0x00, 0x00}, - fromN: 0, - }, - { - items: []KeyBytes{ - {0x01, 0x02, 0x03, 0x04}, - }, - from: KeyBytes{0xff, 0xff, 0xff, 0xff}, - fromN: 0, - }, - { - items: []KeyBytes{ - {0xff, 0xff, 0xff, 0xff}, - }, - from: KeyBytes{0x00, 0x00, 0x00, 0x00}, - fromN: 0, - }, - { - items: []KeyBytes{ - {0xff, 0xff, 0xff, 0xff}, - }, - from: KeyBytes{0x01, 0x00, 0x00, 0x00}, - fromN: 0, - }, - { - items: []KeyBytes{ - {0xff, 0xff, 0xff, 0xff}, - }, - from: KeyBytes{0xff, 0xff, 0xff, 0xff}, - fromN: 0, - }, - { - items: []KeyBytes{ - 0: {0x00, 0x00, 0x00, 0x01}, - 1: {0x00, 0x00, 0x00, 0x03}, - 2: {0x00, 0x00, 0x00, 0x05}, - 3: {0x00, 0x00, 0x00, 0x07}, - }, - from: KeyBytes{0x00, 0x00, 0x00, 0x00}, - fromN: 0, - }, - { - items: []KeyBytes{ - 0: {0x00, 0x00, 0x00, 0x01}, - 1: {0x00, 0x00, 0x00, 0x03}, - 2: {0x00, 0x00, 0x00, 0x05}, - 3: {0x00, 0x00, 0x00, 0x07}, - }, - from: KeyBytes{0x00, 0x00, 0x00, 0x01}, - fromN: 0, - }, - { - items: []KeyBytes{ - 0: {0x00, 0x00, 0x00, 0x01}, - 1: {0x00, 0x00, 0x00, 0x03}, - 2: {0x00, 0x00, 0x00, 0x05}, - 3: {0x00, 0x00, 0x00, 0x07}, - }, - from: KeyBytes{0x00, 0x00, 0x00, 0x02}, - fromN: 1, - }, - { - items: []KeyBytes{ - 0: {0x00, 0x00, 0x00, 0x01}, - 1: {0x00, 0x00, 0x00, 0x03}, - 2: {0x00, 0x00, 0x00, 0x05}, - 3: {0x00, 0x00, 0x00, 0x07}, - }, - from: KeyBytes{0x00, 0x00, 0x00, 0x03}, - fromN: 1, - }, - { - items: []KeyBytes{ - 0: {0x00, 0x00, 0x00, 0x01}, - 1: {0x00, 0x00, 0x00, 0x03}, - 2: {0x00, 0x00, 0x00, 0x05}, - 3: {0x00, 0x00, 0x00, 0x07}, - }, - from: KeyBytes{0x00, 0x00, 0x00, 0x05}, - fromN: 2, - }, - { - items: []KeyBytes{ - 0: {0x00, 0x00, 0x00, 0x01}, - 1: {0x00, 0x00, 0x00, 0x03}, - 2: {0x00, 0x00, 0x00, 0x05}, - 3: {0x00, 0x00, 0x00, 0x07}, - }, - from: KeyBytes{0x00, 0x00, 0x00, 0x07}, - fromN: 3, - }, - { - items: []KeyBytes{ - 0: {0x00, 0x00, 0x00, 0x01}, - 1: {0x00, 0x00, 0x00, 0x03}, - 2: {0x00, 0x00, 0x00, 0x05}, - 3: {0x00, 0x00, 0x00, 0x07}, - }, - from: KeyBytes{0xff, 0xff, 0xff, 0xff}, - fromN: 0, - }, - { - items: []KeyBytes{ - 0: {0x00, 0x00, 0x00, 0x01}, - 1: {0x00, 0x00, 0x00, 0x03}, - 2: {0x00, 0x00, 0x00, 0x05}, - 3: {0x00, 0x00, 0x00, 0x07}, - 4: {0x00, 0x00, 0x01, 0x00}, - 5: {0x00, 0x00, 0x03, 0x00}, - 6: {0x00, 0x01, 0x00, 0x00}, - 7: {0x00, 0x05, 0x00, 0x00}, - 8: {0x03, 0x05, 0x00, 0x00}, - 9: {0x09, 0x05, 0x00, 0x00}, - 10: {0x0a, 0x05, 0x00, 0x00}, - 11: {0xff, 0xff, 0xff, 0xff}, - }, - from: KeyBytes{0x00, 0x00, 0x03, 0x01}, - fromN: 6, - }, - { - items: []KeyBytes{ - 0: {0x00, 0x00, 0x00, 0x01}, - 1: {0x00, 0x00, 0x00, 0x03}, - 2: {0x00, 0x00, 0x00, 0x05}, - 3: {0x00, 0x00, 0x00, 0x07}, - 4: {0x00, 0x00, 0x01, 0x00}, - 5: {0x00, 0x00, 0x03, 0x00}, - 6: {0x00, 0x01, 0x00, 0x00}, - 7: {0x00, 0x05, 0x00, 0x00}, - 8: {0x03, 0x05, 0x00, 0x00}, - 9: {0x09, 0x05, 0x00, 0x00}, - 10: {0x0a, 0x05, 0x00, 0x00}, - 11: {0xff, 0xff, 0xff, 0xff}, - }, - from: KeyBytes{0x00, 0x01, 0x00, 0x00}, - fromN: 6, - }, - { - items: []KeyBytes{ - 0: {0x00, 0x00, 0x00, 0x01}, - 1: {0x00, 0x00, 0x00, 0x03}, - 2: {0x00, 0x00, 0x00, 0x05}, - 3: {0x00, 0x00, 0x00, 0x07}, - 4: {0x00, 0x00, 0x01, 0x00}, - 5: {0x00, 0x00, 0x03, 0x00}, - 6: {0x00, 0x01, 0x00, 0x00}, - 7: {0x00, 0x05, 0x00, 0x00}, - 8: {0x03, 0x05, 0x00, 0x00}, - 9: {0x09, 0x05, 0x00, 0x00}, - 10: {0x0a, 0x05, 0x00, 0x00}, - 11: {0xff, 0xff, 0xff, 0xff}, - }, - from: KeyBytes{0xff, 0xff, 0xff, 0xff}, - fromN: 11, - }, - } { - deleteDBItems(t, db) - insertDBItems(t, db, tc.items) - cache := newLRU() - st := &SyncedTable{ - TableName: "foo", - IDColumn: "id", - } - sts, err := st.snapshot(db) - require.NoError(t, err) - for maxChunkSize := 1; maxChunkSize < 12; maxChunkSize++ { - it := newDBRangeIterator(db, sts, tc.from, -1, maxChunkSize, cache) - if tc.expErr != nil { - _, err := it.Key() - require.ErrorIs(t, err, tc.expErr) - continue - } - // when there are no items, errEmptySet is returned - require.NotEmpty(t, tc.items) - clonedIt := it.Clone() - var collected []KeyBytes - for i := 0; i < len(tc.items); i++ { - k := itKey(t, it) - require.NotNil(t, k) - collected = append(collected, k) - require.Equal(t, k, itKey(t, clonedIt)) - require.NoError(t, it.Next()) - // calling Next on the original iterator - // shouldn't affect the cloned one - require.Equal(t, k, itKey(t, clonedIt)) - require.NoError(t, clonedIt.Next()) - } - expected := slices.Concat(tc.items[tc.fromN:], tc.items[:tc.fromN]) - require.Equal(t, expected, collected, "count=%d from=%s maxChunkSize=%d", - len(tc.items), hex.EncodeToString(tc.from), maxChunkSize) - clonedIt = it.Clone() - for range 2 { - for i := 0; i < len(tc.items); i++ { - k := itKey(t, it) - require.Equal(t, collected[i], k) - require.Equal(t, k, itKey(t, clonedIt)) - require.NoError(t, it.Next()) - require.Equal(t, k, itKey(t, clonedIt)) - require.NoError(t, clonedIt.Next()) - } - } - } - } -} - -type fakeIterator struct { - items, allItems []KeyBytes -} - -var _ hashsync.Iterator = &fakeIterator{} - -func (it *fakeIterator) Key() (hashsync.Ordered, error) { - if len(it.allItems) == 0 { - return nil, errEmptySet - } - if len(it.items) == 0 { - it.items = it.allItems - } - return KeyBytes(it.items[0]), nil -} - -func (it *fakeIterator) Next() error { - if len(it.items) == 0 { - it.items = it.allItems - } - it.items = it.items[1:] - if len(it.items) != 0 && string(it.items[0]) == "error" { - return errors.New("iterator error") - } - return nil -} - -func (it *fakeIterator) Clone() hashsync.Iterator { - cloned := &fakeIterator{ - allItems: make([]KeyBytes, len(it.allItems)), - } - for i, k := range it.allItems { - cloned.allItems[i] = slices.Clone(k) - } - cloned.items = cloned.allItems[len(it.allItems)-len(it.items):] - return cloned -} - -func TestCombineIterators(t *testing.T) { - it1 := &fakeIterator{ - allItems: []KeyBytes{ - {0x00, 0x00, 0x00, 0x01}, - {0x0a, 0x05, 0x00, 0x00}, - }, - } - it2 := &fakeIterator{ - allItems: []KeyBytes{ - {0x00, 0x00, 0x00, 0x03}, - {0xff, 0xff, 0xff, 0xff}, - }, - } - - it := combineIterators(nil, it1, it2) - clonedIt := it.Clone() - for range 3 { - var collected []KeyBytes - for range 4 { - k := itKey(t, it) - collected = append(collected, k) - require.Equal(t, k, itKey(t, clonedIt)) - require.NoError(t, it.Next()) - require.Equal(t, k, itKey(t, clonedIt)) - require.NoError(t, clonedIt.Next()) - } - require.Equal(t, []KeyBytes{ - {0x00, 0x00, 0x00, 0x01}, - {0x00, 0x00, 0x00, 0x03}, - {0x0a, 0x05, 0x00, 0x00}, - {0xff, 0xff, 0xff, 0xff}, - }, collected) - require.Equal(t, KeyBytes{0x00, 0x00, 0x00, 0x01}, itKey(t, it)) - } - - it1 = &fakeIterator{allItems: []KeyBytes{KeyBytes{0, 0, 0, 0}, KeyBytes("error")}} - it2 = &fakeIterator{allItems: []KeyBytes{KeyBytes{0, 0, 0, 1}}} - - it = combineIterators(nil, it1, it2) - require.Equal(t, KeyBytes{0, 0, 0, 0}, itKey(t, it)) - require.Error(t, it.Next()) - - it1 = &fakeIterator{allItems: []KeyBytes{KeyBytes{0, 0, 0, 0}}} - it2 = &fakeIterator{allItems: []KeyBytes{KeyBytes{0, 0, 0, 1}, KeyBytes("error")}} - - it = combineIterators(nil, it1, it2) - require.Equal(t, KeyBytes{0, 0, 0, 0}, itKey(t, it)) - require.NoError(t, it.Next()) - require.Equal(t, KeyBytes{0, 0, 0, 1}, itKey(t, it)) - require.Error(t, it.Next()) -} - -func TestCombineIteratorsInitiallyWrapped(t *testing.T) { - it1 := &fakeIterator{ - allItems: []KeyBytes{ - {0x00, 0x00, 0x00, 0x01}, - {0x0a, 0x05, 0x00, 0x00}, - }, - } - it2 := &fakeIterator{ - allItems: []KeyBytes{ - {0x00, 0x00, 0x00, 0x03}, - {0xff, 0x00, 0x00, 0x55}, - }, - } - require.NoError(t, it2.Next()) - it := combineIterators(KeyBytes{0xff, 0x00, 0x00, 0x55}, it1, it2) - var collected []KeyBytes - for range 4 { - k := itKey(t, it) - collected = append(collected, k) - require.NoError(t, it.Next()) - } - require.Equal(t, []KeyBytes{ - {0xff, 0x00, 0x00, 0x55}, - {0x00, 0x00, 0x00, 0x01}, - {0x00, 0x00, 0x00, 0x03}, - {0x0a, 0x05, 0x00, 0x00}, - }, collected) - require.Equal(t, KeyBytes{0xff, 0x00, 0x00, 0x55}, itKey(t, it)) -} - -func itKey(t *testing.T, it hashsync.Iterator) KeyBytes { - k, err := it.Key() - require.NoError(t, err) - require.NotNil(t, k) - return k.(KeyBytes) -} diff --git a/sync2/dbsync/dbseq.go b/sync2/dbsync/dbseq.go new file mode 100644 index 0000000000..f0ac140b7e --- /dev/null +++ b/sync2/dbsync/dbseq.go @@ -0,0 +1,205 @@ +package dbsync + +import ( + "errors" + "slices" + + "github.com/hashicorp/golang-lru/v2/simplelru" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +type dbIDKey struct { + id string + chunkSize int +} + +type lru = simplelru.LRU[dbIDKey, []types.KeyBytes] + +const lruCacheSize = 1024 * 1024 + +func newLRU() *lru { + cache, err := simplelru.NewLRU[dbIDKey, []types.KeyBytes](lruCacheSize, nil) + if err != nil { + panic("BUG: failed to create LRU cache: " + err.Error()) + } + return cache +} + +type dbSeq struct { + db sql.Executor + from types.KeyBytes + sts *SyncedTableSnapshot + chunkSize int + ts int64 + maxChunkSize int + chunk []types.KeyBytes + pos int + keyLen int + singleChunk bool + cache *lru +} + +// idsFromTable iterates over the id field values in an SQLite table. +func idsFromTable( + db sql.Executor, + sts *SyncedTableSnapshot, + from types.KeyBytes, + ts int64, + maxChunkSize int, + lru *lru, +) types.Seq { + if from == nil { + panic("BUG: makeDBIterator: nil from") + } + if maxChunkSize <= 0 { + panic("BUG: makeDBIterator: chunkSize must be > 0") + } + return func(yield func(k types.Ordered, err error) bool) { + s := &dbSeq{ + db: db, + from: from.Clone(), + sts: sts, + chunkSize: 1, + ts: ts, + maxChunkSize: maxChunkSize, + keyLen: len(from), + chunk: make([]types.KeyBytes, maxChunkSize), + singleChunk: false, + cache: lru, + } + if err := s.load(); err != nil { + yield(nil, err) + } + s.iterate(yield) + } +} + +func (s *dbSeq) loadCached(key dbIDKey) (bool, int) { + if s.cache == nil { + return false, 0 + } + chunk, ok := s.cache.Get(key) + if !ok { + // fmt.Fprintf(os.Stderr, "QQQQQ: cache miss\n") + return false, 0 + } + + // fmt.Fprintf(os.Stderr, "QQQQQ: cache hit, chunk size %d\n", len(chunk)) + for n, id := range s.chunk[:len(chunk)] { + if id == nil { + id = make([]byte, s.keyLen) + s.chunk[n] = id + } + copy(id, chunk[n]) + } + return true, len(chunk) +} + +func (s *dbSeq) load() error { + s.pos = 0 + if s.singleChunk { + // we have a single-chunk DB sequence, don't need to reload, + // just wrap around + return nil + } + + n := 0 + // if the chunk size was reduced due to a short chunk before wraparound, we need + // to extend it back + if cap(s.chunk) < s.chunkSize { + s.chunk = make([]types.KeyBytes, s.chunkSize) + } else { + s.chunk = s.chunk[:s.chunkSize] + } + // fmt.Fprintf(os.Stderr, "QQQQQ: from: %s chunkSize: %d\n", hex.EncodeToString(it.from), it.chunkSize) + key := dbIDKey{string(s.from), s.chunkSize} + var ierr, err error + found, n := s.loadCached(key) + if !found { + dec := func(stmt *sql.Statement) bool { + if n >= len(s.chunk) { + ierr = errors.New("too many rows") + return false + } + // we reuse existing slices when possible for retrieving new IDs + id := s.chunk[n] + if id == nil { + id = make([]byte, s.keyLen) + s.chunk[n] = id + } + stmt.ColumnBytes(0, id) + n++ + return true + } + if s.ts <= 0 { + err = s.sts.loadIDRange(s.db, s.from, s.chunkSize, dec) + } else { + err = s.sts.loadRecent(s.db, s.from, s.chunkSize, s.ts, dec) + } + if err == nil && ierr == nil && s.cache != nil { + cached := make([]types.KeyBytes, n) + for n, id := range s.chunk[:n] { + cached[n] = slices.Clone(id) + } + s.cache.Add(key, cached) + } + } + fromZero := s.from.IsZero() + s.chunkSize = min(s.chunkSize*2, s.maxChunkSize) + switch { + case err != nil || ierr != nil: + return errors.Join(ierr, err) + case n == 0: + // empty chunk + if fromZero { + // already wrapped around or started from 0, + // the set is empty + s.chunk = nil + return nil + } + // wrap around + s.from.Zero() + return s.load() + case n < len(s.chunk): + // short chunk means there are no more items after it, + // start the next chunk from 0 + s.from.Zero() + s.chunk = s.chunk[:n] + // wrapping around on an incomplete chunk that started + // from 0 means we have just a single chunk + s.singleChunk = fromZero + default: + // use last item incremented by 1 as the start of the next chunk + copy(s.from, s.chunk[n-1]) + // inc may wrap around if it's 0xffff...fff, but it's fine + if s.from.Inc() { + // if we wrapped around and the current chunk started from 0, + // we have just a single chunk + s.singleChunk = fromZero + } + } + return nil +} + +func (s *dbSeq) iterate(yield func(k types.Ordered, err error) bool) { + if len(s.chunk) == 0 { + return + } + for { + if s.pos >= len(s.chunk) { + panic("BUG: bad dbSeq position") + } + if !yield(slices.Clone(s.chunk[s.pos]), nil) { + break + } + s.pos++ + if s.pos >= len(s.chunk) { + if err := s.load(); err != nil { + yield(nil, err) + return + } + } + } +} diff --git a/sync2/dbsync/dbseq_test.go b/sync2/dbsync/dbseq_test.go new file mode 100644 index 0000000000..e703b52e3e --- /dev/null +++ b/sync2/dbsync/dbseq_test.go @@ -0,0 +1,289 @@ +package dbsync + +import ( + "context" + "encoding/hex" + "fmt" + "slices" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +func createDB(t *testing.T, keyLen int) sql.Database { + db := sql.InMemoryTest(t) + _, err := db.Exec(fmt.Sprintf("create table foo(id char(%d) not null primary key)", keyLen), nil, nil) + require.NoError(t, err) + return db +} + +func insertDBItems(t *testing.T, db sql.Database, content []types.KeyBytes) { + err := db.WithTx(context.Background(), func(tx sql.Transaction) error { + for _, id := range content { + _, err := tx.Exec( + "insert into foo(id) values(?)", + func(stmt *sql.Statement) { + stmt.BindBytes(1, id) + }, nil) + if err != nil { + return err + } + } + return nil + }) + require.NoError(t, err) +} + +func deleteDBItems(t *testing.T, db sql.Database) { + _, err := db.Exec("delete from foo", nil, nil) + require.NoError(t, err) +} + +func populateDB(t *testing.T, keyLen int, content []types.KeyBytes) sql.Database { + db := createDB(t, keyLen) + insertDBItems(t, db, content) + return db +} + +func TestDBRangeIterator(t *testing.T) { + db := createDB(t, 4) + for _, tc := range []struct { + items []types.KeyBytes + from types.KeyBytes + fromN int + }{ + { + items: nil, + from: types.KeyBytes{0x00, 0x00, 0x00, 0x00}, + }, + { + items: nil, + from: types.KeyBytes{0x80, 0x00, 0x00, 0x00}, + }, + { + items: nil, + from: types.KeyBytes{0xff, 0xff, 0xff, 0xff}, + }, + { + items: []types.KeyBytes{ + {0x00, 0x00, 0x00, 0x00}, + }, + from: types.KeyBytes{0x00, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []types.KeyBytes{ + {0x00, 0x00, 0x00, 0x00}, + }, + from: types.KeyBytes{0x01, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []types.KeyBytes{ + {0x00, 0x00, 0x00, 0x00}, + }, + from: types.KeyBytes{0xff, 0xff, 0xff, 0xff}, + fromN: 0, + }, + { + items: []types.KeyBytes{ + {0x01, 0x02, 0x03, 0x04}, + }, + from: types.KeyBytes{0x00, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []types.KeyBytes{ + {0x01, 0x02, 0x03, 0x04}, + }, + from: types.KeyBytes{0x01, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []types.KeyBytes{ + {0x01, 0x02, 0x03, 0x04}, + }, + from: types.KeyBytes{0xff, 0xff, 0xff, 0xff}, + fromN: 0, + }, + { + items: []types.KeyBytes{ + {0xff, 0xff, 0xff, 0xff}, + }, + from: types.KeyBytes{0x00, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []types.KeyBytes{ + {0xff, 0xff, 0xff, 0xff}, + }, + from: types.KeyBytes{0x01, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []types.KeyBytes{ + {0xff, 0xff, 0xff, 0xff}, + }, + from: types.KeyBytes{0xff, 0xff, 0xff, 0xff}, + fromN: 0, + }, + { + items: []types.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: types.KeyBytes{0x00, 0x00, 0x00, 0x00}, + fromN: 0, + }, + { + items: []types.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: types.KeyBytes{0x00, 0x00, 0x00, 0x01}, + fromN: 0, + }, + { + items: []types.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: types.KeyBytes{0x00, 0x00, 0x00, 0x02}, + fromN: 1, + }, + { + items: []types.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: types.KeyBytes{0x00, 0x00, 0x00, 0x03}, + fromN: 1, + }, + { + items: []types.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: types.KeyBytes{0x00, 0x00, 0x00, 0x05}, + fromN: 2, + }, + { + items: []types.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: types.KeyBytes{0x00, 0x00, 0x00, 0x07}, + fromN: 3, + }, + { + items: []types.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + }, + from: types.KeyBytes{0xff, 0xff, 0xff, 0xff}, + fromN: 0, + }, + { + items: []types.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + 4: {0x00, 0x00, 0x01, 0x00}, + 5: {0x00, 0x00, 0x03, 0x00}, + 6: {0x00, 0x01, 0x00, 0x00}, + 7: {0x00, 0x05, 0x00, 0x00}, + 8: {0x03, 0x05, 0x00, 0x00}, + 9: {0x09, 0x05, 0x00, 0x00}, + 10: {0x0a, 0x05, 0x00, 0x00}, + 11: {0xff, 0xff, 0xff, 0xff}, + }, + from: types.KeyBytes{0x00, 0x00, 0x03, 0x01}, + fromN: 6, + }, + { + items: []types.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + 4: {0x00, 0x00, 0x01, 0x00}, + 5: {0x00, 0x00, 0x03, 0x00}, + 6: {0x00, 0x01, 0x00, 0x00}, + 7: {0x00, 0x05, 0x00, 0x00}, + 8: {0x03, 0x05, 0x00, 0x00}, + 9: {0x09, 0x05, 0x00, 0x00}, + 10: {0x0a, 0x05, 0x00, 0x00}, + 11: {0xff, 0xff, 0xff, 0xff}, + }, + from: types.KeyBytes{0x00, 0x01, 0x00, 0x00}, + fromN: 6, + }, + { + items: []types.KeyBytes{ + 0: {0x00, 0x00, 0x00, 0x01}, + 1: {0x00, 0x00, 0x00, 0x03}, + 2: {0x00, 0x00, 0x00, 0x05}, + 3: {0x00, 0x00, 0x00, 0x07}, + 4: {0x00, 0x00, 0x01, 0x00}, + 5: {0x00, 0x00, 0x03, 0x00}, + 6: {0x00, 0x01, 0x00, 0x00}, + 7: {0x00, 0x05, 0x00, 0x00}, + 8: {0x03, 0x05, 0x00, 0x00}, + 9: {0x09, 0x05, 0x00, 0x00}, + 10: {0x0a, 0x05, 0x00, 0x00}, + 11: {0xff, 0xff, 0xff, 0xff}, + }, + from: types.KeyBytes{0xff, 0xff, 0xff, 0xff}, + fromN: 11, + }, + } { + deleteDBItems(t, db) + insertDBItems(t, db, tc.items) + cache := newLRU() + st := &SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + sts, err := st.snapshot(db) + require.NoError(t, err) + for maxChunkSize := 1; maxChunkSize < 12; maxChunkSize++ { + seq := idsFromTable(db, sts, tc.from, -1, maxChunkSize, cache) + // when there are no items, errEmptySet is returned + for range 3 { // make sure the sequence is reusable + var collected []types.KeyBytes + var firstK types.KeyBytes + for item, err := range seq { + k := item.(types.KeyBytes) + if firstK == nil { + firstK = k + } else if k.Compare(firstK) == 0 { + break + } + collected = append(collected, k) + require.NoError(t, err) + } + expected := slices.Concat(tc.items[tc.fromN:], tc.items[:tc.fromN]) + require.Equal(t, expected, collected, "count=%d from=%s maxChunkSize=%d", + len(tc.items), hex.EncodeToString(tc.from), maxChunkSize) + } + } + } +} diff --git a/sync2/dbsync/dbset.go b/sync2/dbsync/dbset.go new file mode 100644 index 0000000000..a28f8a3c3b --- /dev/null +++ b/sync2/dbsync/dbset.go @@ -0,0 +1,248 @@ +package dbsync + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +type DBSet struct { + loadMtx sync.Mutex + db sql.Database + ft *fpTree + st *SyncedTable + snapshot *SyncedTableSnapshot + dbStore *dbBackedStore + keyLen int + maxDepth int +} + +var _ rangesync.OrderedSet = &DBSet{} + +func NewDBSet( + db sql.Database, + st *SyncedTable, + keyLen, maxDepth int, +) *DBSet { + return &DBSet{ + db: db, + st: st, + keyLen: keyLen, + maxDepth: maxDepth, + } +} + +func (d *DBSet) decodeID(stmt *sql.Statement) bool { + id := make(types.KeyBytes, d.keyLen) // TODO: don't allocate new ID + stmt.ColumnBytes(0, id[:]) + d.ft.addStoredHash(id) + return true +} + +func (d *DBSet) EnsureLoaded(ctx context.Context) error { + d.loadMtx.Lock() + defer d.loadMtx.Unlock() + if d.ft != nil { + return nil + } + db := ContextSQLExec(ctx, d.db) + var err error + d.snapshot, err = d.st.snapshot(db) + if err != nil { + return fmt.Errorf("error taking snapshot: %w", err) + } + var np nodePool + d.dbStore = newDBBackedStore(db, d.snapshot, d.keyLen) + d.ft = newFPTree(&np, d.dbStore, d.keyLen, d.maxDepth) + return d.snapshot.loadIDs(db, d.decodeID) +} + +// Add implements hashsync.ItemStore. +func (d *DBSet) Add(ctx context.Context, k types.Ordered) error { + if err := d.EnsureLoaded(ctx); err != nil { + return err + } + has, err := d.Has(ctx, k) // TODO: this check shouldn't be needed + if has || err != nil { + return err + } + return d.ft.addHash(k.(types.KeyBytes)) +} + +// GetRangeInfo implements hashsync.ItemStore. +func (d *DBSet) GetRangeInfo( + ctx context.Context, + x, y types.Ordered, + count int, +) (rangesync.RangeInfo, error) { + if err := d.EnsureLoaded(ctx); err != nil { + return rangesync.RangeInfo{}, err + } + fpr, err := d.ft.fingerprintInterval(ctx, x.(types.KeyBytes), y.(types.KeyBytes), count) + if err != nil { + return rangesync.RangeInfo{}, err + } + return rangesync.RangeInfo{ + Fingerprint: fpr.fp, + Count: int(fpr.count), + Items: fpr.items, + }, nil +} + +func (d *DBSet) SplitRange( + ctx context.Context, + x, y types.Ordered, + count int, +) ( + rangesync.SplitInfo, + error, +) { + if count <= 0 { + panic("BUG: bad split count") + } + + if err := d.EnsureLoaded(ctx); err != nil { + return rangesync.SplitInfo{}, err + } + + sr, err := d.ft.easySplit(ctx, x.(types.KeyBytes), y.(types.KeyBytes), count) + if err == nil { + // fmt.Fprintf(os.Stderr, "QQQQQ: fast split, middle: %s\n", sr.middle.String()) + return rangesync.SplitInfo{ + Parts: [2]rangesync.RangeInfo{ + { + Fingerprint: sr.part0.fp, + Count: int(sr.part0.count), + Items: sr.part0.items, + }, + { + Fingerprint: sr.part1.fp, + Count: int(sr.part1.count), + Items: sr.part1.items, + }, + }, + Middle: sr.middle, + }, nil + } + + if !errors.Is(err, errEasySplitFailed) { + return rangesync.SplitInfo{}, err + } + + fpr0, err := d.ft.fingerprintInterval(ctx, x.(types.KeyBytes), y.(types.KeyBytes), count) + if err != nil { + return rangesync.SplitInfo{}, err + } + + if fpr0.count == 0 { + return rangesync.SplitInfo{}, errors.New("can't split empty range") + } + + fpr1, err := d.ft.fingerprintInterval(ctx, fpr0.next, y.(types.KeyBytes), -1) + if err != nil { + return rangesync.SplitInfo{}, err + } + + return rangesync.SplitInfo{ + Parts: [2]rangesync.RangeInfo{ + { + Fingerprint: fpr0.fp, + Count: int(fpr0.count), + Items: fpr0.items, + }, + { + Fingerprint: fpr1.fp, + Count: int(fpr1.count), + Items: fpr1.items, + }, + }, + Middle: fpr0.next, + }, nil +} + +// Min implements hashsync.ItemStore. +func (d *DBSet) Items(ctx context.Context) (types.Seq, error) { + if err := d.EnsureLoaded(ctx); err != nil { + return nil, err + } + if d.ft.count() == 0 { + return types.EmptySeq(), nil + } + return d.ft.all(ctx) +} + +// Empty implements hashsync.ItemStore. +func (d *DBSet) Empty(ctx context.Context) (bool, error) { + if err := d.EnsureLoaded(ctx); err != nil { + return false, err + } + return d.ft.count() == 0, nil +} + +func (d *DBSet) Advance(ctx context.Context) error { + d.loadMtx.Lock() + d.loadMtx.Unlock() + if d.ft == nil { + // FIXME + panic("BUG: can't advance the DBItemStore before it's loaded") + } + oldSnapshot := d.snapshot + var err error + d.snapshot, err = d.st.snapshot(ContextSQLExec(ctx, d.db)) + if err != nil { + return fmt.Errorf("error taking snapshot: %w", err) + } + d.dbStore.setSnapshot(d.snapshot) + return d.snapshot.loadIDsSince(d.db, oldSnapshot, d.decodeID) +} + +// Copy implements hashsync.ItemStore. +func (d *DBSet) Copy() rangesync.OrderedSet { + d.loadMtx.Lock() + d.loadMtx.Unlock() + if d.ft == nil { + // FIXME + panic("BUG: can't copy the DBItemStore before it's loaded") + } + return &DBSet{ + db: d.db, + ft: d.ft.clone(), + st: d.st, + keyLen: d.keyLen, + maxDepth: d.maxDepth, + } +} + +// Has implements hashsync.ItemStore. +func (d *DBSet) Has(ctx context.Context, k types.Ordered) (bool, error) { + if err := d.EnsureLoaded(ctx); err != nil { + return false, err + } + if d.ft.count() == 0 { + return false, nil + } + // TODO: should often be able to avoid querying the database if we check the key + // against the fptree + seq, err := d.ft.from(ctx, k.(types.KeyBytes)) + if err != nil { + return false, err + } + for curK, err := range seq { + if err != nil { + return false, err + } + return curK.Compare(k) == 0, nil + } + panic("BUG: empty sequence with a non-zero count") +} + +// Recent implements hashsync.ItemStore. +func (d *DBSet) Recent(ctx context.Context, since time.Time) (types.Seq, int, error) { + return d.dbStore.since(ctx, make(types.KeyBytes, d.keyLen), since.UnixNano()) +} diff --git a/sync2/dbsync/dbset_test.go b/sync2/dbsync/dbset_test.go new file mode 100644 index 0000000000..1b4ba8fff0 --- /dev/null +++ b/sync2/dbsync/dbset_test.go @@ -0,0 +1,287 @@ +package dbsync + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +func TestDBItemStore_Empty(t *testing.T) { + db := populateDB(t, 32, nil) + st := &SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := NewDBSet(db, st, 32, 24) + ctx := context.Background() + empty, err := s.Empty(ctx) + require.NoError(t, err) + require.True(t, empty) + seq, err := s.Items(ctx) + require.NoError(t, err) + requireEmpty(t, seq) + for _, _ = range seq { + require.Fail(t, "expected an empty sequence") + } + + info, err := s.GetRangeInfo(ctx, + types.HexToKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + types.HexToKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + -1) + require.NoError(t, err) + require.Equal(t, 0, info.Count) + require.Equal(t, "000000000000000000000000", info.Fingerprint.String()) + requireEmpty(t, info.Items) + + info, err = s.GetRangeInfo(ctx, + types.HexToKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + types.HexToKeyBytes("9999000000000000000000000000000000000000000000000000000000000000"), + -1) + require.NoError(t, err) + require.Equal(t, 0, info.Count) + require.Equal(t, "000000000000000000000000", info.Fingerprint.String()) + requireEmpty(t, info.Items) +} + +func TestDBItemStore(t *testing.T) { + ids := []types.KeyBytes{ + types.HexToKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + types.HexToKeyBytes("123456789abcdef0000000000000000000000000000000000000000000000000"), + types.HexToKeyBytes("5555555555555555555555555555555555555555555555555555555555555555"), + types.HexToKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), + types.HexToKeyBytes("abcdef1234567890000000000000000000000000000000000000000000000000"), + } + ctx := context.Background() + db := populateDB(t, 32, ids) + st := &SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := NewDBSet(db, st, 32, 24) + seq, err := s.Items(ctx) + require.NoError(t, err) + require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", + firstKey(t, seq).String()) + has, err := s.Has(ctx, types.HexToKeyBytes("9876000000000000000000000000000000000000000000000000000000000000")) + require.NoError(t, err) + require.False(t, has) + + for _, tc := range []struct { + xIdx, yIdx int + limit int + fp string + count int + startIdx, endIdx int + }{ + { + xIdx: 0, + yIdx: 0, + limit: 0, + fp: "000000000000000000000000", + count: 0, + startIdx: 0, + endIdx: 0, + }, + { + xIdx: 1, + yIdx: 1, + limit: -1, + fp: "642464b773377bbddddddddd", + count: 5, + startIdx: 1, + endIdx: 1, + }, + { + xIdx: 0, + yIdx: 3, + limit: -1, + fp: "4761032dcfe98ba555555555", + count: 3, + startIdx: 0, + endIdx: 3, + }, + { + xIdx: 2, + yIdx: 0, + limit: -1, + fp: "761032cfe98ba54ddddddddd", + count: 3, + startIdx: 2, + endIdx: 0, + }, + { + xIdx: 3, + yIdx: 2, + limit: 3, + fp: "2345679abcdef01888888888", + count: 3, + startIdx: 3, + endIdx: 1, + }, + } { + name := fmt.Sprintf("%d-%d_%d", tc.xIdx, tc.yIdx, tc.limit) + t.Run(name, func(t *testing.T) { + t.Logf("x %s y %s limit %d", ids[tc.xIdx], ids[tc.yIdx], tc.limit) + info, err := s.GetRangeInfo(ctx, ids[tc.xIdx], ids[tc.yIdx], tc.limit) + require.NoError(t, err) + require.Equal(t, tc.count, info.Count) + require.Equal(t, tc.fp, info.Fingerprint.String()) + require.Equal(t, ids[tc.startIdx], firstKey(t, info.Items)) + has, err := s.Has(ctx, ids[tc.startIdx]) + require.NoError(t, err) + require.True(t, has) + has, err = s.Has(ctx, ids[tc.endIdx]) + require.NoError(t, err) + require.True(t, has) + }) + } +} + +func TestDBItemStore_Add(t *testing.T) { + ids := []types.KeyBytes{ + types.HexToKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + types.HexToKeyBytes("123456789abcdef0000000000000000000000000000000000000000000000000"), + types.HexToKeyBytes("5555555555555555555555555555555555555555555555555555555555555555"), + types.HexToKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), + } + db := populateDB(t, 32, ids) + st := &SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := NewDBSet(db, st, 32, 24) + ctx := context.Background() + seq, err := s.Items(ctx) + require.NoError(t, err) + require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", + firstKey(t, seq).String()) + + newID := types.HexToKeyBytes("abcdef1234567890000000000000000000000000000000000000000000000000") + require.NoError(t, s.Add(context.Background(), newID)) + + // // QQQQQ: rm + // s.ft.traceEnabled = true + // var sb strings.Builder + // s.ft.dump(&sb) + // t.Logf("tree:\n%s", sb.String()) + + info, err := s.GetRangeInfo(ctx, ids[2], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 3, info.Count) + require.Equal(t, "761032cfe98ba54ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[2], firstKey(t, info.Items)) +} + +func TestDBItemStore_Copy(t *testing.T) { + ids := []types.KeyBytes{ + types.HexToKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + types.HexToKeyBytes("123456789abcdef0000000000000000000000000000000000000000000000000"), + types.HexToKeyBytes("5555555555555555555555555555555555555555555555555555555555555555"), + types.HexToKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), + } + db := populateDB(t, 32, ids) + st := &SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := NewDBSet(db, st, 32, 24) + ctx := context.Background() + seq, err := s.Items(ctx) + require.NoError(t, err) + require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", + firstKey(t, seq).String()) + + copy := s.Copy() + + info, err := copy.GetRangeInfo(ctx, ids[2], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 2, info.Count) + require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[2], firstKey(t, info.Items)) + + newID := types.HexToKeyBytes("abcdef1234567890000000000000000000000000000000000000000000000000") + require.NoError(t, copy.Add(context.Background(), newID)) + + info, err = s.GetRangeInfo(ctx, ids[2], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 2, info.Count) + require.Equal(t, "dddddddddddddddddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[2], firstKey(t, info.Items)) + + info, err = copy.GetRangeInfo(ctx, ids[2], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 3, info.Count) + require.Equal(t, "761032cfe98ba54ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[2], firstKey(t, info.Items)) +} + +func TestDBItemStore_Advance(t *testing.T) { + ids := []types.KeyBytes{ + types.HexToKeyBytes("0000000000000000000000000000000000000000000000000000000000000000"), + types.HexToKeyBytes("123456789abcdef0000000000000000000000000000000000000000000000000"), + types.HexToKeyBytes("5555555555555555555555555555555555555555555555555555555555555555"), + types.HexToKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), + } + db := populateDB(t, 32, ids) + st := &SyncedTable{ + TableName: "foo", + IDColumn: "id", + } + s := NewDBSet(db, st, 32, 24) + ctx := context.Background() + require.NoError(t, s.EnsureLoaded(ctx)) + + copy := s.Copy() + + info, err := s.GetRangeInfo(ctx, ids[0], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + info, err = copy.GetRangeInfo(ctx, ids[0], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + insertDBItems(t, db, []types.KeyBytes{ + types.HexToKeyBytes("abcdef1234567890000000000000000000000000000000000000000000000000"), + }) + + info, err = s.GetRangeInfo(ctx, ids[0], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + info, err = copy.GetRangeInfo(ctx, ids[0], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + require.NoError(t, s.Advance(ctx)) + + info, err = s.GetRangeInfo(ctx, ids[0], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 5, info.Count) + require.Equal(t, "642464b773377bbddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + info, err = copy.GetRangeInfo(ctx, ids[0], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 4, info.Count) + require.Equal(t, "cfe98ba54761032ddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) + + info, err = s.Copy().GetRangeInfo(ctx, ids[0], ids[0], -1) + require.NoError(t, err) + require.Equal(t, 5, info.Count) + require.Equal(t, "642464b773377bbddddddddd", info.Fingerprint.String()) + require.Equal(t, ids[0], firstKey(t, info.Items)) +} diff --git a/sync2/dbsync/fptree.go b/sync2/dbsync/fptree.go index 2b98f6e8c6..8e38ee0887 100644 --- a/sync2/dbsync/fptree.go +++ b/sync2/dbsync/fptree.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/binary" - "encoding/hex" "errors" "fmt" "io" @@ -15,7 +14,7 @@ import ( "strings" "sync" - "github.com/spacemeshos/go-spacemesh/sync2/hashsync" + "github.com/spacemeshos/go-spacemesh/sync2/types" ) var errEasySplitFailed = errors.New("easy split failed") @@ -34,8 +33,8 @@ func (t *trace) enter(format string, args ...any) { return } for n, arg := range args { - if it, ok := arg.(hashsync.Iterator); ok { - args[n] = formatIter(it) + if it, ok := arg.(types.Seq); ok { + args[n] = formatSeq(it) } } msg := fmt.Sprintf(format, args...) @@ -55,8 +54,8 @@ func (t *trace) leave(results ...any) { results = []any{fmt.Sprintf("", err)} break } - if it, ok := r.(hashsync.Iterator); ok { - results[n] = formatIter(it) + if it, ok := r.(types.Seq); ok { + results[n] = formatSeq(it) } } msg := t.traceStack[len(t.traceStack)-1] @@ -74,8 +73,8 @@ func (t *trace) leave(results ...any) { func (t *trace) log(format string, args ...any) { if t.traceEnabled { for n, arg := range args { - if it, ok := arg.(hashsync.Iterator); ok { - args[n] = formatIter(it) + if it, ok := arg.(types.Seq); ok { + args[n] = formatSeq(it) } } msg := fmt.Sprintf(format, args...) @@ -84,7 +83,7 @@ func (t *trace) log(format string, args ...any) { } const ( - fingerprintBytes = 12 + FingerprintSize = types.FingerprintSize // cachedBits = 24 // cachedSize = 1 << cachedBits // cacheMask = cachedSize - 1 @@ -92,42 +91,6 @@ const ( bit63 = 1 << 63 ) -type fingerprint [fingerprintBytes]byte - -func (fp fingerprint) Compare(other fingerprint) int { - return bytes.Compare(fp[:], other[:]) -} - -func (fp fingerprint) String() string { - return hex.EncodeToString(fp[:]) -} - -func (fp *fingerprint) update(h []byte) { - for n := range *fp { - (*fp)[n] ^= h[n] - } -} - -func (fp *fingerprint) bitFromLeft(n int) bool { - if n > fingerprintBytes*8 { - panic("BUG: bad fingerprint bit index") - } - return (fp[n>>3]>>(7-n&0x7))&1 != 0 -} - -func hexToFingerprint(s string) fingerprint { - b, err := hex.DecodeString(s) - if err != nil { - panic("bad hex fingerprint: " + err.Error()) - } - var fp fingerprint - if len(b) != len(fp) { - panic("bad hex fingerprint") - } - copy(fp[:], b) - return fp -} - type nodeIndex uint32 const noIndex = ^nodeIndex(0) @@ -136,7 +99,7 @@ type nodePool struct { rcPool[node, nodeIndex] } -func (np *nodePool) add(fp fingerprint, c uint32, left, right nodeIndex) nodeIndex { +func (np *nodePool) add(fp types.Fingerprint, c uint32, left, right nodeIndex) nodeIndex { // panic("TBD: this is invalid, adds unneeded refs") // if left != noIndex { // np.rcPool.ref(left) @@ -168,7 +131,7 @@ func (np *nodePool) node(idx nodeIndex) node { // The nodes are immutable except for refCount field, which should // only be used directly by nodePool methods type node struct { - fp fingerprint + fp types.Fingerprint c uint32 left, right nodeIndex } @@ -233,7 +196,7 @@ func (p prefix) highBit() bool { return p.bits()>>(p.len()-1) != 0 } -func (p prefix) minID(b KeyBytes) { +func (p prefix) minID(b types.KeyBytes) { if len(b) < 8 { panic("BUG: id slice too small") } @@ -244,7 +207,7 @@ func (p prefix) minID(b KeyBytes) { } } -func (p prefix) idAfter(b KeyBytes) { +func (p prefix) idAfter(b types.KeyBytes) { if len(b) < 8 { panic("BUG: id slice too small") } @@ -264,7 +227,7 @@ func (p prefix) idAfter(b KeyBytes) { } // QQQQQ: rm ? -// func (p prefix) maxID(b KeyBytes) { +// func (p prefix) maxID(b types.KeyBytes) { // if len(b) < 8 { // panic("BUG: id slice too small") // } @@ -289,24 +252,24 @@ func (p prefix) shift() prefix { } } -func (p prefix) match(b KeyBytes) bool { +func (p prefix) match(b types.KeyBytes) bool { return load64(b)>>(64-p.len()) == p.bits() } -func load64(h KeyBytes) uint64 { +func load64(h types.KeyBytes) uint64 { return binary.BigEndian.Uint64(h[:8]) } -func preFirst0(h KeyBytes) prefix { +func preFirst0(h types.KeyBytes) prefix { l := min(maxPrefixLen, bits.LeadingZeros64(^load64(h))) return mkprefix((1<= 0 } -func (ac *aggContext) fingreprintBelowY(fp fingerprint) bool { - k := make(KeyBytes, len(ac.x)) +func (ac *aggContext) fingreprintBelowY(fp types.Fingerprint) bool { + k := make(types.KeyBytes, len(ac.x)) copy(k, fp[:]) - k[:fingerprintBytes].inc() // 1 after max key derived from the fingerprint + k[:FingerprintSize].Inc() // 1 after max key derived from the fingerprint return bytes.Compare(k, ac.y) <= 0 } @@ -404,7 +369,7 @@ func (ac *aggContext) pruneY(node node, p prefix) bool { // to determine if it's below y return false } - k := make(KeyBytes, len(ac.y)) + k := make(types.KeyBytes, len(ac.y)) copy(k, node.fp[:]) return bytes.Compare(k, ac.y) >= 0 } @@ -434,7 +399,7 @@ func (ac *aggContext) maybeIncludeNode(node node, p prefix) bool { ac.lastPrefix = &p return true } - ac.fp.update(node.fp[:]) + ac.fp.Update(node.fp[:]) ac.count += node.c ac.lastPrefix = &p if ac.easySplit && ac.limit == 0 { @@ -455,13 +420,6 @@ func (ac *aggContext) maybeIncludeNode(node node, p prefix) bool { return true } -type idStore interface { - clone() idStore - registerHash(h KeyBytes) error - start(ctx context.Context) hashsync.Iterator - iter(ctx context.Context, from KeyBytes) hashsync.Iterator -} - type fpTree struct { trace // rmme idStore @@ -520,10 +478,10 @@ func (ft *fpTree) clone() *fpTree { } } -func (ft *fpTree) pushDown(fpA, fpB fingerprint, p prefix, curCount uint32) nodeIndex { +func (ft *fpTree) pushDown(fpA, fpB types.Fingerprint, p prefix, curCount uint32) nodeIndex { // ft.log("QQQQQ: pushDown: fpA %s fpB %s p %s", fpA, fpB, p) fpCombined := fpA - fpCombined.update(fpB[:]) + fpCombined.Update(fpB[:]) if ft.maxDepth != 0 && p.len() == ft.maxDepth { // ft.log("QQQQQ: pushDown: add at maxDepth") return ft.np.add(fpCombined, curCount+1, noIndex, noIndex) @@ -531,8 +489,8 @@ func (ft *fpTree) pushDown(fpA, fpB fingerprint, p prefix, curCount uint32) node if curCount != 1 { panic("BUG: pushDown of non-1-leaf below maxDepth") } - dirA := fpA.bitFromLeft(p.len()) - dirB := fpB.bitFromLeft(p.len()) + dirA := fpA.BitFromLeft(p.len()) + dirB := fpB.BitFromLeft(p.len()) // ft.log("QQQQQ: pushDown: bitFromLeft %d: dirA %v dirB %v", p.len(), dirA, dirB) if dirA == dirB { childIdx := ft.pushDown(fpA, fpB, p.dir(dirA), 1) @@ -560,7 +518,7 @@ func (ft *fpTree) pushDown(fpA, fpB fingerprint, p prefix, curCount uint32) node } } -func (ft *fpTree) addValue(fp fingerprint, p prefix, idx nodeIndex) nodeIndex { +func (ft *fpTree) addValue(fp types.Fingerprint, p prefix, idx nodeIndex) nodeIndex { if idx == noIndex { r := ft.np.add(fp, 1, noIndex, noIndex) // ft.log("QQQQQ: addValue: addNew fp %s p %s => %d", fp, p, r) @@ -576,8 +534,8 @@ func (ft *fpTree) addValue(fp fingerprint, p prefix, idx nodeIndex) nodeIndex { return r } fpCombined := fp - fpCombined.update(node.fp[:]) - if fp.bitFromLeft(p.len()) { + fpCombined.Update(node.fp[:]) + if fp.BitFromLeft(p.len()) { // ft.log("QQQQQ: addValue: replaceRight fp %s p %s oldIdx %d", fp, p, idx) if node.left != noIndex { ft.np.ref(node.left) @@ -600,9 +558,9 @@ func (ft *fpTree) addValue(fp fingerprint, p prefix, idx nodeIndex) nodeIndex { } } -func (ft *fpTree) addStoredHash(h KeyBytes) { - var fp fingerprint - fp.update(h) +func (ft *fpTree) addStoredHash(h types.KeyBytes) { + var fp types.Fingerprint + fp.Update(h) ft.rootMtx.Lock() defer ft.rootMtx.Unlock() ft.log("addStoredHash: h %s fp %s", h, fp) @@ -611,7 +569,7 @@ func (ft *fpTree) addStoredHash(h KeyBytes) { ft.releaseNode(oldRoot) } -func (ft *fpTree) addHash(h KeyBytes) error { +func (ft *fpTree) addHash(h types.KeyBytes) error { ft.log("addHash: h %s", h) if err := ft.idStore.registerHash(h); err != nil { return err @@ -649,7 +607,7 @@ func (ft *fpTree) followPrefix(from nodeIndex, p, followed prefix) (idx nodeInde // aggregated items. // It returns a boolean indicating whether the limit or the right edge (y) was reached and // an error, if any. -func (ft *fpTree) aggregateEdge(x, y KeyBytes, p prefix, ac *aggContext) (cont bool, err error) { +func (ft *fpTree) aggregateEdge(x, y types.KeyBytes, p prefix, ac *aggContext) (cont bool, err error) { ft.enter("aggregateEdge: x %s y %s p %s limit %d count %d", x, y, p, ac.limit, ac.count) defer func() { ft.leave(ac.limit, ac.count, cont, err) @@ -659,63 +617,70 @@ func (ft *fpTree) aggregateEdge(x, y KeyBytes, p prefix, ac *aggContext) (cont b // so we'll have to retry using slower strategy return false, errEasySplitFailed } - if ac.limit == 0 && ac.end != nil { + if ac.limit == 0 && ac.next != nil { ft.log("aggregateEdge: limit is 0 and end already set") return false, nil } - var startFrom KeyBytes + var startFrom types.KeyBytes if x == nil { - startFrom = make(KeyBytes, ft.keyLen) + startFrom = make(types.KeyBytes, ft.keyLen) p.minID(startFrom) } else { startFrom = x } ft.log("aggregateEdge: startFrom %s", startFrom) - it := ft.iter(ac.ctx, startFrom) + seq, err := ft.from(ac.ctx, startFrom) + if err != nil { + return false, err + } if ac.limit == 0 { - ac.end = it.Clone() + next, err := seq.First() + if err != nil { + return false, err + } + ac.next = next.(types.KeyBytes).Clone() if x != nil { - ft.log("aggregateEdge: limit 0: x is not nil, setting start to %s", ac.start) - ac.start = ac.end + ft.log("aggregateEdge: limit 0: x is not nil, setting start to %s", ac.next.String()) + ac.items = seq } - ft.log("aggregateEdge: limit is 0 at %s", ac.end) + ft.log("aggregateEdge: limit is 0 at %s", ac.next.String()) return false, nil } if x != nil { - ac.start = it.Clone() - ft.log("aggregateEdge: x is not nil, setting start to %s", ac.start) + ac.items = seq + ft.log("aggregateEdge: x is not nil, setting start to %s", seq) } - for range ft.np.node(ft.root).c { - id, err := it.Key() + n := ft.np.node(ft.root).c + for id, err := range seq { if err != nil { return false, err } + if ac.limit == 0 { + ac.next = id.(types.KeyBytes).Clone() + ft.log("aggregateEdge: limit exhausted") + return false, nil + } + if n == 0 { + break + } ft.log("aggregateEdge: ID %s", id) if y != nil && id.Compare(y) >= 0 { - ac.end = it + ac.next = id.(types.KeyBytes).Clone() ft.log("aggregateEdge: ID is over Y: %s", id) return false, nil } - if !p.match(id.(KeyBytes)) { + if !p.match(id.(types.KeyBytes)) { ft.log("aggregateEdge: ID doesn't match the prefix: %s", id) ac.lastPrefix = &p return true, nil } - ac.fp.update(id.(KeyBytes)) + ac.fp.Update(id.(types.KeyBytes)) ac.count++ if ac.limit > 0 { ac.limit-- } - if err := it.Next(); err != nil { - ft.log("aggregateEdge: Next failed: %v", err) - return false, err - } - if ac.limit == 0 { - ac.end = it - ft.log("aggregateEdge: limit exhausted") - return false, nil - } + n-- } return true, nil @@ -1024,17 +989,32 @@ func (ft *fpTree) aggregateInterval(ac *aggContext) (err error) { } } -func (ft *fpTree) endIterFromPrefix(ac *aggContext, p prefix) hashsync.Iterator { - k := make(KeyBytes, ft.keyLen) +func (ft *fpTree) startFromPrefix(ac *aggContext, p prefix) (types.Seq, error) { + k := make(types.KeyBytes, ft.keyLen) p.idAfter(k) - ft.log("endIterFromPrefix: p: %s idAfter: %s", p, k) - return ft.iter(ac.ctx, k) + ft.log("startFromPrefix: p: %s idAfter: %s", p, k) + return ft.from(ac.ctx, k) } -func (ft *fpTree) fingerprintInterval(ctx context.Context, x, y KeyBytes, limit int) (fpr fpResult, err error) { +func (ft *fpTree) nextFromPrefix(ac *aggContext, p prefix) (types.KeyBytes, error) { + seq, err := ft.startFromPrefix(ac, p) + if err != nil { + return nil, err + } + id, err := seq.First() + if err != nil { + return nil, err + } + if id == nil { + return nil, nil + } + return id.(types.KeyBytes).Clone(), nil +} + +func (ft *fpTree) fingerprintInterval(ctx context.Context, x, y types.KeyBytes, limit int) (fpr fpResult, err error) { ft.enter("fingerprintInterval: x %s y %s limit %d", x, y, limit) defer func() { - ft.leave(fpr.fp, fpr.count, fpr.itype, fpr.start, fpr.end, err) + ft.leave(fpr.fp, fpr.count, fpr.itype, fpr.items, fpr.next, err) }() ac := aggContext{ctx: ctx, x: x, y: y, limit: limit} if err := ft.aggregateInterval(&ac); err != nil { @@ -1047,29 +1027,47 @@ func (ft *fpTree) fingerprintInterval(ctx context.Context, x, y KeyBytes, limit } if ac.total == 0 { + fpr.items = types.EmptySeq() return fpr, nil } - if ac.start != nil { - ft.log("fingerprintInterval: start %s", ac.start) - fpr.start = ac.start + if ac.items != nil { + ft.log("fingerprintInterval: items %s", ac.items) + fpr.items = ac.items } else { - fpr.start = ft.iter(ac.ctx, x) - ft.log("fingerprintInterval: start from x: %s", fpr.start) + fpr.items, err = ft.from(ac.ctx, x) + if err != nil { + return fpResult{}, err + } + ft.log("fingerprintInterval: start from x: %s", fpr.items) } - if ac.end != nil { - ft.log("fingerprintInterval: end %s", ac.end) - fpr.end = ac.end + if ac.next != nil { + ft.log("fingerprintInterval: next %s", ac.next) + fpr.next = ac.next } else if (fpr.itype == 0 && limit < 0) || fpr.count == 0 { - fpr.end = fpr.start - ft.log("fingerprintInterval: end at start %s", fpr.end) + next, err := fpr.items.First() + if err != nil { + return fpResult{}, err + } + if next != nil { + fpr.next = next.(types.KeyBytes).Clone() + } + ft.log("fingerprintInterval: next at start %s", fpr.next) } else if ac.lastPrefix != nil { - fpr.end = ft.endIterFromPrefix(&ac, *ac.lastPrefix) - ft.log("fingerprintInterval: end at lastPrefix %s -> %s", *ac.lastPrefix, fpr.end) + fpr.next, err = ft.nextFromPrefix(&ac, *ac.lastPrefix) + ft.log("fingerprintInterval: next at lastPrefix %s -> %s", *ac.lastPrefix, fpr.next) } else { - fpr.end = ft.iter(ac.ctx, y) - ft.log("fingerprintInterval: end at y: %s", fpr.end) + seq, err := ft.from(ac.ctx, y) + if err != nil { + return fpResult{}, err + } + next, err := seq.First() + if err != nil { + return fpResult{}, err + } + fpr.next = next.(types.KeyBytes).Clone() + ft.log("fingerprintInterval: next at y: %s", fpr.next) } return fpr, nil @@ -1077,18 +1075,18 @@ func (ft *fpTree) fingerprintInterval(ctx context.Context, x, y KeyBytes, limit type splitResult struct { part0, part1 fpResult - middle KeyBytes + middle types.KeyBytes } // easySplit splits an interval in two parts trying to do it in such way that the first // part has close to limit items while not making any idStore queries so that the database // is not accessed. If the split can't be done, which includes the situation where one of // the sides has 0 items, easySplit returns errEasySplitFailed error -func (ft *fpTree) easySplit(ctx context.Context, x, y KeyBytes, limit int) (sr splitResult, err error) { +func (ft *fpTree) easySplit(ctx context.Context, x, y types.KeyBytes, limit int) (sr splitResult, err error) { ft.enter("easySplit: x %s y %s limit %d", x, y, limit) defer func() { - ft.leave(sr.part0.fp, sr.part0.count, sr.part0.itype, sr.part0.start, sr.part0.end, - sr.part1.fp, sr.part1.count, sr.part1.itype, sr.part1.start, sr.part1.end, err) + ft.leave(sr.part0.fp, sr.part0.count, sr.part0.itype, sr.part0.items, sr.part0.next, + sr.part1.fp, sr.part1.count, sr.part1.itype, sr.part1.items, sr.part1.next, err) }() if limit < 0 { panic("BUG: easySplit with limit < 0") @@ -1117,21 +1115,32 @@ func (ft *fpTree) easySplit(ctx context.Context, x, y KeyBytes, limit int) (sr s // ac.start / ac.end are only set in aggregateEdge which fails with // errEasySplitFailed if easySplit is enabled, so we can ignore them here - middle := make(KeyBytes, ft.keyLen) + middle := make(types.KeyBytes, ft.keyLen) ac.lastPrefix0.idAfter(middle) + ft.log("easySplit: lastPrefix0 %s middle %s", ac.lastPrefix0, middle) + items, err := ft.from(ac.ctx, x) + if err != nil { + return splitResult{}, err + } part0 := fpResult{ fp: ac.fp0, count: ac.count0, itype: ac.itype, - start: ft.iter(ac.ctx, x), - end: ft.endIterFromPrefix(&ac, *ac.lastPrefix0), + items: items, + // next is only used for splitting + // next: ft.nextFromPrefix(&ac, *ac.lastPrefix0), + } + items, err = ft.startFromPrefix(&ac, *ac.lastPrefix0) + if err != nil { + return splitResult{}, err } part1 := fpResult{ fp: ac.fp, count: ac.count, itype: ac.itype, - start: part0.end.Clone(), - end: ft.endIterFromPrefix(&ac, *ac.lastPrefix), + items: items, + // next is only used for splitting + // next: ft.nextFromPrefix(&ac, *ac.lastPrefix), } return splitResult{ part0: part0, @@ -1174,22 +1183,24 @@ func (ft *fpTree) count() int { return int(ft.np.node(ft.root).c) } -type iterFormatter struct { - it hashsync.Iterator +type seqFormatter struct { + seq types.Seq } -func (f iterFormatter) String() string { - if k, err := f.it.Key(); err != nil { - return fmt.Sprintf("", err) - } else { +func (f seqFormatter) String() string { + for k, err := range f.seq { + if err != nil { + return fmt.Sprintf("", err) + } return k.(fmt.Stringer).String() } + return "" } -func formatIter(it hashsync.Iterator) fmt.Stringer { - return iterFormatter{it: it} +func formatSeq(seq types.Seq) fmt.Stringer { + return seqFormatter{seq: seq} } // TBD: optimize, get rid of binary.BigEndian.* -// TBD: QQQQQ: detect unbalancedness when a ref gets too many items +// TBD: detect unbalancedness when a ref gets too many items // TBD: QQQQQ: ItemStore.Close(): close db conns, also free fpTree instead of using finalizer! diff --git a/sync2/dbsync/fptree_test.go b/sync2/dbsync/fptree_test.go index 3c48d085b7..768db9ed04 100644 --- a/sync2/dbsync/fptree_test.go +++ b/sync2/dbsync/fptree_test.go @@ -4,25 +4,25 @@ import ( "context" "encoding/binary" "fmt" - "math" "math/rand" "reflect" - "runtime" "slices" "strings" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/common/util" "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sql/statesql" - "github.com/spacemeshos/go-spacemesh/sync2/hashsync" + "github.com/spacemeshos/go-spacemesh/sync2/types" ) +func firstKey(t *testing.T, seq types.Seq) types.KeyBytes { + k, err := seq.First() + require.NoError(t, err) + return k.(types.KeyBytes) +} + func TestPrefix(t *testing.T) { for _, tc := range []struct { p prefix @@ -144,16 +144,10 @@ func TestPrefix(t *testing.T) { require.Equal(t, tc.shift, tc.p.shift()) } - expMinID := types.HexToHash32(tc.minID) - var minID types.Hash32 - tc.p.minID(minID[:]) + expMinID := types.HexToKeyBytes(tc.minID) + minID := make(types.KeyBytes, 32) + tc.p.minID(minID) require.Equal(t, expMinID, minID) - - // QQQQQ: TBD: rm (probably with maxid fields?) - // expMaxID := types.HexToHash32(tc.maxID) - // var maxID types.Hash32 - // tc.p.maxID(maxID[:]) - // require.Equal(t, expMaxID, maxID) }) } } @@ -189,9 +183,9 @@ func TestCommonPrefix(t *testing.T) { p: 0xabcdef12345678ba, }, } { - a := types.HexToHash32(tc.a) - b := types.HexToHash32(tc.b) - require.Equal(t, tc.p, commonPrefix(a[:], b[:])) + a := types.HexToKeyBytes(tc.a) + b := types.HexToKeyBytes(tc.b) + require.Equal(t, tc.p, commonPrefix(a, b)) } } @@ -213,7 +207,7 @@ func newFakeATXIDStore(t *testing.T, db sql.Database, maxDepth int) *fakeIDDBSto return &fakeIDDBStore{db: db, t: t, sqlIDStore: newSQLIDStore(db, sts, 32)} } -func (s *fakeIDDBStore) registerHash(h KeyBytes) error { +func (s *fakeIDDBStore) registerHash(h types.KeyBytes) error { if err := s.sqlIDStore.registerHash(h); err != nil { return err } @@ -229,6 +223,12 @@ func (s *fakeIDDBStore) registerHash(h KeyBytes) error { type idStoreFunc func(maxDepth int) idStore +func requireEmpty(t *testing.T, seq types.Seq) { + for _, _ = range seq { + require.Fail(t, "expected an empty sequence") + } +} + func testFPTree(t *testing.T, makeIDStore idStoreFunc) { type rangeTestCase struct { xIdx, yIdx int @@ -705,17 +705,41 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { }, }, }, + { + name: "ids8", + maxDepth: 24, + ids: []string{ + "0e69888877324da35693decc7ded1b2bac16d394ced869af494568d66473a6f0", + "3a78db9e386493402561d9c6f69a6b434a62388f61d06d960598ebf29a3a2187", + "66c9aa8f3be7da713db66e56cc165a46764f88d3113244dd5964bb0a10ccacc3", + "90b25f2d1ee9c9e2d20df5f2226d14ee4223ea27ba565a49aa66a9c44a51c241", + "9e11fdb099f1118144738f9b68ca601e74b97280fd7bbc97cfc377f432e9b7b5", + "c1690e47798295cca02392cbfc0a86cb5204878c04a29b3ae7701b6b51681128", + }, + ranges: []rangeTestCase{ + { + x: "9e11fdb099f1118144738f9b68ca601e74b97280fd7bbc97cfc377f432e9b7b5", + y: "0e69880000000000000000000000000000000000000000000000000000000000", + limit: -1, + fp: "5f78f3f7e073844de4501d50", + count: 2, + itype: 1, + startIdx: 4, + endIdx: 0, + }, + }, + }, } { t.Run(tc.name, func(t *testing.T) { var np nodePool idStore := makeIDStore(tc.maxDepth) ft := newFPTree(&np, idStore, 32, tc.maxDepth) // ft.traceEnabled = true - var hs []types.Hash32 + var hs []types.KeyBytes for _, hex := range tc.ids { - h := types.HexToHash32(hex) + h := types.HexToKeyBytes(hex) hs = append(hs, h) - ft.addHash(h[:]) + ft.addHash(h) } var sb strings.Builder @@ -725,11 +749,11 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { checkTree(t, ft, tc.maxDepth) for _, rtc := range tc.ranges { - var x, y types.Hash32 + var x, y types.KeyBytes var name string if rtc.x != "" { - x = types.HexToHash32(rtc.x) - y = types.HexToHash32(rtc.y) + x = types.HexToKeyBytes(rtc.x) + y = types.HexToKeyBytes(rtc.y) name = fmt.Sprintf("%s-%s_%d", rtc.x, rtc.y, rtc.limit) } else { x = hs[rtc.xIdx] @@ -739,7 +763,7 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { t.Run(name, func(t *testing.T) { fpr, err := ft.fingerprintInterval( context.Background(), - x[:], y[:], rtc.limit, + x, y, rtc.limit, ) require.NoError(t, err) assert.Equal(t, rtc.fp, fpr.fp.String(), "fp") @@ -747,19 +771,19 @@ func testFPTree(t *testing.T, makeIDStore idStoreFunc) { assert.Equal(t, rtc.itype, fpr.itype, "itype") if rtc.startIdx == -1 { - require.Nil(t, fpr.start, "start") + requireEmpty(t, fpr.items) } else { - require.NotNil(t, fpr.start, "start") - expK := KeyBytes(hs[rtc.startIdx][:]) - assert.Equal(t, expK, itKey(t, fpr.start), "start") + require.NotNil(t, fpr.items, "items") + expK := types.KeyBytes(hs[rtc.startIdx]) + assert.Equal(t, expK, firstKey(t, fpr.items), "items") } if rtc.endIdx == -1 { - require.Nil(t, fpr.end, "end") + require.Nil(t, fpr.next, "next") } else { - require.NotNil(t, fpr.end, "end") - expK := KeyBytes(hs[rtc.endIdx][:]) - assert.Equal(t, expK, itKey(t, fpr.end), "end") + require.NotNil(t, fpr.next, "next") + expK := types.KeyBytes(hs[rtc.endIdx]) + assert.Equal(t, expK, fpr.next, "next") } }) } @@ -794,35 +818,19 @@ func (noIDStore) clone() idStore { return &noIDStore{} } -func (noIDStore) registerHash(h KeyBytes) error { +func (noIDStore) registerHash(h types.KeyBytes) error { return nil } -func (noIDStore) start(ctx context.Context) hashsync.Iterator { +func (noIDStore) all(ctx context.Context) (types.Seq, error) { panic("no ID store") } -func (noIDStore) iter(ctx context.Context, from KeyBytes) hashsync.Iterator { - return noIter{} -} - -type noIter struct{} - -func (noIter) Key() (hashsync.Ordered, error) { - return make(KeyBytes, 32), nil -} - -func (noIter) Next() error { - panic("no ID store") +func (noIDStore) from(ctx context.Context, from types.KeyBytes) (types.Seq, error) { + return types.EmptySeq(), nil } -func (noIter) Clone() hashsync.Iterator { - return noIter{} -} - -var _ hashsync.Iterator = &noIter{} - // TestFPTreeNoIDStore tests that an fpTree can avoid using an idStore if X has only // 0 bits below max-depth and Y has only 1 bits below max-depth. It also checks that an fpTree // can avoid using an idStore in "relaxed count" mode for splitting ranges. @@ -830,18 +838,18 @@ func TestFPTreeNoIDStore(t *testing.T) { var np nodePool ft := newFPTree(&np, &noIDStore{}, 32, 24) // ft.traceEnabled = true - hashes := []KeyBytes{ - util.FromHex("1111111111111111111111111111111111111111111111111111111111111111"), - util.FromHex("2222222222222222222222222222222222222222222222222222222222222222"), - util.FromHex("4444444444444444444444444444444444444444444444444444444444444444"), - util.FromHex("8888888888888888888888888888888888888888888888888888888888888888"), + hashes := []types.KeyBytes{ + types.HexToKeyBytes("1111111111111111111111111111111111111111111111111111111111111111"), + types.HexToKeyBytes("2222222222222222222222222222222222222222222222222222222222222222"), + types.HexToKeyBytes("4444444444444444444444444444444444444444444444444444444444444444"), + types.HexToKeyBytes("8888888888888888888888888888888888888888888888888888888888888888"), } for _, h := range hashes { ft.addHash(h) } for _, tc := range []struct { - x, y KeyBytes + x, y types.KeyBytes limit int fp string count uint32 @@ -854,15 +862,19 @@ func TestFPTreeNoIDStore(t *testing.T) { count: 4, }, { - x: util.FromHex("1111110000000000000000000000000000000000000000000000000000000000"), - y: util.FromHex("1111120000000000000000000000000000000000000000000000000000000000"), + x: types.HexToKeyBytes( + "1111110000000000000000000000000000000000000000000000000000000000"), + y: types.HexToKeyBytes( + "1111120000000000000000000000000000000000000000000000000000000000"), limit: -1, fp: "111111111111111111111111", count: 1, }, { - x: util.FromHex("0000000000000000000000000000000000000000000000000000000000000000"), - y: util.FromHex("9000000000000000000000000000000000000000000000000000000000000000"), + x: types.HexToKeyBytes( + "0000000000000000000000000000000000000000000000000000000000000000"), + y: types.HexToKeyBytes( + "9000000000000000000000000000000000000000000000000000000000000000"), limit: -1, fp: "ffffffffffffffffffffffff", count: 4, @@ -878,16 +890,16 @@ func TestFPTreeNoIDStore(t *testing.T) { func TestFPTreeClone(t *testing.T) { var np nodePool ft1 := newFPTree(&np, newInMemIDStore(32), 32, 24) - hashes := []types.Hash32{ - types.HexToHash32("1111111111111111111111111111111111111111111111111111111111111111"), - types.HexToHash32("3333333333333333333333333333333333333333333333333333333333333333"), - types.HexToHash32("4444444444444444444444444444444444444444444444444444444444444444"), + hashes := []types.KeyBytes{ + types.HexToKeyBytes("1111111111111111111111111111111111111111111111111111111111111111"), + types.HexToKeyBytes("3333333333333333333333333333333333333333333333333333333333333333"), + types.HexToKeyBytes("4444444444444444444444444444444444444444444444444444444444444444"), } - ft1.addHash(hashes[0][:]) - ft1.addHash(hashes[1][:]) + ft1.addHash(hashes[0]) + ft1.addHash(hashes[1]) ctx := context.Background() - fpr, err := ft1.fingerprintInterval(ctx, hashes[0][:], hashes[0][:], -1) + fpr, err := ft1.fingerprintInterval(ctx, hashes[0], hashes[0], -1) require.NoError(t, err) require.Equal(t, "222222222222222222222222", fpr.fp.String(), "fp") require.Equal(t, uint32(2), fpr.count, "count") @@ -908,22 +920,22 @@ func TestFPTreeClone(t *testing.T) { t.Logf("ft2 after-clone:\n%s", sb.String()) // original tree unchanged --- rmme!!!! - fpr, err = ft1.fingerprintInterval(ctx, hashes[0][:], hashes[0][:], -1) + fpr, err = ft1.fingerprintInterval(ctx, hashes[0], hashes[0], -1) require.NoError(t, err) require.Equal(t, "222222222222222222222222", fpr.fp.String(), "fp") require.Equal(t, uint32(2), fpr.count, "count") require.Equal(t, 0, fpr.itype, "itype") - ft2.addHash(hashes[2][:]) + ft2.addHash(hashes[2]) - fpr, err = ft2.fingerprintInterval(ctx, hashes[0][:], hashes[0][:], -1) + fpr, err = ft2.fingerprintInterval(ctx, hashes[0], hashes[0], -1) require.NoError(t, err) require.Equal(t, "666666666666666666666666", fpr.fp.String(), "fp") require.Equal(t, uint32(3), fpr.count, "count") require.Equal(t, 0, fpr.itype, "itype") // original tree unchanged - fpr, err = ft1.fingerprintInterval(ctx, hashes[0][:], hashes[0][:], -1) + fpr, err = ft1.fingerprintInterval(ctx, hashes[0], hashes[0], -1) require.NoError(t, err) require.Equal(t, "222222222222222222222222", fpr.fp.String(), "fp") require.Equal(t, uint32(2), fpr.count, "count") @@ -943,20 +955,20 @@ func TestFPTreeClone(t *testing.T) { require.Zero(t, np.count()) } -type hashList []types.Hash32 +type hashList []types.KeyBytes -func (l hashList) findGTE(h types.Hash32) int { - p, _ := slices.BinarySearchFunc(l, h, func(a, b types.Hash32) int { +func (l hashList) findGTE(h types.KeyBytes) int { + p, _ := slices.BinarySearchFunc(l, h, func(a, b types.KeyBytes) int { return a.Compare(b) }) return p } -func (l hashList) keyAt(p int) KeyBytes { +func (l hashList) keyAt(p int) types.KeyBytes { if p == len(l) { p = 0 } - return KeyBytes(l[p][:]) + return types.KeyBytes(l[p]) } func checkNode(t *testing.T, ft *fpTree, idx nodeIndex, depth int) { @@ -967,18 +979,18 @@ func checkNode(t *testing.T, ft *fpTree, idx nodeIndex, depth int) { } } else { require.Less(t, depth, ft.maxDepth) - var expFP fingerprint + var expFP types.Fingerprint var expCount uint32 if node.left != noIndex { checkNode(t, ft, node.left, depth+1) left := ft.np.node(node.left) - expFP.update(left.fp[:]) + expFP.Update(left.fp[:]) expCount += left.c } if node.right != noIndex { checkNode(t, ft, node.right, depth+1) right := ft.np.node(node.right) - expFP.update(right.fp[:]) + expFP.Update(right.fp[:]) expCount += right.c } require.Equal(t, expFP, node.fp, "node fp at depth %d", depth) @@ -1006,11 +1018,11 @@ func repeatTestFPTreeManyItems( } type fpResultWithBounds struct { - fp fingerprint + fp types.Fingerprint count uint32 itype int - start KeyBytes - end KeyBytes + start types.KeyBytes + next types.KeyBytes } func toFPResultWithBounds(t *testing.T, fpr fpResult) fpResultWithBounds { @@ -1018,17 +1030,15 @@ func toFPResultWithBounds(t *testing.T, fpr fpResult) fpResultWithBounds { fp: fpr.fp, count: fpr.count, itype: fpr.itype, + next: fpr.next, } - if fpr.start != nil { - r.start = itKey(t, fpr.start) - } - if fpr.end != nil { - r.end = itKey(t, fpr.end) + if fpr.items != nil { + r.start = firstKey(t, fpr.items) } return r } -func dumbFP(hs hashList, x, y types.Hash32, limit int) fpResultWithBounds { +func dumbFP(hs hashList, x, y types.KeyBytes, limit int) fpResultWithBounds { var fpr fpResultWithBounds l := len(hs) if l == 0 { @@ -1043,11 +1053,11 @@ func dumbFP(hs hashList, x, y types.Hash32, limit int) fpResultWithBounds { fpr.start = hs.keyAt(p) for { if p >= pY || limit == 0 { - fpr.end = hs.keyAt(p) + fpr.next = hs.keyAt(p) break } // t.Logf("XOR %s", hs[p].String()) - fpr.fp.update(hs.keyAt(p)) + fpr.fp.Update(hs.keyAt(p)) limit-- fpr.count++ p++ @@ -1057,10 +1067,10 @@ func dumbFP(hs hashList, x, y types.Hash32, limit int) fpResultWithBounds { fpr.start = hs.keyAt(p) for { if p >= len(hs) || limit == 0 { - fpr.end = hs.keyAt(p) + fpr.next = hs.keyAt(p) break } - fpr.fp.update(hs.keyAt(p)) + fpr.fp.Update(hs.keyAt(p)) limit-- fpr.count++ p++ @@ -1072,10 +1082,10 @@ func dumbFP(hs hashList, x, y types.Hash32, limit int) fpResultWithBounds { p = 0 for { if p == pY || limit == 0 { - fpr.end = hs.keyAt(p) + fpr.next = hs.keyAt(p) break } - fpr.fp.update(hs.keyAt(p)) + fpr.fp.Update(hs.keyAt(p)) limit-- fpr.count++ p++ @@ -1084,13 +1094,13 @@ func dumbFP(hs hashList, x, y types.Hash32, limit int) fpResultWithBounds { pX := hs.findGTE(x) p := pX fpr.start = hs.keyAt(p) - fpr.end = fpr.start + fpr.next = fpr.start for { if limit == 0 { - fpr.end = hs.keyAt(p) + fpr.next = hs.keyAt(p) break } - fpr.fp.update(hs.keyAt(p)) + fpr.fp.Update(hs.keyAt(p)) limit-- fpr.count++ p = (p + 1) % l @@ -1102,9 +1112,9 @@ func dumbFP(hs hashList, x, y types.Hash32, limit int) fpResultWithBounds { return fpr } -func verifyInterval(t *testing.T, hs hashList, ft *fpTree, x, y types.Hash32, limit int) fpResult { +func verifyInterval(t *testing.T, hs hashList, ft *fpTree, x, y types.KeyBytes, limit int) fpResult { expFPR := dumbFP(hs, x, y, limit) - fpr, err := ft.fingerprintInterval(context.Background(), x[:], y[:], limit) + fpr, err := ft.fingerprintInterval(context.Background(), x, y, limit) require.NoError(t, err) require.Equal(t, expFPR, toFPResultWithBounds(t, fpr), "x=%s y=%s limit=%d", x.String(), y.String(), limit) @@ -1127,7 +1137,7 @@ func verifyInterval(t *testing.T, hs hashList, ft *fpTree, x, y types.Hash32, li return fpr } -func verifySubIntervals(t *testing.T, hs hashList, ft *fpTree, x, y types.Hash32, limit, d int) fpResult { +func verifySubIntervals(t *testing.T, hs hashList, ft *fpTree, x, y types.KeyBytes, limit, d int) fpResult { fpr := verifyInterval(t, hs, ft, x, y, limit) // t.Logf("verifySubIntervals: x=%s y=%s limit=%d => count %d", x.String(), y.String(), limit, fpr.count) if fpr.count > 1 { @@ -1136,8 +1146,8 @@ func verifySubIntervals(t *testing.T, hs hashList, ft *fpTree, x, y types.Hash32 require.Less(t, c, limit) } part := verifyInterval(t, hs, ft, x, y, c) - var m types.Hash32 - copy(m[:], itKey(t, part.end)) + m := make(types.KeyBytes, len(x)) + copy(m, part.next) verifySubIntervals(t, hs, ft, x, m, -1, d+1) verifySubIntervals(t, hs, ft, m, y, -1, d+1) } @@ -1149,30 +1159,30 @@ func testFPTreeManyItems(t *testing.T, idStore idStore, randomXY bool, numItems, ft := newFPTree(&np, idStore, 32, maxDepth) // ft.traceEnabled = true hs := make(hashList, numItems) - var fp fingerprint + var fp types.Fingerprint for i := range hs { - h := types.RandomHash() + h := types.RandomKeyBytes(32) hs[i] = h - ft.addHash(h[:]) - fp.update(h[:]) + ft.addHash(h) + fp.Update(h) } - slices.SortFunc(hs, func(a, b types.Hash32) int { + slices.SortFunc(hs, func(a, b types.KeyBytes) int { return a.Compare(b) }) checkTree(t, ft, maxDepth) - fpr, err := ft.fingerprintInterval(context.Background(), hs[0][:], hs[0][:], -1) + fpr, err := ft.fingerprintInterval(context.Background(), hs[0], hs[0], -1) require.NoError(t, err) require.Equal(t, fp, fpr.fp, "fp") require.Equal(t, uint32(numItems), fpr.count, "count") require.Equal(t, 0, fpr.itype, "itype") for i := 0; i < repeat; i++ { // TBD: allow reverse order - var x, y types.Hash32 + var x, y types.KeyBytes if randomXY { - x = types.RandomHash() - y = types.RandomHash() + x = types.RandomKeyBytes(32) + y = types.RandomKeyBytes(32) } else { x = hs[rand.Intn(numItems)] y = hs[rand.Intn(numItems)] @@ -1228,7 +1238,7 @@ func TestFPTreeManyItems(t *testing.T) { func verifyEasySplit( t *testing.T, ft *fpTree, - x, y KeyBytes, + x, y types.KeyBytes, depth, maxDepth int, ) ( @@ -1241,14 +1251,14 @@ func verifyEasySplit( if fpr.count <= 1 { return } - a, err := fpr.start.Key() - require.NoError(t, err) - b, err := fpr.end.Key() + a := firstKey(t, fpr.items) require.NoError(t, err) + b := fpr.next + require.NotNil(t, b) m := fpr.count / 2 // t.Logf("--- easy split %s %s %d ---", x.String(), y.String(), m) - sr, err := ft.easySplit(context.Background(), x[:], y[:], int(m)) + sr, err := ft.easySplit(context.Background(), x, y, int(m)) if err != nil { require.ErrorIs(t, err, errEasySplitFailed) return 0, 1 @@ -1260,40 +1270,34 @@ func verifyEasySplit( require.Equal(t, fpr.itype, sr.part0.itype) require.Equal(t, fpr.itype, sr.part1.itype) fp := sr.part0.fp - fp.update(sr.part1.fp[:]) + fp.Update(sr.part1.fp[:]) require.Equal(t, fpr.fp, fp) - require.Equal(t, a, itKey(t, sr.part0.start)) - require.Equal(t, b, itKey(t, sr.part1.end)) - precMiddle := itKey(t, sr.part0.end) - require.Equal(t, precMiddle, itKey(t, sr.part1.start)) + require.Equal(t, a, firstKey(t, sr.part0.items)) + precMiddle := firstKey(t, sr.part1.items) fpr11, err := ft.fingerprintInterval(context.Background(), x, precMiddle, -1) require.NoError(t, err) - require.Equal(t, sr.part0.fp, fpr11.fp) require.Equal(t, sr.part0.count, fpr11.count) - require.Equal(t, a, itKey(t, fpr11.start)) - require.Equal(t, precMiddle, itKey(t, fpr11.end)) + require.Equal(t, sr.part0.fp, fpr11.fp) + require.Equal(t, a, firstKey(t, fpr11.items)) fpr12, err := ft.fingerprintInterval(context.Background(), precMiddle, y, -1) require.NoError(t, err) - require.Equal(t, sr.part1.fp, fpr12.fp) require.Equal(t, sr.part1.count, fpr12.count) - require.Equal(t, precMiddle, itKey(t, fpr12.start)) - require.Equal(t, b, itKey(t, fpr12.end)) + require.Equal(t, sr.part1.fp, fpr12.fp) + require.Equal(t, precMiddle, firstKey(t, fpr12.items)) fpr11, err = ft.fingerprintInterval(context.Background(), x, sr.middle, -1) require.NoError(t, err) - require.Equal(t, sr.part0.fp, fpr11.fp) require.Equal(t, sr.part0.count, fpr11.count) - require.Equal(t, a, itKey(t, fpr11.start)) - require.Equal(t, precMiddle, itKey(t, fpr11.end)) + require.Equal(t, sr.part0.fp, fpr11.fp) + require.Equal(t, a, firstKey(t, fpr11.items)) fpr12, err = ft.fingerprintInterval(context.Background(), sr.middle, y, -1) require.NoError(t, err) - require.Equal(t, sr.part1.fp, fpr12.fp) require.Equal(t, sr.part1.count, fpr12.count) - require.Equal(t, precMiddle, itKey(t, fpr12.start)) - require.Equal(t, b, itKey(t, fpr12.end)) + require.Equal(t, sr.part1.fp, fpr12.fp) + require.Equal(t, precMiddle, firstKey(t, fpr12.items)) if depth >= maxDepth { return 1, 0 @@ -1310,13 +1314,13 @@ func TestEasySplit(t *testing.T) { var np nodePool ft := newFPTree(&np, newInMemIDStore(32), 32, maxDepth) for range count { - h := types.RandomHash() + h := types.RandomKeyBytes(32) // t.Logf("adding hash %s", h.String()) - ft.addHash(h[:]) + ft.addHash(h) } - k, err := ft.start(context.Background()).Key() + seq, err := ft.all(context.Background()) require.NoError(t, err) - x := k.(KeyBytes) + x := firstKey(t, seq) v := load64(x) & ^(1<<(64-maxDepth) - 1) binary.BigEndian.PutUint64(x, v) for i := 8; i < len(x); i++ { @@ -1335,300 +1339,3 @@ func TestEasySplit(t *testing.T) { require.GreaterOrEqual(t, successRate, 95.0) } } - -const dbFile = "/Users/ivan4th/Library/Application Support/Spacemesh/node-data/7c8cef2b/state.sql" - -// func dumbAggATXs(t *testing.T, db sql.StateDatabase, x, y types.Hash32) fpResult { -// var fp fingerprint -// ts := time.Now() -// nRows, err := db.Exec( -// // BETWEEN is faster than >= and < -// "select id from atxs where id between ? and ? order by id", -// func(stmt *sql.Statement) { -// stmt.BindBytes(1, x[:]) -// stmt.BindBytes(2, y[:]) -// }, -// func(stmt *sql.Statement) bool { -// var id types.Hash32 -// stmt.ColumnBytes(0, id[:]) -// if id != y { -// fp.update(id[:]) -// } -// return true -// }, -// ) -// require.NoError(t, err) -// t.Logf("QQQQQ: %v: dumb fp between %s and %s", time.Now().Sub(ts), x.String(), y.String()) -// return fpResult{ -// fp: fp, -// count: uint32(nRows), -// itype: x.Compare(y), -// } -// } - -func treeStats(t *testing.T, ft *fpTree) { - numNodes := 0 - numCompactable := 0 - numLeafs := 0 - numEarlyLeafs := 0 - minLeafSize := uint32(math.MaxUint32) - maxLeafSize := uint32(0) - totalLeafSize := uint32(0) - var scanNode func(nodeIndex, int) bool - scanNode = func(idx nodeIndex, depth int) bool { - if idx == noIndex { - return false - } - numNodes++ - node := ft.np.node(idx) - if node.leaf() { - minLeafSize = min(minLeafSize, node.c) - maxLeafSize = max(maxLeafSize, node.c) - totalLeafSize += node.c - numLeafs++ - if depth < ft.maxDepth { - numEarlyLeafs++ - } - } else { - haveLeft := scanNode(node.left, depth+1) - if !scanNode(node.right, depth+1) || !haveLeft { - numCompactable++ - } - } - return true - } - scanNode(ft.root, 0) - avgLeafSize := float64(totalLeafSize) / float64(numLeafs) - t.Logf("tree stats: numNodes=%d numLeafs=%d numEarlyLeafs=%d numCompactable=%d minLeafSize=%d maxLeafSize=%d avgLeafSize=%f", - numNodes, numLeafs, numEarlyLeafs, numCompactable, minLeafSize, maxLeafSize, avgLeafSize) -} - -func testATXFP(t *testing.T, maxDepth int, hs *[]types.Hash32) { - // t.Skip("slow tmp test") - // counts := make(map[uint64]uint64) - // prefLens := make(map[int]int) - // QQQQQ: TBD: reenable schema drift check - db, err := statesql.Open("file:"+dbFile, sql.WithNoCheckSchemaDrift()) - require.NoError(t, err) - defer db.Close() - // _, err = db.Exec("PRAGMA cache_size = -2000000", nil, nil) - // require.NoError(t, err) - // var prev uint64 - // first := true - // where epoch=23 - var np nodePool - if *hs == nil { - t.Logf("loading IDs") - _, err = db.Exec("select id from atxs where epoch = 26 order by id", - nil, func(stmt *sql.Statement) bool { - var id types.Hash32 - stmt.ColumnBytes(0, id[:]) - *hs = append(*hs, id) - // v := load64(id[:]) - // counts[v>>40]++ - // if first { - // first = false - // } else { - // prefLens[bits.LeadingZeros64(prev^v)]++ - // } - // prev = v - return true - }) - require.NoError(t, err) - } - - // TODO: use testing.B and b.ReportAllocs() - for i := 0; i < 3; i++ { - runtime.GC() - time.Sleep(100 * time.Millisecond) - } - var stats1 runtime.MemStats - runtime.ReadMemStats(&stats1) - // TODO: pass extra bind params to the SQL query - st := &SyncedTable{ - TableName: "atxs", - IDColumn: "id", - Filter: parseSQLExpr(t, "epoch = ?"), - Binder: func(stmt *sql.Statement) { - stmt.BindInt64(1, 26) - }, - } - sts, err := st.snapshot(db) - require.NoError(t, err) - store := newSQLIDStore(db, sts, 32) - ft := newFPTree(&np, store, 32, maxDepth) - for _, id := range *hs { - ft.addHash(id[:]) - } - treeStats(t, ft) - - // countFreq := make(map[uint64]int) - // for _, c := range counts { - // countFreq[c]++ - // } - // ks := maps.Keys(countFreq) - // slices.Sort(ks) - // for _, c := range ks { - // t.Logf("%d: %d times", c, countFreq[c]) - // } - // pls := maps.Keys(prefLens) - // slices.Sort(pls) - // for _, pl := range pls { - // t.Logf("pl %d: %d times", pl, prefLens[pl]) - // } - - t.Logf("benchmarking ranges") - ts := time.Now() - const numIter = 20000 - for n := 0; n < numIter; n++ { - x := types.RandomHash() - y := types.RandomHash() - ft.fingerprintInterval(context.Background(), x[:], y[:], -1) - } - elapsed := time.Now().Sub(ts) - - for i := 0; i < 3; i++ { - runtime.GC() - time.Sleep(100 * time.Millisecond) - } - var stats2 runtime.MemStats - runtime.ReadMemStats(&stats2) - t.Logf("range benchmark for maxDepth %d: %v per range, %f ranges/s, heap diff %d", - // it's important to use ft pointer here so it doesn't get freed - // before we read the mem stats - ft.maxDepth, - elapsed/numIter, - float64(numIter)/elapsed.Seconds(), - stats2.HeapInuse-stats1.HeapInuse) - - // TODO: test incomplete ranges (with limit) - t.Logf("testing ranges") - for n := 0; n < 50; n++ { - x := types.RandomHash() - y := types.RandomHash() - // t.Logf("QQQQQ: x=%s y=%s", x.String(), y.String()) - expFPResult := dumbFP(*hs, x, y, -1) - //expFPResult := dumbAggATXs(t, db, x, y) - fpr, err := ft.fingerprintInterval(context.Background(), x[:], y[:], -1) - require.NoError(t, err) - require.Equal(t, expFPResult, toFPResultWithBounds(t, fpr), - "x=%s y=%s", x.String(), y.String()) - - limit := 0 - if fpr.count != 0 { - limit = rand.Intn(int(fpr.count)) - } - // t.Logf("QQQQQ: x=%s y=%s limit=%d", x.String(), y.String(), limit) - expFPResult = dumbFP(*hs, x, y, limit) - fpr, err = ft.fingerprintInterval(context.Background(), x[:], y[:], limit) - require.NoError(t, err) - require.Equal(t, expFPResult, toFPResultWithBounds(t, fpr), - "x=%s y=%s limit=%d", x.String(), y.String(), limit) - } - - // x := types.HexToHash32("930a069661bf21b52aa79a4b5149ecc1190282f1386b6b8ae6b738153a7a802d") - // y := types.HexToHash32("6c966fc65c07c92e869b7796b2346a33e01c4fe38c25094a480cdcd2e7df1f56") - // t.Logf("QQQQQ: maxDepth=%d x=%s y=%s", maxDepth, x.String(), y.String()) - // expFPResult := dumbFP(*hs, x, y, -1) - // //expFPResult := dumbAggATXs(t, db, x, y) - // ft.traceEnabled = true - // fpr, err := ft.fingerprintInterval(x[:], y[:], -1) - // require.NoError(t, err) - // require.Equal(t, expFPResult, fpr, "x=%s y=%s", x.String(), y.String()) -} - -func TestATXFP(t *testing.T) { - t.Skip("slow test") - var hs []types.Hash32 - for maxDepth := 15; maxDepth <= 23; maxDepth++ { - for i := 0; i < 3; i++ { - testATXFP(t, maxDepth, &hs) - } - } -} - -// benchmarks - -// maxDepth 18: 94.739µs per range, 10555.290991 ranges/s, heap diff 16621568 -// maxDepth 18: 95.837µs per range, 10434.316922 ranges/s, heap diff 16564224 -// maxDepth 18: 95.312µs per range, 10491.834238 ranges/s, heap diff 16588800 -// maxDepth 19: 60.822µs per range, 16441.200726 ranges/s, heap diff 32317440 -// maxDepth 19: 57.86µs per range, 17283.084675 ranges/s, heap diff 32333824 -// maxDepth 19: 58.183µs per range, 17187.139809 ranges/s, heap diff 32342016 -// maxDepth 20: 41.582µs per range, 24048.516680 ranges/s, heap diff 63094784 -// maxDepth 20: 41.384µs per range, 24163.830753 ranges/s, heap diff 63102976 -// maxDepth 20: 42.003µs per range, 23807.631953 ranges/s, heap diff 63053824 -// maxDepth 21: 31.996µs per range, 31253.349138 ranges/s, heap diff 123289600 -// maxDepth 21: 31.926µs per range, 31321.766830 ranges/s, heap diff 123256832 -// maxDepth 21: 31.839µs per range, 31407.657854 ranges/s, heap diff 123256832 -// maxDepth 22: 27.829µs per range, 35933.122150 ranges/s, heap diff 240689152 -// maxDepth 22: 27.524µs per range, 36330.976995 ranges/s, heap diff 240689152 -// maxDepth 22: 27.386µs per range, 36514.410406 ranges/s, heap diff 240689152 -// maxDepth 23: 24.378µs per range, 41020.262869 ranges/s, heap diff 470024192 -// maxDepth 23: 24.605µs per range, 40641.096389 ranges/s, heap diff 470056960 -// maxDepth 23: 24.51µs per range, 40799.444720 ranges/s, heap diff 470040576 - -// maxDepth 18: 94.518µs per range, 10579.885738 ranges/s, heap diff 16621568 -// maxDepth 18: 95.144µs per range, 10510.332936 ranges/s, heap diff 16572416 -// maxDepth 18: 94.55µs per range, 10576.359829 ranges/s, heap diff 16588800 -// maxDepth 19: 60.463µs per range, 16538.974879 ranges/s, heap diff 32325632 -// maxDepth 19: 60.47µs per range, 16537.108181 ranges/s, heap diff 32358400 -// maxDepth 19: 60.441µs per range, 16544.939001 ranges/s, heap diff 32333824 -// maxDepth 20: 41.131µs per range, 24311.982297 ranges/s, heap diff 63078400 -// maxDepth 20: 41.621µs per range, 24026.119996 ranges/s, heap diff 63086592 -// maxDepth 20: 41.568µs per range, 24056.912641 ranges/s, heap diff 63094784 -// maxDepth 21: 32.234µs per range, 31022.459566 ranges/s, heap diff 123256832 -// maxDepth 21: 30.856µs per range, 32408.240119 ranges/s, heap diff 123248640 -// maxDepth 21: 30.774µs per range, 32494.318758 ranges/s, heap diff 123224064 -// maxDepth 22: 27.476µs per range, 36394.375781 ranges/s, heap diff 240689152 -// maxDepth 22: 27.707µs per range, 36091.188900 ranges/s, heap diff 240705536 -// maxDepth 22: 27.281µs per range, 36654.794863 ranges/s, heap diff 240705536 -// maxDepth 23: 24.394µs per range, 40992.220132 ranges/s, heap diff 470048768 -// maxDepth 23: 24.697µs per range, 40489.695824 ranges/s, heap diff 470040576 -// maxDepth 23: 24.436µs per range, 40923.081488 ranges/s, heap diff 470032384 - -// maxDepth 15: 529.513µs per range, 1888.524885 ranges/s, heap diff 2293760 -// maxDepth 15: 528.783µs per range, 1891.132520 ranges/s, heap diff 2244608 -// maxDepth 15: 529.458µs per range, 1888.723450 ranges/s, heap diff 2252800 -// maxDepth 16: 281.809µs per range, 3548.498801 ranges/s, heap diff 4390912 -// maxDepth 16: 280.159µs per range, 3569.389929 ranges/s, heap diff 4382720 -// maxDepth 16: 280.449µs per range, 3565.709031 ranges/s, heap diff 4390912 -// maxDepth 17: 157.429µs per range, 6352.037713 ranges/s, heap diff 8527872 -// maxDepth 17: 156.569µs per range, 6386.942961 ranges/s, heap diff 8527872 -// maxDepth 17: 157.158µs per range, 6362.998907 ranges/s, heap diff 8527872 -// maxDepth 18: 94.689µs per range, 10560.886016 ranges/s, heap diff 16547840 -// maxDepth 18: 95.995µs per range, 10417.191145 ranges/s, heap diff 16564224 -// maxDepth 18: 94.469µs per range, 10585.428908 ranges/s, heap diff 16515072 -// maxDepth 19: 61.218µs per range, 16334.822475 ranges/s, heap diff 32342016 -// maxDepth 19: 61.733µs per range, 16198.549404 ranges/s, heap diff 32350208 -// maxDepth 19: 61.269µs per range, 16321.226214 ranges/s, heap diff 32309248 -// maxDepth 20: 42.336µs per range, 23620.054892 ranges/s, heap diff 63053824 -// maxDepth 20: 41.906µs per range, 23862.511368 ranges/s, heap diff 63094784 -// maxDepth 20: 41.647µs per range, 24011.273302 ranges/s, heap diff 63086592 -// maxDepth 21: 32.895µs per range, 30399.444906 ranges/s, heap diff 123256832 -// maxDepth 21: 31.798µs per range, 31447.748207 ranges/s, heap diff 123256832 -// maxDepth 21: 32.008µs per range, 31241.248008 ranges/s, heap diff 123265024 -// maxDepth 22: 27.014µs per range, 37017.223157 ranges/s, heap diff 240689152 -// maxDepth 22: 26.764µs per range, 37363.422097 ranges/s, heap diff 240664576 -// maxDepth 22: 26.938µs per range, 37121.580267 ranges/s, heap diff 240664576 -// maxDepth 23: 24.457µs per range, 40887.173321 ranges/s, heap diff 470040576 -// maxDepth 23: 24.997µs per range, 40003.930386 ranges/s, heap diff 470040576 -// maxDepth 23: 24.741µs per range, 40418.462446 ranges/s, heap diff 470040576 - -// TODO: QQQQQ: retrieve the end of the interval w/count in fpTree.fingerprintInterval() -// TODO: QQQQQ: test limits in TestInMemFPTreeManyItems (sep test cases SQL / non-SQL) -// TODO: the returned RangeInfo.End iterators should be cyclic - -// TBD: random off-by-1 failure? -// --- Expected -// +++ Actual -// @@ -2,5 +2,5 @@ -// fp: (dbsync.fingerprint) (len=12) { -// - 00000000 30 d4 db 9d b9 15 dd ad 75 1e 67 fd |0.......u.g.| -// + 00000000 a3 de 4b 89 7b 93 fc 76 24 88 82 b2 |..K.{..v$...| -// }, -// - count: (uint32) 41784134, -// + count: (uint32) 41784135, -// itype: (int) 1 -// Test: TestATXFP -// Messages: x=930a069661bf21b52aa79a4b5149ecc1190282f1386b6b8ae6b738153a7a802d y=6c966fc65c07c92e869b7796b2346a33e01c4fe38c25094a480cdcd2e7df1f56 diff --git a/sync2/dbsync/inmemidstore.go b/sync2/dbsync/inmemidstore.go index e6f420145a..cd85ff222a 100644 --- a/sync2/dbsync/inmemidstore.go +++ b/sync2/dbsync/inmemidstore.go @@ -3,8 +3,8 @@ package dbsync import ( "context" - "github.com/spacemeshos/go-spacemesh/sync2/hashsync" "github.com/spacemeshos/go-spacemesh/sync2/internal/skiplist" + "github.com/spacemeshos/go-spacemesh/sync2/types" ) type inMemIDStore struct { @@ -30,49 +30,45 @@ func (s *inMemIDStore) clone() idStore { return newStore } -func (s *inMemIDStore) registerHash(h KeyBytes) error { +func (s *inMemIDStore) registerHash(h types.KeyBytes) error { s.sl.Add(h) s.len++ return nil } -func (s *inMemIDStore) start(ctx context.Context) hashsync.Iterator { - return &inMemIDStoreIterator{sl: s.sl, node: s.sl.First()} -} - -func (s *inMemIDStore) iter(ctx context.Context, from KeyBytes) hashsync.Iterator { - node := s.sl.FindGTENode(from) - if node == nil { - node = s.sl.First() - } - return &inMemIDStoreIterator{sl: s.sl, node: node} -} - -type inMemIDStoreIterator struct { - sl *skiplist.SkipList - node *skiplist.Node -} - -var _ hashsync.Iterator = &inMemIDStoreIterator{} - -func (it *inMemIDStoreIterator) Key() (hashsync.Ordered, error) { - if it.node == nil { - return nil, errEmptySet - } - return KeyBytes(it.node.Key()), nil -} - -func (it *inMemIDStoreIterator) Next() error { - if it.node = it.node.Next(); it.node == nil { - it.node = it.sl.First() - if it.node == nil { - panic("BUG: iterator returned for an empty skiplist") +func (s *inMemIDStore) all(ctx context.Context) (types.Seq, error) { + return func(yield func(types.Ordered, error) bool) { + if s.sl.First() == nil { + return } - } - return nil + for node := s.sl.First(); ; node = node.Next() { + if node == nil { + node = s.sl.First() + } + if !yield(types.KeyBytes(node.Key()), nil) { + return + } + } + }, nil } -func (it *inMemIDStoreIterator) Clone() hashsync.Iterator { - cloned := *it - return &cloned +func (s *inMemIDStore) from(ctx context.Context, from types.KeyBytes) (types.Seq, error) { + return func(yield func(types.Ordered, error) bool) { + node := s.sl.FindGTENode(from) + if node == nil { + node = s.sl.First() + if node == nil { + return + } + } + for { + if !yield(types.KeyBytes(node.Key()), nil) { + return + } + node = node.Next() + if node == nil { + node = s.sl.First() + } + } + }, nil } diff --git a/sync2/dbsync/inmemidstore_test.go b/sync2/dbsync/inmemidstore_test.go index 337dccf2c7..08f2bce35f 100644 --- a/sync2/dbsync/inmemidstore_test.go +++ b/sync2/dbsync/inmemidstore_test.go @@ -2,29 +2,29 @@ package dbsync import ( "context" - "encoding/hex" "testing" "github.com/stretchr/testify/require" - "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/common/util" - "github.com/spacemeshos/go-spacemesh/sync2/hashsync" + "github.com/spacemeshos/go-spacemesh/sync2/types" ) func TestInMemIDStore(t *testing.T) { - var ( - it hashsync.Iterator - err error - ) s := newInMemIDStore(32) ctx := context.Background() - _, err = s.start(ctx).Key() - require.ErrorIs(t, err, errEmptySet) + seq, err := s.all(ctx) + require.NoError(t, err) + for _ = range seq { + require.Fail(t, "sequence not empty") + } - _, err = s.iter(ctx, util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")).Key() - require.ErrorIs(t, err, errEmptySet) + seq, err = s.from(ctx, types.HexToKeyBytes( + "0000000000000000000000000000000000000000000000000000000000000000")) + require.NoError(t, err) + for _ = range seq { + require.Fail(t, "sequence not empty") + } for _, h := range []string{ "0000000000000000000000000000000000000000000000000000000000000000", @@ -35,21 +35,21 @@ func TestInMemIDStore(t *testing.T) { "8888889999999999999999999999999999999999999999999999999999999999", "abcdef1234567890000000000000000000000000000000000000000000000000", } { - s.registerHash(util.FromHex(h)) + s.registerHash(types.HexToKeyBytes(h)) } - for i := range 6 { - if i%2 == 0 { - it = s.start(ctx) - } else { - it = s.iter( - ctx, - util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")) - } - var items []string - for range 7 { - items = append(items, hex.EncodeToString(itKey(t, it))) - require.NoError(t, it.Next()) + seq, err = s.all(ctx) + require.NoError(t, err) + for range 3 { // make sure seq is reusable + var r []string + n := 15 + for k, err := range seq { + require.NoError(t, err) + r = append(r, k.(types.KeyBytes).String()) + n-- + if n == 0 { + break + } } require.Equal(t, []string{ "0000000000000000000000000000000000000000000000000000000000000000", @@ -59,42 +59,47 @@ func TestInMemIDStore(t *testing.T) { "8888888888888888888888888888888888888888888888888888888888888888", "8888889999999999999999999999999999999999999999999999999999999999", "abcdef1234567890000000000000000000000000000000000000000000000000", - }, items) - require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", - hex.EncodeToString(itKey(t, it))) + "1234561111111111111111111111111111111111111111111111111111111111", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "8888889999999999999999999999999999999999999999999999999999999999", + "abcdef1234567890000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + }, r) + } - s1 := s.clone() - h := types.BytesToHash( - util.FromHex("2000000000000000000000000000000000000000000000000000000000000000")) - s1.registerHash(h[:]) - items = nil - it = s1.iter( - ctx, - util.FromHex("0000000000000000000000000000000000000000000000000000000000000000")) - for range 8 { - items = append(items, hex.EncodeToString(itKey(t, it))) - require.NoError(t, it.Next()) + seq, err = s.from(ctx, types.HexToKeyBytes( + "5555555555555555555555555555555555555555555555555555555555555555")) + require.NoError(t, err) + for range 3 { // make sure seq is reusable + var r []string + n := 15 + for k, err := range seq { + require.NoError(t, err) + r = append(r, k.(types.KeyBytes).String()) + n-- + if n == 0 { + break + } } require.Equal(t, []string{ + "5555555555555555555555555555555555555555555555555555555555555555", + "8888888888888888888888888888888888888888888888888888888888888888", + "8888889999999999999999999999999999999999999999999999999999999999", + "abcdef1234567890000000000000000000000000000000000000000000000000", "0000000000000000000000000000000000000000000000000000000000000000", "1234561111111111111111111111111111111111111111111111111111111111", "123456789abcdef0000000000000000000000000000000000000000000000000", - "2000000000000000000000000000000000000000000000000000000000000000", "5555555555555555555555555555555555555555555555555555555555555555", "8888888888888888888888888888888888888888888888888888888888888888", "8888889999999999999999999999999999999999999999999999999999999999", "abcdef1234567890000000000000000000000000000000000000000000000000", - }, items) - require.Equal(t, "0000000000000000000000000000000000000000000000000000000000000000", - hex.EncodeToString(itKey(t, it))) - - it = s1.iter( - ctx, - util.FromHex("fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0")) - require.Equal(t, - "0000000000000000000000000000000000000000000000000000000000000000", - hex.EncodeToString(itKey(t, it))) + "1234561111111111111111111111111111111111111111111111111111111111", + "123456789abcdef0000000000000000000000000000000000000000000000000", + "5555555555555555555555555555555555555555555555555555555555555555", + }, r) } } diff --git a/sync2/dbsync/interface.go b/sync2/dbsync/interface.go new file mode 100644 index 0000000000..37740c2f1c --- /dev/null +++ b/sync2/dbsync/interface.go @@ -0,0 +1,14 @@ +package dbsync + +import ( + "context" + + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +type idStore interface { + clone() idStore + registerHash(h types.KeyBytes) error + all(ctx context.Context) (types.Seq, error) + from(ctx context.Context, from types.KeyBytes) (types.Seq, error) +} diff --git a/sync2/dbsync/p2p_test.go b/sync2/dbsync/p2p_test.go index 35046e86d7..56babfc96a 100644 --- a/sync2/dbsync/p2p_test.go +++ b/sync2/dbsync/p2p_test.go @@ -3,8 +3,8 @@ package dbsync import ( "context" "encoding/binary" - "encoding/hex" "errors" + "fmt" "io" "slices" "testing" @@ -17,16 +17,16 @@ import ( "go.uber.org/zap/zaptest" "golang.org/x/sync/errgroup" - "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/p2p/server" "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sync2/hashsync" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/types" ) var startDate = time.Date(2024, 8, 29, 18, 0, 0, 0, time.UTC) type fooRow struct { - id KeyBytes + id types.KeyBytes ts int64 } @@ -54,10 +54,10 @@ type syncTracer struct { sentItems int } -var _ hashsync.Tracer = &syncTracer{} +var _ rangesync.Tracer = &syncTracer{} func (tr *syncTracer) OnDumbSync() { - // QQQQQQ: use mutex and also update handler_test.go in hashsync!!!! + // QQQQQQ: use mutex and also update handler_test.go in rangesync!!!! tr.dumb = true } @@ -69,13 +69,13 @@ func (tr *syncTracer) OnRecent(receivedItems, sentItems int) { func verifyP2P( t *testing.T, rowsA, rowsB []fooRow, - combinedItems []KeyBytes, + combinedItems []types.KeyBytes, clockAt time.Time, receivedRecent, sentRecent bool, - opts ...hashsync.RangeSetReconcilerOption, + opts ...rangesync.RangeSetReconcilerOption, ) { - nr := hashsync.RmmeNumRead() - nw := hashsync.RmmeNumWritten() + // nr := rangesync.RmmeNumRead() + // nw := rangesync.RmmeNumWritten() const maxDepth = 24 log := zaptest.NewLogger(t) t.Logf("QQQQQ: 0") @@ -92,71 +92,72 @@ func verifyP2P( IDColumn: "id", TimestampColumn: "received", } - storeA := NewItemStoreAdapter(NewDBItemStore(dbA, st, 32, maxDepth)) + setA := NewDBSet(dbA, st, 32, maxDepth) t.Logf("QQQQQ: 2.1") require.NoError(t, dbA.WithTx(ctx, func(tx sql.Transaction) error { - return storeA.s.EnsureLoaded(WithSQLExec(ctx, tx)) + return setA.EnsureLoaded(WithSQLExec(ctx, tx)) })) t.Logf("QQQQQ: 3") - storeB := NewItemStoreAdapter(NewDBItemStore(dbB, st, 32, maxDepth)) + setB := NewDBSet(dbB, st, 32, maxDepth) t.Logf("QQQQQ: 3.1") - var x *types.Hash32 + var x types.KeyBytes require.NoError(t, dbB.WithTx(ctx, func(tx sql.Transaction) error { ctx := WithSQLExec(ctx, tx) - if err := storeB.s.EnsureLoaded(ctx); err != nil { - return err + empty, err := setB.Empty(ctx) + if err != nil { + return fmt.Errorf("check if the set is empty: %w", err) + } + if empty { + return nil } - it, err := storeB.Min(ctx) + seq, err := setB.Items(ctx) if err != nil { - return err + return fmt.Errorf("get items: %w", err) } - if it != nil { - x = &types.Hash32{} - k, err := it.Key() - if err != nil { - return err - } - h := k.(types.Hash32) - v := load64(h[:]) & ^uint64(1<<(64-maxDepth)-1) - binary.BigEndian.PutUint64(x[:], v) - for i := 8; i < len(x); i++ { - x[i] = 0 - } - t.Logf("x: %s", x.String()) + x = make(types.KeyBytes, 32) + k, err := seq.First() + if err != nil { + return fmt.Errorf("get first item: %w", err) } + v := load64(k.(types.KeyBytes)) & ^uint64(1<<(64-maxDepth)-1) + binary.BigEndian.PutUint64(x[:], v) + for i := 8; i < len(x); i++ { + x[i] = 0 + } + t.Logf("x: %s", x.String()) return nil })) t.Logf("QQQQQ: 4") // QQQQQ: rmme - // storeB.s.ft.traceEnabled = true - // storeB.qqqqq = true - // require.NoError(t, storeB.s.EnsureLoaded()) + // setB.s.ft.traceEnabled = true + // setB.qqqqq = true + // require.NoError(t, setB.s.EnsureLoaded()) // var sb strings.Builder - // storeA.s.ft.dump(&sb) - // t.Logf("storeA:\n%s", sb.String()) + // setA.s.ft.dump(&sb) + // t.Logf("setA:\n%s", sb.String()) // sb = strings.Builder{} - // storeB.s.ft.dump(&sb) - // t.Logf("storeB:\n%s", sb.String()) + // setB.s.ft.dump(&sb) + // t.Logf("setB:\n%s", sb.String()) var tr syncTracer opts = append(opts, - hashsync.WithRangeReconcilerClock(clockwork.NewFakeClockAt(clockAt)), - hashsync.WithTracer(&tr), + rangesync.WithClock(clockwork.NewFakeClockAt(clockAt)), + rangesync.WithTracer(&tr), ) opts = opts[:len(opts):len(opts)] srvPeerID := mesh.Hosts()[0].ID() srv := server.New(mesh.Hosts()[0], proto, func(ctx context.Context, req []byte, stream io.ReadWriter) error { - pss := hashsync.NewPairwiseStoreSyncer(nil, append( + pss := rangesync.NewPairwiseSetSyncer(nil, append( opts, - hashsync.WithMaxSendRange(1), + rangesync.WithMaxSendRange(1), // uncomment to enable verbose logging which may slow down tests - // hashsync.WithRangeReconcilerLogger(log.Named("sideA")), + // rangesync.WithLogger(log.Named("sideA")), )) return dbA.WithTx(ctx, func(tx sql.Transaction) error { - return pss.Serve(WithSQLExec(ctx, tx), req, stream, storeA) + return pss.Serve(WithSQLExec(ctx, tx), req, stream, setA) }) }, server.WithTimeout(10*time.Second), @@ -192,19 +193,19 @@ func verifyP2P( return true }, time.Second, 10*time.Millisecond) - pss := hashsync.NewPairwiseStoreSyncer(client, append( + pss := rangesync.NewPairwiseSetSyncer(client, append( opts, - hashsync.WithMaxSendRange(1), + rangesync.WithMaxSendRange(1), // uncomment to enable verbose logging which may slow down tests - // hashsync.WithRangeReconcilerLogger(log.Named("sideB")), + // rangesync.WithLogger(log.Named("sideB")), )) tStart := time.Now() require.NoError(t, dbB.WithTx(ctx, func(tx sql.Transaction) error { - return pss.SyncStore(WithSQLExec(ctx, tx), srvPeerID, storeB, x, x) + return pss.Sync(WithSQLExec(ctx, tx), srvPeerID, setB, x, x) })) t.Logf("synced in %v", time.Since(tStart)) - t.Logf("bytes read: %d, bytes written: %d", hashsync.RmmeNumRead()-nr, hashsync.RmmeNumWritten()-nw) + // t.Logf("bytes read: %d, bytes written: %d", rangesync.RmmeNumRead()-nr, rangesync.RmmeNumWritten()-nw) require.Equal(t, receivedRecent, tr.receivedItems > 0) require.Equal(t, sentRecent, tr.sentItems > 0) @@ -220,48 +221,32 @@ func verifyP2P( if len(combinedItems) == 0 { return } - var actItemsA []KeyBytes + var actItemsA []types.KeyBytes require.NoError(t, dbA.WithTx(ctx, func(tx sql.Transaction) error { - it, err := storeA.Min(WithSQLExec(ctx, tx)) + seq, err := setA.Items(WithSQLExec(ctx, tx)) require.NoError(t, err) - if len(combinedItems) == 0 { - assert.Nil(t, it) + if l := len(combinedItems); l == 0 { + requireEmpty(t, seq) } else { - for range combinedItems { - k, err := it.Key() - require.NoError(t, err) - h := k.(types.Hash32) - // t.Logf("synced itemA: %s", h.String()) - actItemsA = append(actItemsA, h[:]) - require.NoError(t, it.Next()) - } - k, err := it.Key() + collected, err := types.GetN[types.KeyBytes](seq, l+1) require.NoError(t, err) - h := k.(types.Hash32) - assert.Equal(t, actItemsA[0], KeyBytes(h[:])) + actItemsA = collected[:l] + assert.Equal(t, actItemsA[0], collected[l]) // verify wraparound } return nil })) - var actItemsB []KeyBytes + var actItemsB []types.KeyBytes require.NoError(t, dbB.WithTx(ctx, func(tx sql.Transaction) error { - it, err := storeB.Min(WithSQLExec(ctx, tx)) + seq, err := setB.Items(WithSQLExec(ctx, tx)) require.NoError(t, err) - if len(combinedItems) == 0 { - assert.Nil(t, it) + if l := len(combinedItems); l == 0 { + requireEmpty(t, seq) } else { - for range combinedItems { - k, err := it.Key() - require.NoError(t, err) - h := k.(types.Hash32) - // t.Logf("synced itemB: %s", h.String()) - actItemsB = append(actItemsB, h[:]) - require.NoError(t, it.Next()) - } - k, err := it.Key() + collected, err := types.GetN[types.KeyBytes](seq, l+1) require.NoError(t, err) - h := k.(types.Hash32) - assert.Equal(t, actItemsB[0], KeyBytes(h[:])) + actItemsB = collected[:l] + assert.Equal(t, actItemsB[0], collected[l]) // verify wraparound } return nil })) @@ -271,20 +256,13 @@ func verifyP2P( func fooR(id string, seconds int) fooRow { return fooRow{ - hexID(id), + types.HexToKeyBytes(id), startDate.Add(time.Duration(seconds) * time.Second).UnixNano(), } } -func hexID(s string) KeyBytes { - b, err := hex.DecodeString(s) - if err != nil { - panic(err) - } - return b -} - func TestP2P(t *testing.T) { + hexID := types.HexToKeyBytes t.Run("predefined items", func(t *testing.T) { verifyP2P( t, []fooRow{ @@ -299,7 +277,7 @@ func TestP2P(t *testing.T) { fooR("5555555555555555555555555555555555555555555555555555555555555555", 13), fooR("8888888888888888888888888888888888888888888888888888888888888888", 14), }, - []KeyBytes{ + []types.KeyBytes{ hexID("1111111111111111111111111111111111111111111111111111111111111111"), hexID("123456789abcdef0000000000000000000000000000000000000000000000000"), hexID("5555555555555555555555555555555555555555555555555555555555555555"), @@ -314,24 +292,30 @@ func TestP2P(t *testing.T) { t.Run("predefined items 2", func(t *testing.T) { verifyP2P( t, []fooRow{ - fooR("80e95b39faa731eb50eae7585a8b1cae98f503481f950fdb690e60ff86c21236", 10), - fooR("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7", 20), - fooR("d862b2413af5c252028e8f9871be8297e807661d64decd8249ac2682db168b90", 30), - fooR("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567", 40), + fooR("0e69888877324da35693decc7ded1b2bac16d394ced869af494568d66473a6f0", 10), + fooR("3a78db9e386493402561d9c6f69a6b434a62388f61d06d960598ebf29a3a2187", 20), + fooR("66c9aa8f3be7da713db66e56cc165a46764f88d3113244dd5964bb0a10ccacc3", 30), + fooR("72e1adaaf140d809a5da325a197341a453b00807ef8d8995fd3c8079b917c9d7", 40), + fooR("782c24553b0a8cf1d95f632054b7215be192facfb177cfd1312901dd4c9e0bfd", 50), + fooR("9e11fdb099f1118144738f9b68ca601e74b97280fd7bbc97cfc377f432e9b7b5", 60), }, []fooRow{ - fooR("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7", 11), - fooR("bc6218a88d1648b8145fbf93ae74af8975f193af88788e7add3608e0bc50f701", 12), - fooR("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567", 13), - fooR("fbf03324234f79a3fe0587cf5505d7e4c826cb2be38d72eafa60296ed77b3f8f", 14), + fooR("0e69888877324da35693decc7ded1b2bac16d394ced869af494568d66473a6f0", 11), + fooR("3a78db9e386493402561d9c6f69a6b434a62388f61d06d960598ebf29a3a2187", 12), + fooR("66c9aa8f3be7da713db66e56cc165a46764f88d3113244dd5964bb0a10ccacc3", 13), + fooR("90b25f2d1ee9c9e2d20df5f2226d14ee4223ea27ba565a49aa66a9c44a51c241", 14), + fooR("9e11fdb099f1118144738f9b68ca601e74b97280fd7bbc97cfc377f432e9b7b5", 15), + fooR("c1690e47798295cca02392cbfc0a86cb5204878c04a29b3ae7701b6b51681128", 16), }, - []KeyBytes{ - hexID("80e95b39faa731eb50eae7585a8b1cae98f503481f950fdb690e60ff86c21236"), - hexID("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7"), - hexID("bc6218a88d1648b8145fbf93ae74af8975f193af88788e7add3608e0bc50f701"), - hexID("d862b2413af5c252028e8f9871be8297e807661d64decd8249ac2682db168b90"), - hexID("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567"), - hexID("fbf03324234f79a3fe0587cf5505d7e4c826cb2be38d72eafa60296ed77b3f8f"), + []types.KeyBytes{ + hexID("0e69888877324da35693decc7ded1b2bac16d394ced869af494568d66473a6f0"), + hexID("3a78db9e386493402561d9c6f69a6b434a62388f61d06d960598ebf29a3a2187"), + hexID("66c9aa8f3be7da713db66e56cc165a46764f88d3113244dd5964bb0a10ccacc3"), + hexID("72e1adaaf140d809a5da325a197341a453b00807ef8d8995fd3c8079b917c9d7"), + hexID("782c24553b0a8cf1d95f632054b7215be192facfb177cfd1312901dd4c9e0bfd"), + hexID("90b25f2d1ee9c9e2d20df5f2226d14ee4223ea27ba565a49aa66a9c44a51c241"), + hexID("9e11fdb099f1118144738f9b68ca601e74b97280fd7bbc97cfc377f432e9b7b5"), + hexID("c1690e47798295cca02392cbfc0a86cb5204878c04a29b3ae7701b6b51681128"), }, startDate, false, @@ -352,7 +336,7 @@ func TestP2P(t *testing.T) { fooR("db1903851d4eba1e973fef5326cb997ea191c62a4b30d7830cc76931d28fd567", 13), fooR("fbf03324234f79a3fe0587cf5505d7e4c826cb2be38d72eafa60296ed77b3f8f", 14), }, - []KeyBytes{ + []types.KeyBytes{ hexID("80e95b39faa731eb50eae7585a8b1cae98f503481f950fdb690e60ff86c21236"), hexID("b46eb2c08f01a87aa0fd76f70dc6b1048b04a1125a44cca79c1a61932d3773d7"), hexID("bc6218a88d1648b8145fbf93ae74af8975f193af88788e7add3608e0bc50f701"), @@ -363,7 +347,7 @@ func TestP2P(t *testing.T) { startDate.Add(time.Minute), true, true, - hashsync.WithRecentTimeSpan(48*time.Second), + rangesync.WithRecentTimeSpan(48*time.Second), ) }) t.Run("empty to non-empty", func(t *testing.T) { @@ -375,7 +359,7 @@ func TestP2P(t *testing.T) { fooR("5555555555555555555555555555555555555555555555555555555555555555", 13), fooR("8888888888888888888888888888888888888888888888888888888888888888", 14), }, - []KeyBytes{ + []types.KeyBytes{ hexID("1111111111111111111111111111111111111111111111111111111111111111"), hexID("123456789abcdef0000000000000000000000000000000000000000000000000"), hexID("5555555555555555555555555555555555555555555555555555555555555555"), @@ -395,7 +379,7 @@ func TestP2P(t *testing.T) { fooR("5555555555555555555555555555555555555555555555555555555555555555", 13), fooR("8888888888888888888888888888888888888888888888888888888888888888", 14), }, - []KeyBytes{ + []types.KeyBytes{ hexID("1111111111111111111111111111111111111111111111111111111111111111"), hexID("123456789abcdef0000000000000000000000000000000000000000000000000"), hexID("5555555555555555555555555555555555555555555555555555555555555555"), @@ -404,7 +388,7 @@ func TestP2P(t *testing.T) { startDate.Add(time.Minute), true, true, - hashsync.WithRecentTimeSpan(48*time.Second), + rangesync.WithRecentTimeSpan(48*time.Second), ) }) t.Run("non-empty to empty with recent", func(t *testing.T) { @@ -417,7 +401,7 @@ func TestP2P(t *testing.T) { fooR("8888888888888888888888888888888888888888888888888888888888888888", 14), }, nil, - []KeyBytes{ + []types.KeyBytes{ hexID("1111111111111111111111111111111111111111111111111111111111111111"), hexID("123456789abcdef0000000000000000000000000000000000000000000000000"), hexID("5555555555555555555555555555555555555555555555555555555555555555"), @@ -427,7 +411,7 @@ func TestP2P(t *testing.T) { // no actual recent exchange happens due to the initial EmptySet message false, false, - hashsync.WithRecentTimeSpan(48*time.Second), + rangesync.WithRecentTimeSpan(48*time.Second), ) }) t.Run("empty to empty", func(t *testing.T) { @@ -440,23 +424,23 @@ func TestP2P(t *testing.T) { // const nShared = 8000000 // const nUniqueA = 10 // const nUniqueB = 8000 - const nShared = 80000 - const nUniqueA = 400 - const nUniqueB = 800 // const nShared = 2 // const nUniqueA = 2 // const nUniqueB = 2 - combined := make([]KeyBytes, 0, nShared+nUniqueA+nUniqueB) + const nShared = 80000 + const nUniqueA = 400 + const nUniqueB = 800 + + combined := make([]types.KeyBytes, 0, nShared+nUniqueA+nUniqueB) rowsA := make([]fooRow, nShared+nUniqueA) for i := range rowsA { - h := types.RandomHash() - k := KeyBytes(h[:]) + k := types.RandomKeyBytes(32) rowsA[i] = fooRow{ id: k, ts: startDate.Add(time.Duration(i) * time.Second).UnixNano(), } combined = append(combined, k) - // t.Logf("itemsA[%d] = %s", i, itemsA[i]) + // t.Logf("itemsA[%d] = %s", i, rowsA[i].id) } rowsB := make([]fooRow, nShared+nUniqueB) for i := range rowsB { @@ -466,17 +450,16 @@ func TestP2P(t *testing.T) { ts: rowsA[i].ts, } } else { - h := types.RandomHash() - k := KeyBytes(h[:]) + k := types.RandomKeyBytes(32) rowsB[i] = fooRow{ id: k, ts: startDate.Add(time.Duration(i) * time.Second).UnixNano(), } combined = append(combined, k) } - // t.Logf("itemsB[%d] = %s", i, itemsB[i]) + // t.Logf("itemsB[%d] = %s", i, rowsB[i].id) } - slices.SortFunc(combined, func(a, b KeyBytes) int { + slices.SortFunc(combined, func(a, b types.KeyBytes) int { return a.Compare(b) }) // for i, v := range combined { @@ -486,5 +469,3 @@ func TestP2P(t *testing.T) { // TODO: multiple iterations }) } - -// QQQQQ: TBD empty sets with recent diff --git a/sync2/dbsync/sqlidstore.go b/sync2/dbsync/sqlidstore.go index 177616a384..ebda8bd002 100644 --- a/sync2/dbsync/sqlidstore.go +++ b/sync2/dbsync/sqlidstore.go @@ -1,11 +1,10 @@ package dbsync import ( - "bytes" "context" "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sync2/hashsync" + "github.com/spacemeshos/go-spacemesh/sync2/types" ) const sqlMaxChunkSize = 1024 @@ -52,24 +51,24 @@ func (s *sqlIDStore) clone() idStore { return newSQLIDStore(s.db, s.sts, s.keyLen) } -func (s *sqlIDStore) registerHash(h KeyBytes) error { +func (s *sqlIDStore) registerHash(h types.KeyBytes) error { // should be registered by the handler code return nil } -func (s *sqlIDStore) start(ctx context.Context) hashsync.Iterator { +func (s *sqlIDStore) all(ctx context.Context) (types.Seq, error) { // TODO: should probably use a different query to get the first key - return s.iter(ctx, make(KeyBytes, s.keyLen)) + return s.from(ctx, make(types.KeyBytes, s.keyLen)) } -func (s *sqlIDStore) iter(ctx context.Context, from KeyBytes) hashsync.Iterator { +func (s *sqlIDStore) from(ctx context.Context, from types.KeyBytes) (types.Seq, error) { if len(from) != s.keyLen { panic("BUG: invalid key length") } - return newDBRangeIterator(ContextSQLExec(ctx, s.db), s.sts, from, -1, sqlMaxChunkSize, s.cache) + return idsFromTable(ContextSQLExec(ctx, s.db), s.sts, from, -1, sqlMaxChunkSize, s.cache), nil } -func (s *sqlIDStore) iterSince(ctx context.Context, from KeyBytes, since int64) (hashsync.Iterator, int, error) { +func (s *sqlIDStore) since(ctx context.Context, from types.KeyBytes, since int64) (types.Seq, int, error) { if len(from) != s.keyLen { panic("BUG: invalid key length") } @@ -81,7 +80,7 @@ func (s *sqlIDStore) iterSince(ctx context.Context, from KeyBytes, since int64) if count == 0 { return nil, 0, nil } - return newDBRangeIterator(db, s.sts, from, since, sqlMaxChunkSize, nil), count, nil + return idsFromTable(db, s.sts, from, since, sqlMaxChunkSize, nil), count, nil } func (s *sqlIDStore) setSnapshot(sts *SyncedTableSnapshot) { @@ -109,29 +108,30 @@ func (s *dbBackedStore) clone() idStore { } } -func (s *dbBackedStore) registerHash(h KeyBytes) error { +func (s *dbBackedStore) registerHash(h types.KeyBytes) error { return s.inMemIDStore.registerHash(h) } -func (s *dbBackedStore) start(ctx context.Context) hashsync.Iterator { - dbIt := s.sqlIDStore.start(ctx) - memIt := s.inMemIDStore.start(ctx) - return combineIterators(nil, dbIt, memIt) -} - -func (s *dbBackedStore) iter(ctx context.Context, from KeyBytes) hashsync.Iterator { - dbIt := s.sqlIDStore.iter(ctx, from) - memIt := s.inMemIDStore.iter(ctx, from) - return combineIterators(from, dbIt, memIt) +func (s *dbBackedStore) all(ctx context.Context) (types.Seq, error) { + dbSeq, err := s.sqlIDStore.all(ctx) + if err != nil { + return nil, err + } + memSeq, err := s.inMemIDStore.all(ctx) + if err != nil { + return nil, err + } + return combineSeqs(nil, dbSeq, memSeq), nil } -func idWithinInterval(id, x, y KeyBytes, itype int) bool { - switch itype { - case 0: - return true - case -1: - return bytes.Compare(id, x) >= 0 && bytes.Compare(id, y) < 0 - default: - return bytes.Compare(id, y) < 0 || bytes.Compare(id, x) >= 0 +func (s *dbBackedStore) from(ctx context.Context, from types.KeyBytes) (types.Seq, error) { + dbSeq, err := s.sqlIDStore.from(ctx, from) + if err != nil { + return nil, err + } + memSeq, err := s.inMemIDStore.from(ctx, from) + if err != nil { + return nil, err } + return combineSeqs(from, dbSeq, memSeq), nil } diff --git a/sync2/dbsync/sqlidstore_test.go b/sync2/dbsync/sqlidstore_test.go index 03b1c1cc5b..f3cd337f5c 100644 --- a/sync2/dbsync/sqlidstore_test.go +++ b/sync2/dbsync/sqlidstore_test.go @@ -4,8 +4,10 @@ import ( "context" "testing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/types" ) func TestDBBackedStore(t *testing.T) { @@ -15,23 +17,23 @@ func TestDBBackedStore(t *testing.T) { nil, nil) require.NoError(t, err) for _, row := range []struct { - id KeyBytes + id types.KeyBytes ts int64 }{ { - id: KeyBytes{0, 0, 0, 1, 0, 0, 0, 0}, + id: types.KeyBytes{0, 0, 0, 1, 0, 0, 0, 0}, ts: 100, }, { - id: KeyBytes{0, 0, 0, 3, 0, 0, 0, 0}, + id: types.KeyBytes{0, 0, 0, 3, 0, 0, 0, 0}, ts: 200, }, { - id: KeyBytes{0, 0, 0, 5, 0, 0, 0, 0}, + id: types.KeyBytes{0, 0, 0, 5, 0, 0, 0, 0}, ts: 300, }, { - id: KeyBytes{0, 0, 0, 7, 0, 0, 0, 0}, + id: types.KeyBytes{0, 0, 0, 7, 0, 0, 0, 0}, ts: 400, }, } { @@ -52,13 +54,11 @@ func TestDBBackedStore(t *testing.T) { require.NoError(t, err) verify := func(t *testing.T, ctx context.Context) { store := newDBBackedStore(db, sts, 8) - it := store.iter(ctx, KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}) - var actualIDs []KeyBytes - for range 5 { - actualIDs = append(actualIDs, itKey(t, it)) - require.NoError(t, it.Next()) - } - require.Equal(t, []KeyBytes{ + seq, err := store.from(ctx, types.KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}) + require.NoError(t, err) + actualIDs, err := types.GetN[types.KeyBytes](seq, 5) + require.NoError(t, err) + require.Equal(t, []types.KeyBytes{ {0, 0, 0, 1, 0, 0, 0, 0}, {0, 0, 0, 3, 0, 0, 0, 0}, {0, 0, 0, 5, 0, 0, 0, 0}, @@ -66,35 +66,30 @@ func TestDBBackedStore(t *testing.T) { {0, 0, 0, 1, 0, 0, 0, 0}, // wrapped around }, actualIDs) - it = store.start(ctx) - for n := range 5 { - require.Equal(t, actualIDs[n], itKey(t, it)) - require.NoError(t, it.Next()) - } + seq, err = store.all(ctx) + require.NoError(t, err) + actualIDs1, err := types.GetN[types.KeyBytes](seq, 5) + require.NoError(t, err) + require.Equal(t, actualIDs, actualIDs1) actualIDs = nil - it, count, err := store.iterSince(ctx, KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}, 300) + seq, count, err := store.since(ctx, types.KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}, 300) require.NoError(t, err) require.Equal(t, 2, count) - for range 3 { - actualIDs = append(actualIDs, itKey(t, it)) - require.NoError(t, it.Next()) - } - require.Equal(t, []KeyBytes{ + actualIDs, err = types.GetN[types.KeyBytes](seq, 3) + require.Equal(t, []types.KeyBytes{ {0, 0, 0, 5, 0, 0, 0, 0}, {0, 0, 0, 7, 0, 0, 0, 0}, {0, 0, 0, 5, 0, 0, 0, 0}, // wrapped around }, actualIDs) - require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 2, 0, 0, 0, 0})) - require.NoError(t, store.registerHash(KeyBytes{0, 0, 0, 9, 0, 0, 0, 0})) + require.NoError(t, store.registerHash(types.KeyBytes{0, 0, 0, 2, 0, 0, 0, 0})) + require.NoError(t, store.registerHash(types.KeyBytes{0, 0, 0, 9, 0, 0, 0, 0})) actualIDs = nil - it = store.iter(ctx, KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}) - for range 6 { - actualIDs = append(actualIDs, itKey(t, it)) - require.NoError(t, it.Next()) - } - require.Equal(t, []KeyBytes{ + seq, err = store.from(ctx, types.KeyBytes{0, 0, 0, 0, 0, 0, 0, 0}) + actualIDs, err = types.GetN[types.KeyBytes](seq, 6) + require.NoError(t, err) + require.Equal(t, []types.KeyBytes{ {0, 0, 0, 1, 0, 0, 0, 0}, {0, 0, 0, 2, 0, 0, 0, 0}, {0, 0, 0, 3, 0, 0, 0, 0}, diff --git a/sync2/dbsync/syncedtable.go b/sync2/dbsync/syncedtable.go index aba9ac9e73..350e15107b 100644 --- a/sync2/dbsync/syncedtable.go +++ b/sync2/dbsync/syncedtable.go @@ -4,7 +4,9 @@ import ( "fmt" rsql "github.com/rqlite/sql" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/types" ) type Binder func(s *sql.Statement) @@ -232,7 +234,7 @@ func (sts *SyncedTableSnapshot) loadIDsSince( func (sts *SyncedTableSnapshot) loadIDRange( db sql.Executor, - fromID KeyBytes, + fromID types.KeyBytes, limit int, dec func(stmt *sql.Statement) bool, ) error { @@ -281,7 +283,7 @@ func (sts *SyncedTableSnapshot) loadRecentCount( func (sts *SyncedTableSnapshot) loadRecent( db sql.Executor, - fromID KeyBytes, + fromID types.KeyBytes, limit int, since int64, dec func(stmt *sql.Statement) bool, diff --git a/sync2/dbsync/syncedtable_test.go b/sync2/dbsync/syncedtable_test.go index 1285f497c7..916429c4e1 100644 --- a/sync2/dbsync/syncedtable_test.go +++ b/sync2/dbsync/syncedtable_test.go @@ -8,6 +8,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/util" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/types" ) func parseSQLExpr(t *testing.T, s string) rsql.Expr { @@ -110,7 +111,7 @@ func TestSyncedTable_LoadIDs(t *testing.T) { mkDecode := func(ids *[]string) func(stmt *sql.Statement) bool { return func(stmt *sql.Statement) bool { - id := make(KeyBytes, stmt.ColumnLen(0)) + id := make(types.KeyBytes, stmt.ColumnLen(0)) stmt.ColumnBytes(0, id) *ids = append(*ids, id.String()) return true @@ -129,7 +130,7 @@ func TestSyncedTable_LoadIDs(t *testing.T) { return ids } - loadIDRange := func(sts *SyncedTableSnapshot, from KeyBytes, limit int) []string { + loadIDRange := func(sts *SyncedTableSnapshot, from types.KeyBytes, limit int) []string { var ids []string require.NoError(t, sts.loadIDRange(db, from, limit, mkDecode(&ids))) return ids @@ -146,7 +147,7 @@ func TestSyncedTable_LoadIDs(t *testing.T) { loadRecent := func( sts *SyncedTableSnapshot, - from KeyBytes, + from types.KeyBytes, limit int, ts int64, ) []string { diff --git a/sync2/hashsync/handler.go b/sync2/hashsync/handler.go deleted file mode 100644 index 2b16c6ad8b..0000000000 --- a/sync2/hashsync/handler.go +++ /dev/null @@ -1,374 +0,0 @@ -package hashsync - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "sync/atomic" - "time" - - "github.com/spacemeshos/go-spacemesh/codec" - "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/p2p" -) - -type sendable interface { - codec.Encodable - Type() MessageType -} - -// QQQQQ: rmme -var ( - numRead atomic.Int64 - numWritten atomic.Int64 -) - -func RmmeNumRead() int64 { - return numRead.Load() -} - -func RmmeNumWritten() int64 { - return numWritten.Load() -} - -type rmmeCountingStream struct { - io.ReadWriter -} - -// Read implements io.ReadWriter. -func (r *rmmeCountingStream) Read(p []byte) (n int, err error) { - n, err = r.ReadWriter.Read(p) - numRead.Add(int64(n)) - return n, err -} - -// Write implements io.ReadWriter. -func (r *rmmeCountingStream) Write(p []byte) (n int, err error) { - n, err = r.ReadWriter.Write(p) - numWritten.Add(int64(n)) - return n, err -} - -type conduitState int - -type wireConduit struct { - stream io.ReadWriter - initReqBuf *bytes.Buffer - // rmmePrint bool -} - -var _ Conduit = &wireConduit{} - -// NextMessage implements Conduit. -func (c *wireConduit) NextMessage() (SyncMessage, error) { - var b [1]byte - if _, err := io.ReadFull(c.stream, b[:]); err != nil { - if !errors.Is(err, io.EOF) { - return nil, err - } - return nil, nil - } - mtype := MessageType(b[0]) - // fmt.Fprintf(os.Stderr, "QQQQQ: wireConduit: receive message type %s\n", mtype) - switch mtype { - case MessageTypeDone: - return &DoneMessage{}, nil - case MessageTypeEndRound: - return &EndRoundMessage{}, nil - case MessageTypeItemBatch: - var m ItemBatchMessage - if _, err := codec.DecodeFrom(c.stream, &m); err != nil { - return nil, err - } - return &m, nil - case MessageTypeEmptySet: - return &EmptySetMessage{}, nil - case MessageTypeEmptyRange: - var m EmptyRangeMessage - if _, err := codec.DecodeFrom(c.stream, &m); err != nil { - return nil, err - } - return &m, nil - case MessageTypeFingerprint: - var m FingerprintMessage - if _, err := codec.DecodeFrom(c.stream, &m); err != nil { - return nil, err - } - return &m, nil - case MessageTypeRangeContents: - var m RangeContentsMessage - if _, err := codec.DecodeFrom(c.stream, &m); err != nil { - return nil, err - } - return &m, nil - case MessageTypeProbe: - var m ProbeMessage - if _, err := codec.DecodeFrom(c.stream, &m); err != nil { - return nil, err - } - return &m, nil - case MessageTypeSample: - var m SampleMessage - if _, err := codec.DecodeFrom(c.stream, &m); err != nil { - return nil, err - } - return &m, nil - case MessageTypeRecent: - var m RecentMessage - if _, err := codec.DecodeFrom(c.stream, &m); err != nil { - return nil, err - } - return &m, nil - default: - return nil, fmt.Errorf("invalid message code %02x", b[0]) - } -} - -func (c *wireConduit) send(m sendable) error { - // fmt.Fprintf(os.Stderr, "QQQQQ: send: %s: %#v\n", m.Type(), m) - var stream io.Writer - if c.initReqBuf != nil { - stream = c.initReqBuf - } else if c.stream == nil { - panic("BUG: wireConduit: no stream") - } else { - stream = c.stream - } - b := []byte{byte(m.Type())} - if _, err := stream.Write(b); err != nil { - return err - } - _, err := codec.EncodeTo(stream, m) - return err -} - -func (c *wireConduit) SendFingerprint(x, y Ordered, fingerprint any, count int) error { - return c.send(&FingerprintMessage{ - RangeX: OrderedToCompactHash32(x), - RangeY: OrderedToCompactHash32(y), - RangeFingerprint: fingerprint.(types.Hash12), - NumItems: uint32(count), - }) -} - -func (c *wireConduit) SendEmptySet() error { - return c.send(&EmptySetMessage{}) -} - -func (c *wireConduit) SendEmptyRange(x, y Ordered) error { - return c.send(&EmptyRangeMessage{ - RangeX: OrderedToCompactHash32(x), - RangeY: OrderedToCompactHash32(y), - }) -} - -func (c *wireConduit) SendRangeContents(x, y Ordered, count int) error { - return c.send(&RangeContentsMessage{ - RangeX: OrderedToCompactHash32(x), - RangeY: OrderedToCompactHash32(y), - NumItems: uint32(count), - }) -} - -func (c *wireConduit) SendChunk(items []Ordered) error { - msg := ItemBatchMessage{ - ContentKeys: make([]types.Hash32, len(items)), - } - for n, k := range items { - msg.ContentKeys[n] = k.(types.Hash32) - } - return c.send(&msg) -} - -func (c *wireConduit) SendEndRound() error { - return c.send(&EndRoundMessage{}) -} - -func (c *wireConduit) SendDone() error { - return c.send(&DoneMessage{}) -} - -func (c *wireConduit) SendProbe(x, y Ordered, fingerprint any, sampleSize int) error { - m := &ProbeMessage{ - RangeFingerprint: fingerprint.(types.Hash12), - SampleSize: uint32(sampleSize), - } - if x == nil && y == nil { - return c.send(m) - } else if x == nil || y == nil { - panic("BUG: SendProbe: bad range: just one of the bounds is nil") - } - m.RangeX = OrderedToCompactHash32(x) - m.RangeY = OrderedToCompactHash32(y) - return c.send(m) -} - -func (c *wireConduit) SendSample(x, y Ordered, fingerprint any, count, sampleSize int, it Iterator) error { - m := &SampleMessage{ - RangeFingerprint: fingerprint.(types.Hash12), - NumItems: uint32(count), - Sample: make([]MinhashSampleItem, sampleSize), - } - // fmt.Fprintf(os.Stderr, "QQQQQ: begin sending items\n") - for n := 0; n < sampleSize; n++ { - k, err := it.Key() - if err != nil { - return err - } - m.Sample[n] = MinhashSampleItemFromHash32(k.(types.Hash32)) - // fmt.Fprintf(os.Stderr, "QQQQQ: SEND: m.Sample[%d] = %s (full %s)\n", n, m.Sample[n], k.(types.Hash32).String()) - if err := it.Next(); err != nil { - return err - } - } - // fmt.Fprintf(os.Stderr, "QQQQQ: end sending items\n") - if x == nil && y == nil { - return c.send(m) - } else if x == nil || y == nil { - panic("BUG: SendProbe: bad range: just one of the bounds is nil") - } - m.RangeX = OrderedToCompactHash32(x) - m.RangeY = OrderedToCompactHash32(y) - return c.send(m) -} - -func (c *wireConduit) SendRecent(since time.Time) error { - return c.send(&RecentMessage{ - SinceTime: uint64(since.UnixNano()), - }) -} - -func (c *wireConduit) withInitialRequest(toCall func(Conduit) error) ([]byte, error) { - c.initReqBuf = new(bytes.Buffer) - defer func() { c.initReqBuf = nil }() - if err := toCall(c); err != nil { - return nil, err - } - return c.initReqBuf.Bytes(), nil -} - -func (c *wireConduit) handleStream(ctx context.Context, stream io.ReadWriter, rsr *RangeSetReconciler) error { - c.stream = stream - for { - // Process() will receive all items and messages from the peer - syncDone, err := rsr.Process(ctx, c) - if err != nil { - return err - } else if syncDone { - return nil - } - } -} - -// ShortenKey implements Conduit. -func (c *wireConduit) ShortenKey(k Ordered) Ordered { - return MinhashSampleItemFromHash32(k.(types.Hash32)) -} - -type PairwiseStoreSyncer struct { - r Requester - opts []RangeSetReconcilerOption -} - -var _ PairwiseSyncer = &PairwiseStoreSyncer{} - -func NewPairwiseStoreSyncer(r Requester, opts []RangeSetReconcilerOption) *PairwiseStoreSyncer { - return &PairwiseStoreSyncer{r: r, opts: opts} -} - -func (pss *PairwiseStoreSyncer) Probe( - ctx context.Context, - peer p2p.Peer, - is ItemStore, - x, y *types.Hash32, -) (ProbeResult, error) { - var ( - err error - initReq []byte - info RangeInfo - pr ProbeResult - ) - var c wireConduit - rsr := NewRangeSetReconciler(is, pss.opts...) - if x == nil { - initReq, err = c.withInitialRequest(func(c Conduit) error { - info, err = rsr.InitiateProbe(ctx, c) - return err - }) - } else { - initReq, err = c.withInitialRequest(func(c Conduit) error { - info, err = rsr.InitiateBoundedProbe(ctx, c, *x, *y) - return err - }) - } - if err != nil { - return ProbeResult{}, err - } - err = pss.r.StreamRequest(ctx, peer, initReq, func(ctx context.Context, stream io.ReadWriter) error { - c.stream = stream - var err error - pr, err = rsr.HandleProbeResponse(&c, info) - return err - }) - if err != nil { - return ProbeResult{}, err - } - return pr, nil -} - -func (pss *PairwiseStoreSyncer) SyncStore( - ctx context.Context, - peer p2p.Peer, - is ItemStore, - x, y *types.Hash32, -) error { - var c wireConduit - rsr := NewRangeSetReconciler(is, pss.opts...) - // c.rmmePrint = true - var ( - initReq []byte - err error - ) - if x == nil { - initReq, err = c.withInitialRequest(func(c Conduit) error { - return rsr.Initiate(ctx, c) - }) - } else { - initReq, err = c.withInitialRequest(func(c Conduit) error { - return rsr.InitiateBounded(ctx, c, *x, *y) - }) - } - if err != nil { - return err - } - return pss.r.StreamRequest(ctx, peer, initReq, func(ctx context.Context, stream io.ReadWriter) error { - s := &rmmeCountingStream{ReadWriter: stream} - return c.handleStream(ctx, s, rsr) - }) -} - -func (pss *PairwiseStoreSyncer) Serve( - ctx context.Context, - req []byte, - stream io.ReadWriter, - is ItemStore, -) error { - var c wireConduit - rsr := NewRangeSetReconciler(is, pss.opts...) - s := struct { - io.Reader - io.Writer - }{ - // prepend the received request to data being read - Reader: io.MultiReader(bytes.NewBuffer(req), stream), - Writer: stream, - } - return c.handleStream(ctx, s, rsr) -} - -// TODO: request duration -// TODO: validate counts -// TODO: don't forget about Initiate!!! -// TBD: use MessageType instead of byte diff --git a/sync2/hashsync/handler_test.go b/sync2/hashsync/handler_test.go deleted file mode 100644 index 79791f71db..0000000000 --- a/sync2/hashsync/handler_test.go +++ /dev/null @@ -1,682 +0,0 @@ -package hashsync - -import ( - "bytes" - "context" - "fmt" - "io" - "slices" - "testing" - "time" - - "github.com/jonboulle/clockwork" - mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" - "github.com/stretchr/testify/require" - "go.uber.org/zap/zaptest" - "golang.org/x/sync/errgroup" - - "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/p2p" - "github.com/spacemeshos/go-spacemesh/p2p/server" -) - -type incomingRequest struct { - initialRequest []byte - stream io.ReadWriter -} - -type fakeRequester struct { - id p2p.Peer - handler server.StreamHandler - peers map[p2p.Peer]*fakeRequester - reqCh chan incomingRequest - bytesSent uint32 - bytesReceived uint32 -} - -var _ Requester = &fakeRequester{} - -func newFakeRequester(id p2p.Peer, handler server.StreamHandler, peers ...Requester) *fakeRequester { - fr := &fakeRequester{ - id: id, - handler: handler, - reqCh: make(chan incomingRequest), - peers: make(map[p2p.Peer]*fakeRequester), - } - for _, p := range peers { - pfr := p.(*fakeRequester) - fr.peers[pfr.id] = pfr - } - return fr -} - -func (fr *fakeRequester) Run(ctx context.Context) error { - if fr.handler == nil { - panic("no handler") - } - for { - var req incomingRequest - select { - case <-ctx.Done(): - return nil - case req = <-fr.reqCh: - } - if err := fr.handler(ctx, req.initialRequest, req.stream); err != nil { - panic("handler error: " + err.Error()) - } - } -} - -func (fr *fakeRequester) request( - ctx context.Context, - pid p2p.Peer, - initialRequest []byte, - callback server.StreamRequestCallback, -) error { - p, found := fr.peers[pid] - if !found { - return fmt.Errorf("bad peer %q", pid) - } - r, w := io.Pipe() - defer r.Close() - defer w.Close() - stream := struct { - io.Reader - io.Writer - }{ - Reader: r, - Writer: w, - } - select { - case p.reqCh <- incomingRequest{ - initialRequest: initialRequest, - stream: stream, - }: - case <-ctx.Done(): - return ctx.Err() - } - return callback(ctx, stream) -} - -func (fr *fakeRequester) StreamRequest( - ctx context.Context, - pid p2p.Peer, - initialRequest []byte, - callback server.StreamRequestCallback, - extraProtocols ...string, -) error { - return fr.request(ctx, pid, initialRequest, callback) -} - -type sliceIterator struct { - s []Ordered -} - -var _ Iterator = &sliceIterator{} - -func (it *sliceIterator) Equal(other Iterator) bool { - // not used by wireConduit - return false -} - -func (it *sliceIterator) Key() (Ordered, error) { - if len(it.s) != 0 { - return it.s[0], nil - } - return nil, nil -} - -func (it *sliceIterator) Next() error { - if len(it.s) != 0 { - it.s = it.s[1:] - } - return nil -} - -func (it *sliceIterator) Clone() Iterator { - return &sliceIterator{s: it.s} -} - -type fakeSend struct { - x, y Ordered - count int - fp any - items []Ordered - endRound bool - done bool -} - -func (fs *fakeSend) send(c Conduit) error { - switch { - case fs.endRound: - return c.SendEndRound() - case fs.done: - return c.SendDone() - case len(fs.items) != 0: - return c.SendChunk(slices.Clone(fs.items)) - case fs.x == nil || fs.y == nil: - return c.SendEmptySet() - case fs.count == 0: - return c.SendEmptyRange(fs.x, fs.y) - case fs.fp != nil: - return c.SendFingerprint(fs.x, fs.y, fs.fp, fs.count) - default: - return c.SendRangeContents(fs.x, fs.y, fs.count) - } -} - -type fakeRound struct { - name string - expectMsgs []SyncMessage - toSend []*fakeSend -} - -func (r *fakeRound) handleMessages(t *testing.T, c Conduit) error { - var msgs []SyncMessage - for { - msg, err := c.NextMessage() - if err != nil { - return fmt.Errorf("NextMessage(): %w", err) - } else if msg == nil { - break - } - msgs = append(msgs, msg) - if msg.Type() == MessageTypeDone || msg.Type() == MessageTypeEndRound { - break - } - } - require.Equal(t, r.expectMsgs, msgs, "messages for round %q", r.name) - return nil -} - -func (r *fakeRound) handleConversation(t *testing.T, c *wireConduit) error { - if err := r.handleMessages(t, c); err != nil { - return err - } - for _, s := range r.toSend { - if err := s.send(c); err != nil { - return err - } - } - return nil -} - -func makeTestStreamHandler(t *testing.T, c *wireConduit, rounds []fakeRound) server.StreamHandler { - cbk := makeTestRequestCallback(t, c, rounds) - return func(ctx context.Context, initialRequest []byte, stream io.ReadWriter) error { - t.Logf("init request bytes: %d", len(initialRequest)) - s := struct { - io.Reader - io.Writer - }{ - // prepend the received request to data being read - Reader: io.MultiReader(bytes.NewBuffer(initialRequest), stream), - Writer: stream, - } - return cbk(ctx, s) - } -} - -func makeTestRequestCallback(t *testing.T, c *wireConduit, rounds []fakeRound) server.StreamRequestCallback { - return func(ctx context.Context, stream io.ReadWriter) error { - if c == nil { - c = &wireConduit{stream: stream} - } else { - c.stream = stream - } - for _, round := range rounds { - if err := round.handleConversation(t, c); err != nil { - return err - } - } - return nil - } -} - -func TestWireConduit(t *testing.T) { - hs := make([]types.Hash32, 16) - for n := range hs { - hs[n] = types.RandomHash() - } - fp := types.Hash12(hs[2][:12]) - srvHandler := makeTestStreamHandler(t, nil, []fakeRound{ - { - name: "server got 1st request", - expectMsgs: []SyncMessage{ - &FingerprintMessage{ - RangeX: Hash32ToCompact(hs[0]), - RangeY: Hash32ToCompact(hs[1]), - RangeFingerprint: fp, - NumItems: 4, - }, - &EndRoundMessage{}, - }, - toSend: []*fakeSend{ - { - x: hs[0], - y: hs[3], - count: 2, - }, - { - x: hs[3], - y: hs[6], - count: 2, - }, - { - items: []Ordered{hs[4], hs[5], hs[7], hs[8]}, - }, - { - endRound: true, - }, - }, - }, - { - name: "server got 2nd request", - expectMsgs: []SyncMessage{ - &ItemBatchMessage{ - ContentKeys: []types.Hash32{hs[9], hs[10], hs[11]}, - }, - &EndRoundMessage{}, - }, - toSend: []*fakeSend{ - { - done: true, - }, - }, - }, - }) - - srv := newFakeRequester("srv", srvHandler) - var eg errgroup.Group - ctx, cancel := context.WithCancel(context.Background()) - defer func() { - cancel() - eg.Wait() - }() - eg.Go(func() error { - return srv.Run(ctx) - }) - - client := newFakeRequester("client", nil, srv) - var c wireConduit - initReq, err := c.withInitialRequest(func(c Conduit) error { - if err := c.SendFingerprint(hs[0], hs[1], fp, 4); err != nil { - return err - } - return c.SendEndRound() - }) - require.NoError(t, err) - clientCbk := makeTestRequestCallback(t, &c, []fakeRound{ - { - name: "client got 1st response", - expectMsgs: []SyncMessage{ - &RangeContentsMessage{ - RangeX: Hash32ToCompact(hs[0]), - RangeY: Hash32ToCompact(hs[3]), - NumItems: 2, - }, - &RangeContentsMessage{ - RangeX: Hash32ToCompact(hs[3]), - RangeY: Hash32ToCompact(hs[6]), - NumItems: 2, - }, - &ItemBatchMessage{ - ContentKeys: []types.Hash32{hs[4], hs[5], hs[7], hs[8]}, - }, - &EndRoundMessage{}, - }, - toSend: []*fakeSend{ - { - items: []Ordered{hs[9], hs[10], hs[11]}, - }, - { - endRound: true, - }, - }, - }, - { - name: "client got 2nd response", - expectMsgs: []SyncMessage{ - &DoneMessage{}, - }, - }, - }) - err = client.StreamRequest(context.Background(), "srv", initReq, clientCbk) - require.NoError(t, err) -} - -type getRequesterFunc func(name string, handler server.StreamHandler, peers ...Requester) (Requester, p2p.Peer) - -func withClientServer( - store ItemStore, - getRequester getRequesterFunc, - opts []RangeSetReconcilerOption, - toCall func(ctx context.Context, client Requester, srvPeerID p2p.Peer), -) { - srvHandler := func(ctx context.Context, req []byte, stream io.ReadWriter) error { - pss := NewPairwiseStoreSyncer(nil, opts) - return pss.Serve(ctx, req, stream, store) - } - srv, srvPeerID := getRequester("srv", srvHandler) - var eg errgroup.Group - ctx, cancel := context.WithCancel(context.Background()) - defer func() { - cancel() - eg.Wait() - }() - eg.Go(func() error { - return srv.Run(ctx) - }) - - client, _ := getRequester("client", nil, srv) - toCall(ctx, client, srvPeerID) -} - -func fakeRequesterGetter() getRequesterFunc { - return func(name string, handler server.StreamHandler, peers ...Requester) (Requester, p2p.Peer) { - pid := p2p.Peer(name) - return newFakeRequester(pid, handler, peers...), pid - } -} - -func p2pRequesterGetter(t *testing.T) getRequesterFunc { - mesh, err := mocknet.FullMeshConnected(2) - require.NoError(t, err) - proto := "itest" - opts := []server.Opt{ - server.WithRequestSizeLimit(100_000_000), - server.WithTimeout(10 * time.Second), - server.WithLog(zaptest.NewLogger(t)), - } - return func(name string, handler server.StreamHandler, peers ...Requester) (Requester, p2p.Peer) { - if len(peers) == 0 { - return server.New(mesh.Hosts()[0], proto, handler, opts...), mesh.Hosts()[0].ID() - } - s := server.New(mesh.Hosts()[1], proto, handler, opts...) - // TODO: this 'Eventually' is somewhat misplaced - require.Eventually(t, func() bool { - for _, h := range mesh.Hosts()[0:] { - if len(h.Mux().Protocols()) == 0 { - return false - } - } - return true - }, time.Second, 10*time.Millisecond) - return s, mesh.Hosts()[1].ID() - } -} - -type syncTracer struct { - dumb bool - receivedItems int - sentItems int -} - -var _ Tracer = &syncTracer{} - -func (tr *syncTracer) OnDumbSync() { - tr.dumb = true -} - -func (tr *syncTracer) OnRecent(receivedItems, sentItems int) { - tr.receivedItems += receivedItems - tr.sentItems += sentItems -} - -type fakeRecentIterator struct { - items []types.Hash32 - p int -} - -func (it *fakeRecentIterator) Clone() Iterator { - return &fakeRecentIterator{items: it.items} -} - -func (it *fakeRecentIterator) Key() (Ordered, error) { - return it.items[it.p], nil -} - -func (it *fakeRecentIterator) Next() error { - it.p = (it.p + 1) % len(it.items) - return nil -} - -var _ Iterator = &fakeRecentIterator{} - -type fakeRecentSet struct { - ItemStore - timestamps map[types.Hash32]time.Time - clock clockwork.Clock -} - -var _ ItemStore = &fakeRecentSet{} - -var startDate = time.Date(2024, 8, 29, 18, 0, 0, 0, time.UTC) - -func (frs *fakeRecentSet) registerAll(ctx context.Context) error { - frs.timestamps = make(map[types.Hash32]time.Time) - t := startDate - for v, err := range IterItems[types.Hash32](ctx, frs.ItemStore) { - if err != nil { - return err - } - frs.timestamps[v] = t - t = t.Add(time.Second) - } - return nil -} - -func (frs *fakeRecentSet) Add(ctx context.Context, k Ordered) error { - if err := frs.ItemStore.Add(ctx, k); err != nil { - return err - } - h := k.(types.Hash32) - frs.timestamps[h] = frs.clock.Now() - return nil -} - -func (frs *fakeRecentSet) Recent(ctx context.Context, since time.Time) (Iterator, int, error) { - var items []types.Hash32 - for h, err := range IterItems[types.Hash32](ctx, frs.ItemStore) { - if err != nil { - return nil, 0, err - } - if !frs.timestamps[h].Before(since) { - items = append(items, h) - } - } - return &fakeRecentIterator{items: items}, len(items), nil -} - -func testWireSync(t *testing.T, getRequester getRequesterFunc) { - for _, tc := range []struct { - name string - cfg xorSyncTestConfig - dumb bool - opts []RangeSetReconcilerOption - advance time.Duration - sentRecent bool - receivedRecent bool - }{ - { - name: "non-dumb sync", - cfg: xorSyncTestConfig{ - maxSendRange: 1, - numTestHashes: 1000, - minNumSpecificA: 8, - maxNumSpecificA: 16, - minNumSpecificB: 8, - maxNumSpecificB: 16, - }, - dumb: false, - }, - { - name: "dumb sync", - cfg: xorSyncTestConfig{ - maxSendRange: 1, - numTestHashes: 1000, - minNumSpecificA: 400, - maxNumSpecificA: 500, - minNumSpecificB: 400, - maxNumSpecificB: 500, - }, - dumb: true, - }, - { - name: "recent sync", - cfg: xorSyncTestConfig{ - maxSendRange: 1, - numTestHashes: 1000, - minNumSpecificA: 400, - maxNumSpecificA: 500, - minNumSpecificB: 400, - maxNumSpecificB: 500, - allowReAdd: true, - }, - dumb: false, - opts: []RangeSetReconcilerOption{ - WithRecentTimeSpan(990 * time.Second), - }, - advance: 1000 * time.Second, - sentRecent: true, - receivedRecent: true, - }, - { - name: "larger sync", - cfg: xorSyncTestConfig{ - // even larger test: - // maxSendRange: 1, - // numTestHashes: 5000000, - // minNumSpecificA: 15000, - // maxNumSpecificA: 20000, - // minNumSpecificB: 15, - // maxNumSpecificB: 20, - - maxSendRange: 1, - numTestHashes: 100000, - minNumSpecificA: 4, - maxNumSpecificA: 100, - minNumSpecificB: 4, - maxNumSpecificB: 100, - }, - dumb: false, - }, - } { - t.Run(tc.name, func(t *testing.T) { - verifyXORSync(t, tc.cfg, func( - storeA, storeB ItemStore, - numSpecific int, - opts []RangeSetReconcilerOption, - ) bool { - clock := clockwork.NewFakeClockAt(startDate) - // Note that at this point, the items are already added to the sets - // and thus fakeRecentSet.Add is not invoked for them, just underlying - // set's Add method - frsA := &fakeRecentSet{ItemStore: storeA, clock: clock} - require.NoError(t, frsA.registerAll(context.Background())) - storeA = frsA - frsB := &fakeRecentSet{ItemStore: storeB, clock: clock} - require.NoError(t, frsB.registerAll(context.Background())) - storeB = frsB - var tr syncTracer - opts = append(opts, WithTracer(&tr), WithRangeReconcilerClock(clock)) - opts = append(opts, tc.opts...) - opts = opts[0:len(opts):len(opts)] - clock.Advance(tc.advance) - withClientServer( - storeA, getRequester, - opts, - func(ctx context.Context, client Requester, srvPeerID p2p.Peer) { - nr := RmmeNumRead() - nw := RmmeNumWritten() - pss := NewPairwiseStoreSyncer(client, opts) - err := pss.SyncStore(ctx, srvPeerID, storeB, nil, nil) - require.NoError(t, err) - - if fr, ok := client.(*fakeRequester); ok { - t.Logf("numSpecific: %d, bytesSent %d, bytesReceived %d", - numSpecific, fr.bytesSent, fr.bytesReceived) - } - t.Logf("bytes read: %d, bytes written: %d", - RmmeNumRead()-nr, RmmeNumWritten()-nw) - }) - require.Equal(t, tc.dumb, tr.dumb, "dumb sync") - require.Equal(t, tc.receivedRecent, tr.receivedItems > 0) - require.Equal(t, tc.sentRecent, tr.sentItems > 0) - return true - }) - }) - } -} - -func TestWireSync(t *testing.T) { - t.Run("fake requester", func(t *testing.T) { - testWireSync(t, fakeRequesterGetter()) - }) - t.Run("p2p", func(t *testing.T) { - testWireSync(t, p2pRequesterGetter(t)) - }) -} - -func testWireProbe(t *testing.T, getRequester getRequesterFunc) Requester { - cfg := xorSyncTestConfig{ - maxSendRange: 1, - numTestHashes: 10000, - minNumSpecificA: 130, - maxNumSpecificA: 130, - minNumSpecificB: 130, - maxNumSpecificB: 130, - } - var client Requester - verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []RangeSetReconcilerOption) bool { - withClientServer( - storeA, getRequester, opts, - func(ctx context.Context, client Requester, srvPeerID p2p.Peer) { - pss := NewPairwiseStoreSyncer(client, opts) - minA, err := storeA.Min(ctx) - require.NoError(t, err) - kA, err := minA.Key() - require.NoError(t, err) - infoA, err := storeA.GetRangeInfo(ctx, nil, kA, kA, -1) - require.NoError(t, err) - prA, err := pss.Probe(ctx, srvPeerID, storeB, nil, nil) - require.NoError(t, err) - require.Equal(t, infoA.Fingerprint, prA.FP) - require.Equal(t, infoA.Count, prA.Count) - require.InDelta(t, 0.98, prA.Sim, 0.05, "sim") - - minA, err = storeA.Min(ctx) - require.NoError(t, err) - kA, err = minA.Key() - require.NoError(t, err) - partInfoA, err := storeA.GetRangeInfo(ctx, nil, kA, kA, infoA.Count/2) - require.NoError(t, err) - xK, err := partInfoA.Start.Key() - require.NoError(t, err) - x := xK.(types.Hash32) - yK, err := partInfoA.End.Key() - y := yK.(types.Hash32) - // partInfoA = storeA.GetRangeInfo(nil, x, y, -1) - prA, err = pss.Probe(ctx, srvPeerID, storeB, &x, &y) - require.NoError(t, err) - require.Equal(t, partInfoA.Fingerprint, prA.FP) - require.Equal(t, partInfoA.Count, prA.Count) - require.InDelta(t, 0.98, prA.Sim, 0.1, "sim") - // QQQQQ: TBD: check prA.Sim and prB.Sim values - }) - return false - }) - return client -} - -func TestWireProbe(t *testing.T) { - t.Run("fake requester", func(t *testing.T) { - testWireProbe(t, fakeRequesterGetter()) - }) - t.Run("p2p", func(t *testing.T) { - testWireProbe(t, p2pRequesterGetter(t)) - }) -} - -// TODO: test bounded sync -// TODO: test fail handler diff --git a/sync2/hashsync/interface.go b/sync2/hashsync/interface.go deleted file mode 100644 index acde320a84..0000000000 --- a/sync2/hashsync/interface.go +++ /dev/null @@ -1,102 +0,0 @@ -package hashsync - -import ( - "context" - "io" - "time" - - "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/p2p" - "github.com/spacemeshos/go-spacemesh/p2p/server" -) - -//go:generate mockgen -typed -package=hashsync -destination=./mocks_test.go -source=./interface.go - -// Iterator points to in item in ItemStore -type Iterator interface { - // Key returns the key corresponding to iterator position. It returns - // nil if the ItemStore is empty - // If the iterator is returned along with a count, the return value of Key() - // after calling Next() count times is dependent on the implementation. - Key() (Ordered, error) - // Next advances the iterator - Next() error - // Clone returns a copy of the iterator - Clone() Iterator -} - -// RangeInfo contains information about a range of items in the ItemStore as returned by -// ItemStore.GetRangeInfo. -type RangeInfo struct { - // Fingerprint of the interval - Fingerprint any - // Number of items in the interval - Count int - // An iterator pointing to the beginning of the interval or nil if count is zero. - Start Iterator - // An iterator pointing to the end of the interval or nil if count is zero. - End Iterator -} - -// SplitInfo contains information about range split in two. -type SplitInfo struct { - // 2 parts of the range - Parts [2]RangeInfo - // Middle point between the ranges - Middle Ordered -} - -// ItemStore represents the data store that can be synced against a remote peer -type ItemStore interface { - // Add adds a key to the store - Add(ctx context.Context, k Ordered) error - // GetRangeInfo returns RangeInfo for the item range in the tree. - // If count >= 0, at most count items are returned, and RangeInfo - // is returned for the corresponding subrange of the requested range. - // If both x and y is nil, the whole set of items is used. - // If only x or only y is nil, GetRangeInfo panics - GetRangeInfo(ctx context.Context, preceding Iterator, x, y Ordered, count int) (RangeInfo, error) - // SplitRange splits the range roughly after the specified count of items, - // returning RangeInfo for the first half and the second half of the range. - SplitRange(ctx context.Context, preceding Iterator, x, y Ordered, count int) (SplitInfo, error) - // Min returns the iterator pointing at the minimum element - // in the store. If the store is empty, it returns nil - Min(ctx context.Context) (Iterator, error) - // Copy makes a shallow copy of the ItemStore - Copy() ItemStore - // Has returns true if the specified key is present in ItemStore - Has(ctx context.Context, k Ordered) (bool, error) - // Recent returns an Iterator that yields the items added since the specified - // timestamp. Some ItemStore implementations may not have Recent implemented, in - // which case it should return an error. - Recent(ctx context.Context, since time.Time) (Iterator, int, error) -} - -type Requester interface { - Run(context.Context) error - StreamRequest(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error -} - -type SyncBase interface { - Count(ctx context.Context) (int, error) - Derive(p p2p.Peer) Syncer - Probe(ctx context.Context, p p2p.Peer) (ProbeResult, error) - Wait() error -} - -type Syncer interface { - Peer() p2p.Peer - Sync(ctx context.Context, x, y *types.Hash32) error - Serve(ctx context.Context, req []byte, stream io.ReadWriter) error -} - -type PairwiseSyncer interface { - Probe(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) (ProbeResult, error) - SyncStore(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) error - Serve(ctx context.Context, req []byte, stream io.ReadWriter, is ItemStore) error -} - -type syncRunner interface { - splitSync(ctx context.Context, syncPeers []p2p.Peer) error - fullSync(ctx context.Context, syncPeers []p2p.Peer) error -} diff --git a/sync2/hashsync/mocks_test.go b/sync2/hashsync/mocks_test.go deleted file mode 100644 index 6902d22b30..0000000000 --- a/sync2/hashsync/mocks_test.go +++ /dev/null @@ -1,1110 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: ./interface.go -// -// Generated by this command: -// -// mockgen -typed -package=hashsync -destination=./mocks_test.go -source=./interface.go -// - -// Package hashsync is a generated GoMock package. -package hashsync - -import ( - context "context" - io "io" - reflect "reflect" - time "time" - - types "github.com/spacemeshos/go-spacemesh/common/types" - p2p "github.com/spacemeshos/go-spacemesh/p2p" - server "github.com/spacemeshos/go-spacemesh/p2p/server" - gomock "go.uber.org/mock/gomock" -) - -// MockIterator is a mock of Iterator interface. -type MockIterator struct { - ctrl *gomock.Controller - recorder *MockIteratorMockRecorder -} - -// MockIteratorMockRecorder is the mock recorder for MockIterator. -type MockIteratorMockRecorder struct { - mock *MockIterator -} - -// NewMockIterator creates a new mock instance. -func NewMockIterator(ctrl *gomock.Controller) *MockIterator { - mock := &MockIterator{ctrl: ctrl} - mock.recorder = &MockIteratorMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockIterator) EXPECT() *MockIteratorMockRecorder { - return m.recorder -} - -// Clone mocks base method. -func (m *MockIterator) Clone() Iterator { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Clone") - ret0, _ := ret[0].(Iterator) - return ret0 -} - -// Clone indicates an expected call of Clone. -func (mr *MockIteratorMockRecorder) Clone() *MockIteratorCloneCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clone", reflect.TypeOf((*MockIterator)(nil).Clone)) - return &MockIteratorCloneCall{Call: call} -} - -// MockIteratorCloneCall wrap *gomock.Call -type MockIteratorCloneCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockIteratorCloneCall) Return(arg0 Iterator) *MockIteratorCloneCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockIteratorCloneCall) Do(f func() Iterator) *MockIteratorCloneCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockIteratorCloneCall) DoAndReturn(f func() Iterator) *MockIteratorCloneCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Key mocks base method. -func (m *MockIterator) Key() (Ordered, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Key") - ret0, _ := ret[0].(Ordered) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Key indicates an expected call of Key. -func (mr *MockIteratorMockRecorder) Key() *MockIteratorKeyCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Key", reflect.TypeOf((*MockIterator)(nil).Key)) - return &MockIteratorKeyCall{Call: call} -} - -// MockIteratorKeyCall wrap *gomock.Call -type MockIteratorKeyCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockIteratorKeyCall) Return(arg0 Ordered, arg1 error) *MockIteratorKeyCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockIteratorKeyCall) Do(f func() (Ordered, error)) *MockIteratorKeyCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockIteratorKeyCall) DoAndReturn(f func() (Ordered, error)) *MockIteratorKeyCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Next mocks base method. -func (m *MockIterator) Next() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Next") - ret0, _ := ret[0].(error) - return ret0 -} - -// Next indicates an expected call of Next. -func (mr *MockIteratorMockRecorder) Next() *MockIteratorNextCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockIterator)(nil).Next)) - return &MockIteratorNextCall{Call: call} -} - -// MockIteratorNextCall wrap *gomock.Call -type MockIteratorNextCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockIteratorNextCall) Return(arg0 error) *MockIteratorNextCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockIteratorNextCall) Do(f func() error) *MockIteratorNextCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockIteratorNextCall) DoAndReturn(f func() error) *MockIteratorNextCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// MockItemStore is a mock of ItemStore interface. -type MockItemStore struct { - ctrl *gomock.Controller - recorder *MockItemStoreMockRecorder -} - -// MockItemStoreMockRecorder is the mock recorder for MockItemStore. -type MockItemStoreMockRecorder struct { - mock *MockItemStore -} - -// NewMockItemStore creates a new mock instance. -func NewMockItemStore(ctrl *gomock.Controller) *MockItemStore { - mock := &MockItemStore{ctrl: ctrl} - mock.recorder = &MockItemStoreMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockItemStore) EXPECT() *MockItemStoreMockRecorder { - return m.recorder -} - -// Add mocks base method. -func (m *MockItemStore) Add(ctx context.Context, k Ordered) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Add", ctx, k) - ret0, _ := ret[0].(error) - return ret0 -} - -// Add indicates an expected call of Add. -func (mr *MockItemStoreMockRecorder) Add(ctx, k any) *MockItemStoreAddCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockItemStore)(nil).Add), ctx, k) - return &MockItemStoreAddCall{Call: call} -} - -// MockItemStoreAddCall wrap *gomock.Call -type MockItemStoreAddCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockItemStoreAddCall) Return(arg0 error) *MockItemStoreAddCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockItemStoreAddCall) Do(f func(context.Context, Ordered) error) *MockItemStoreAddCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreAddCall) DoAndReturn(f func(context.Context, Ordered) error) *MockItemStoreAddCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Copy mocks base method. -func (m *MockItemStore) Copy() ItemStore { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Copy") - ret0, _ := ret[0].(ItemStore) - return ret0 -} - -// Copy indicates an expected call of Copy. -func (mr *MockItemStoreMockRecorder) Copy() *MockItemStoreCopyCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Copy", reflect.TypeOf((*MockItemStore)(nil).Copy)) - return &MockItemStoreCopyCall{Call: call} -} - -// MockItemStoreCopyCall wrap *gomock.Call -type MockItemStoreCopyCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockItemStoreCopyCall) Return(arg0 ItemStore) *MockItemStoreCopyCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockItemStoreCopyCall) Do(f func() ItemStore) *MockItemStoreCopyCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreCopyCall) DoAndReturn(f func() ItemStore) *MockItemStoreCopyCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// GetRangeInfo mocks base method. -func (m *MockItemStore) GetRangeInfo(ctx context.Context, preceding Iterator, x, y Ordered, count int) (RangeInfo, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRangeInfo", ctx, preceding, x, y, count) - ret0, _ := ret[0].(RangeInfo) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetRangeInfo indicates an expected call of GetRangeInfo. -func (mr *MockItemStoreMockRecorder) GetRangeInfo(ctx, preceding, x, y, count any) *MockItemStoreGetRangeInfoCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRangeInfo", reflect.TypeOf((*MockItemStore)(nil).GetRangeInfo), ctx, preceding, x, y, count) - return &MockItemStoreGetRangeInfoCall{Call: call} -} - -// MockItemStoreGetRangeInfoCall wrap *gomock.Call -type MockItemStoreGetRangeInfoCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockItemStoreGetRangeInfoCall) Return(arg0 RangeInfo, arg1 error) *MockItemStoreGetRangeInfoCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockItemStoreGetRangeInfoCall) Do(f func(context.Context, Iterator, Ordered, Ordered, int) (RangeInfo, error)) *MockItemStoreGetRangeInfoCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreGetRangeInfoCall) DoAndReturn(f func(context.Context, Iterator, Ordered, Ordered, int) (RangeInfo, error)) *MockItemStoreGetRangeInfoCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Has mocks base method. -func (m *MockItemStore) Has(ctx context.Context, k Ordered) (bool, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Has", ctx, k) - ret0, _ := ret[0].(bool) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Has indicates an expected call of Has. -func (mr *MockItemStoreMockRecorder) Has(ctx, k any) *MockItemStoreHasCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Has", reflect.TypeOf((*MockItemStore)(nil).Has), ctx, k) - return &MockItemStoreHasCall{Call: call} -} - -// MockItemStoreHasCall wrap *gomock.Call -type MockItemStoreHasCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockItemStoreHasCall) Return(arg0 bool, arg1 error) *MockItemStoreHasCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockItemStoreHasCall) Do(f func(context.Context, Ordered) (bool, error)) *MockItemStoreHasCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreHasCall) DoAndReturn(f func(context.Context, Ordered) (bool, error)) *MockItemStoreHasCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Min mocks base method. -func (m *MockItemStore) Min(ctx context.Context) (Iterator, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Min", ctx) - ret0, _ := ret[0].(Iterator) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Min indicates an expected call of Min. -func (mr *MockItemStoreMockRecorder) Min(ctx any) *MockItemStoreMinCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Min", reflect.TypeOf((*MockItemStore)(nil).Min), ctx) - return &MockItemStoreMinCall{Call: call} -} - -// MockItemStoreMinCall wrap *gomock.Call -type MockItemStoreMinCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockItemStoreMinCall) Return(arg0 Iterator, arg1 error) *MockItemStoreMinCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockItemStoreMinCall) Do(f func(context.Context) (Iterator, error)) *MockItemStoreMinCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreMinCall) DoAndReturn(f func(context.Context) (Iterator, error)) *MockItemStoreMinCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Recent mocks base method. -func (m *MockItemStore) Recent(ctx context.Context, since time.Time) (Iterator, int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Recent", ctx, since) - ret0, _ := ret[0].(Iterator) - ret1, _ := ret[1].(int) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// Recent indicates an expected call of Recent. -func (mr *MockItemStoreMockRecorder) Recent(ctx, since any) *MockItemStoreRecentCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Recent", reflect.TypeOf((*MockItemStore)(nil).Recent), ctx, since) - return &MockItemStoreRecentCall{Call: call} -} - -// MockItemStoreRecentCall wrap *gomock.Call -type MockItemStoreRecentCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockItemStoreRecentCall) Return(arg0 Iterator, arg1 int, arg2 error) *MockItemStoreRecentCall { - c.Call = c.Call.Return(arg0, arg1, arg2) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockItemStoreRecentCall) Do(f func(context.Context, time.Time) (Iterator, int, error)) *MockItemStoreRecentCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreRecentCall) DoAndReturn(f func(context.Context, time.Time) (Iterator, int, error)) *MockItemStoreRecentCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// SplitRange mocks base method. -func (m *MockItemStore) SplitRange(ctx context.Context, preceding Iterator, x, y Ordered, count int) (SplitInfo, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SplitRange", ctx, preceding, x, y, count) - ret0, _ := ret[0].(SplitInfo) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// SplitRange indicates an expected call of SplitRange. -func (mr *MockItemStoreMockRecorder) SplitRange(ctx, preceding, x, y, count any) *MockItemStoreSplitRangeCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SplitRange", reflect.TypeOf((*MockItemStore)(nil).SplitRange), ctx, preceding, x, y, count) - return &MockItemStoreSplitRangeCall{Call: call} -} - -// MockItemStoreSplitRangeCall wrap *gomock.Call -type MockItemStoreSplitRangeCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockItemStoreSplitRangeCall) Return(arg0 SplitInfo, arg1 error) *MockItemStoreSplitRangeCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockItemStoreSplitRangeCall) Do(f func(context.Context, Iterator, Ordered, Ordered, int) (SplitInfo, error)) *MockItemStoreSplitRangeCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockItemStoreSplitRangeCall) DoAndReturn(f func(context.Context, Iterator, Ordered, Ordered, int) (SplitInfo, error)) *MockItemStoreSplitRangeCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// MockRequester is a mock of Requester interface. -type MockRequester struct { - ctrl *gomock.Controller - recorder *MockRequesterMockRecorder -} - -// MockRequesterMockRecorder is the mock recorder for MockRequester. -type MockRequesterMockRecorder struct { - mock *MockRequester -} - -// NewMockRequester creates a new mock instance. -func NewMockRequester(ctrl *gomock.Controller) *MockRequester { - mock := &MockRequester{ctrl: ctrl} - mock.recorder = &MockRequesterMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockRequester) EXPECT() *MockRequesterMockRecorder { - return m.recorder -} - -// Run mocks base method. -func (m *MockRequester) Run(arg0 context.Context) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Run", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// Run indicates an expected call of Run. -func (mr *MockRequesterMockRecorder) Run(arg0 any) *MockRequesterRunCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRequester)(nil).Run), arg0) - return &MockRequesterRunCall{Call: call} -} - -// MockRequesterRunCall wrap *gomock.Call -type MockRequesterRunCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockRequesterRunCall) Return(arg0 error) *MockRequesterRunCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockRequesterRunCall) Do(f func(context.Context) error) *MockRequesterRunCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockRequesterRunCall) DoAndReturn(f func(context.Context) error) *MockRequesterRunCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// StreamRequest mocks base method. -func (m *MockRequester) StreamRequest(arg0 context.Context, arg1 p2p.Peer, arg2 []byte, arg3 server.StreamRequestCallback, arg4 ...string) error { - m.ctrl.T.Helper() - varargs := []any{arg0, arg1, arg2, arg3} - for _, a := range arg4 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "StreamRequest", varargs...) - ret0, _ := ret[0].(error) - return ret0 -} - -// StreamRequest indicates an expected call of StreamRequest. -func (mr *MockRequesterMockRecorder) StreamRequest(arg0, arg1, arg2, arg3 any, arg4 ...any) *MockRequesterStreamRequestCall { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamRequest", reflect.TypeOf((*MockRequester)(nil).StreamRequest), varargs...) - return &MockRequesterStreamRequestCall{Call: call} -} - -// MockRequesterStreamRequestCall wrap *gomock.Call -type MockRequesterStreamRequestCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockRequesterStreamRequestCall) Return(arg0 error) *MockRequesterStreamRequestCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockRequesterStreamRequestCall) Do(f func(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error) *MockRequesterStreamRequestCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockRequesterStreamRequestCall) DoAndReturn(f func(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error) *MockRequesterStreamRequestCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// MockSyncBase is a mock of SyncBase interface. -type MockSyncBase struct { - ctrl *gomock.Controller - recorder *MockSyncBaseMockRecorder -} - -// MockSyncBaseMockRecorder is the mock recorder for MockSyncBase. -type MockSyncBaseMockRecorder struct { - mock *MockSyncBase -} - -// NewMockSyncBase creates a new mock instance. -func NewMockSyncBase(ctrl *gomock.Controller) *MockSyncBase { - mock := &MockSyncBase{ctrl: ctrl} - mock.recorder = &MockSyncBaseMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSyncBase) EXPECT() *MockSyncBaseMockRecorder { - return m.recorder -} - -// Count mocks base method. -func (m *MockSyncBase) Count(ctx context.Context) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Count", ctx) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Count indicates an expected call of Count. -func (mr *MockSyncBaseMockRecorder) Count(ctx any) *MockSyncBaseCountCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Count", reflect.TypeOf((*MockSyncBase)(nil).Count), ctx) - return &MockSyncBaseCountCall{Call: call} -} - -// MockSyncBaseCountCall wrap *gomock.Call -type MockSyncBaseCountCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockSyncBaseCountCall) Return(arg0 int, arg1 error) *MockSyncBaseCountCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockSyncBaseCountCall) Do(f func(context.Context) (int, error)) *MockSyncBaseCountCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockSyncBaseCountCall) DoAndReturn(f func(context.Context) (int, error)) *MockSyncBaseCountCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Derive mocks base method. -func (m *MockSyncBase) Derive(p p2p.Peer) Syncer { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Derive", p) - ret0, _ := ret[0].(Syncer) - return ret0 -} - -// Derive indicates an expected call of Derive. -func (mr *MockSyncBaseMockRecorder) Derive(p any) *MockSyncBaseDeriveCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Derive", reflect.TypeOf((*MockSyncBase)(nil).Derive), p) - return &MockSyncBaseDeriveCall{Call: call} -} - -// MockSyncBaseDeriveCall wrap *gomock.Call -type MockSyncBaseDeriveCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockSyncBaseDeriveCall) Return(arg0 Syncer) *MockSyncBaseDeriveCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockSyncBaseDeriveCall) Do(f func(p2p.Peer) Syncer) *MockSyncBaseDeriveCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockSyncBaseDeriveCall) DoAndReturn(f func(p2p.Peer) Syncer) *MockSyncBaseDeriveCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Probe mocks base method. -func (m *MockSyncBase) Probe(ctx context.Context, p p2p.Peer) (ProbeResult, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Probe", ctx, p) - ret0, _ := ret[0].(ProbeResult) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Probe indicates an expected call of Probe. -func (mr *MockSyncBaseMockRecorder) Probe(ctx, p any) *MockSyncBaseProbeCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Probe", reflect.TypeOf((*MockSyncBase)(nil).Probe), ctx, p) - return &MockSyncBaseProbeCall{Call: call} -} - -// MockSyncBaseProbeCall wrap *gomock.Call -type MockSyncBaseProbeCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockSyncBaseProbeCall) Return(arg0 ProbeResult, arg1 error) *MockSyncBaseProbeCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockSyncBaseProbeCall) Do(f func(context.Context, p2p.Peer) (ProbeResult, error)) *MockSyncBaseProbeCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockSyncBaseProbeCall) DoAndReturn(f func(context.Context, p2p.Peer) (ProbeResult, error)) *MockSyncBaseProbeCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Wait mocks base method. -func (m *MockSyncBase) Wait() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Wait") - ret0, _ := ret[0].(error) - return ret0 -} - -// Wait indicates an expected call of Wait. -func (mr *MockSyncBaseMockRecorder) Wait() *MockSyncBaseWaitCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Wait", reflect.TypeOf((*MockSyncBase)(nil).Wait)) - return &MockSyncBaseWaitCall{Call: call} -} - -// MockSyncBaseWaitCall wrap *gomock.Call -type MockSyncBaseWaitCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockSyncBaseWaitCall) Return(arg0 error) *MockSyncBaseWaitCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockSyncBaseWaitCall) Do(f func() error) *MockSyncBaseWaitCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockSyncBaseWaitCall) DoAndReturn(f func() error) *MockSyncBaseWaitCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// MockSyncer is a mock of Syncer interface. -type MockSyncer struct { - ctrl *gomock.Controller - recorder *MockSyncerMockRecorder -} - -// MockSyncerMockRecorder is the mock recorder for MockSyncer. -type MockSyncerMockRecorder struct { - mock *MockSyncer -} - -// NewMockSyncer creates a new mock instance. -func NewMockSyncer(ctrl *gomock.Controller) *MockSyncer { - mock := &MockSyncer{ctrl: ctrl} - mock.recorder = &MockSyncerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSyncer) EXPECT() *MockSyncerMockRecorder { - return m.recorder -} - -// Peer mocks base method. -func (m *MockSyncer) Peer() p2p.Peer { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Peer") - ret0, _ := ret[0].(p2p.Peer) - return ret0 -} - -// Peer indicates an expected call of Peer. -func (mr *MockSyncerMockRecorder) Peer() *MockSyncerPeerCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Peer", reflect.TypeOf((*MockSyncer)(nil).Peer)) - return &MockSyncerPeerCall{Call: call} -} - -// MockSyncerPeerCall wrap *gomock.Call -type MockSyncerPeerCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockSyncerPeerCall) Return(arg0 p2p.Peer) *MockSyncerPeerCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockSyncerPeerCall) Do(f func() p2p.Peer) *MockSyncerPeerCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockSyncerPeerCall) DoAndReturn(f func() p2p.Peer) *MockSyncerPeerCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Serve mocks base method. -func (m *MockSyncer) Serve(ctx context.Context, req []byte, stream io.ReadWriter) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Serve", ctx, req, stream) - ret0, _ := ret[0].(error) - return ret0 -} - -// Serve indicates an expected call of Serve. -func (mr *MockSyncerMockRecorder) Serve(ctx, req, stream any) *MockSyncerServeCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Serve", reflect.TypeOf((*MockSyncer)(nil).Serve), ctx, req, stream) - return &MockSyncerServeCall{Call: call} -} - -// MockSyncerServeCall wrap *gomock.Call -type MockSyncerServeCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockSyncerServeCall) Return(arg0 error) *MockSyncerServeCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockSyncerServeCall) Do(f func(context.Context, []byte, io.ReadWriter) error) *MockSyncerServeCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockSyncerServeCall) DoAndReturn(f func(context.Context, []byte, io.ReadWriter) error) *MockSyncerServeCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Sync mocks base method. -func (m *MockSyncer) Sync(ctx context.Context, x, y *types.Hash32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Sync", ctx, x, y) - ret0, _ := ret[0].(error) - return ret0 -} - -// Sync indicates an expected call of Sync. -func (mr *MockSyncerMockRecorder) Sync(ctx, x, y any) *MockSyncerSyncCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sync", reflect.TypeOf((*MockSyncer)(nil).Sync), ctx, x, y) - return &MockSyncerSyncCall{Call: call} -} - -// MockSyncerSyncCall wrap *gomock.Call -type MockSyncerSyncCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockSyncerSyncCall) Return(arg0 error) *MockSyncerSyncCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockSyncerSyncCall) Do(f func(context.Context, *types.Hash32, *types.Hash32) error) *MockSyncerSyncCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockSyncerSyncCall) DoAndReturn(f func(context.Context, *types.Hash32, *types.Hash32) error) *MockSyncerSyncCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// MockPairwiseSyncer is a mock of PairwiseSyncer interface. -type MockPairwiseSyncer struct { - ctrl *gomock.Controller - recorder *MockPairwiseSyncerMockRecorder -} - -// MockPairwiseSyncerMockRecorder is the mock recorder for MockPairwiseSyncer. -type MockPairwiseSyncerMockRecorder struct { - mock *MockPairwiseSyncer -} - -// NewMockPairwiseSyncer creates a new mock instance. -func NewMockPairwiseSyncer(ctrl *gomock.Controller) *MockPairwiseSyncer { - mock := &MockPairwiseSyncer{ctrl: ctrl} - mock.recorder = &MockPairwiseSyncerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockPairwiseSyncer) EXPECT() *MockPairwiseSyncerMockRecorder { - return m.recorder -} - -// Probe mocks base method. -func (m *MockPairwiseSyncer) Probe(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) (ProbeResult, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Probe", ctx, peer, is, x, y) - ret0, _ := ret[0].(ProbeResult) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Probe indicates an expected call of Probe. -func (mr *MockPairwiseSyncerMockRecorder) Probe(ctx, peer, is, x, y any) *MockPairwiseSyncerProbeCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Probe", reflect.TypeOf((*MockPairwiseSyncer)(nil).Probe), ctx, peer, is, x, y) - return &MockPairwiseSyncerProbeCall{Call: call} -} - -// MockPairwiseSyncerProbeCall wrap *gomock.Call -type MockPairwiseSyncerProbeCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockPairwiseSyncerProbeCall) Return(arg0 ProbeResult, arg1 error) *MockPairwiseSyncerProbeCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockPairwiseSyncerProbeCall) Do(f func(context.Context, p2p.Peer, ItemStore, *types.Hash32, *types.Hash32) (ProbeResult, error)) *MockPairwiseSyncerProbeCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockPairwiseSyncerProbeCall) DoAndReturn(f func(context.Context, p2p.Peer, ItemStore, *types.Hash32, *types.Hash32) (ProbeResult, error)) *MockPairwiseSyncerProbeCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Serve mocks base method. -func (m *MockPairwiseSyncer) Serve(ctx context.Context, req []byte, stream io.ReadWriter, is ItemStore) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Serve", ctx, req, stream, is) - ret0, _ := ret[0].(error) - return ret0 -} - -// Serve indicates an expected call of Serve. -func (mr *MockPairwiseSyncerMockRecorder) Serve(ctx, req, stream, is any) *MockPairwiseSyncerServeCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Serve", reflect.TypeOf((*MockPairwiseSyncer)(nil).Serve), ctx, req, stream, is) - return &MockPairwiseSyncerServeCall{Call: call} -} - -// MockPairwiseSyncerServeCall wrap *gomock.Call -type MockPairwiseSyncerServeCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockPairwiseSyncerServeCall) Return(arg0 error) *MockPairwiseSyncerServeCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockPairwiseSyncerServeCall) Do(f func(context.Context, []byte, io.ReadWriter, ItemStore) error) *MockPairwiseSyncerServeCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockPairwiseSyncerServeCall) DoAndReturn(f func(context.Context, []byte, io.ReadWriter, ItemStore) error) *MockPairwiseSyncerServeCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// SyncStore mocks base method. -func (m *MockPairwiseSyncer) SyncStore(ctx context.Context, peer p2p.Peer, is ItemStore, x, y *types.Hash32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SyncStore", ctx, peer, is, x, y) - ret0, _ := ret[0].(error) - return ret0 -} - -// SyncStore indicates an expected call of SyncStore. -func (mr *MockPairwiseSyncerMockRecorder) SyncStore(ctx, peer, is, x, y any) *MockPairwiseSyncerSyncStoreCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncStore", reflect.TypeOf((*MockPairwiseSyncer)(nil).SyncStore), ctx, peer, is, x, y) - return &MockPairwiseSyncerSyncStoreCall{Call: call} -} - -// MockPairwiseSyncerSyncStoreCall wrap *gomock.Call -type MockPairwiseSyncerSyncStoreCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockPairwiseSyncerSyncStoreCall) Return(arg0 error) *MockPairwiseSyncerSyncStoreCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockPairwiseSyncerSyncStoreCall) Do(f func(context.Context, p2p.Peer, ItemStore, *types.Hash32, *types.Hash32) error) *MockPairwiseSyncerSyncStoreCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockPairwiseSyncerSyncStoreCall) DoAndReturn(f func(context.Context, p2p.Peer, ItemStore, *types.Hash32, *types.Hash32) error) *MockPairwiseSyncerSyncStoreCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// MocksyncRunner is a mock of syncRunner interface. -type MocksyncRunner struct { - ctrl *gomock.Controller - recorder *MocksyncRunnerMockRecorder -} - -// MocksyncRunnerMockRecorder is the mock recorder for MocksyncRunner. -type MocksyncRunnerMockRecorder struct { - mock *MocksyncRunner -} - -// NewMocksyncRunner creates a new mock instance. -func NewMocksyncRunner(ctrl *gomock.Controller) *MocksyncRunner { - mock := &MocksyncRunner{ctrl: ctrl} - mock.recorder = &MocksyncRunnerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MocksyncRunner) EXPECT() *MocksyncRunnerMockRecorder { - return m.recorder -} - -// fullSync mocks base method. -func (m *MocksyncRunner) fullSync(ctx context.Context, syncPeers []p2p.Peer) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "fullSync", ctx, syncPeers) - ret0, _ := ret[0].(error) - return ret0 -} - -// fullSync indicates an expected call of fullSync. -func (mr *MocksyncRunnerMockRecorder) fullSync(ctx, syncPeers any) *MocksyncRunnerfullSyncCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "fullSync", reflect.TypeOf((*MocksyncRunner)(nil).fullSync), ctx, syncPeers) - return &MocksyncRunnerfullSyncCall{Call: call} -} - -// MocksyncRunnerfullSyncCall wrap *gomock.Call -type MocksyncRunnerfullSyncCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MocksyncRunnerfullSyncCall) Return(arg0 error) *MocksyncRunnerfullSyncCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MocksyncRunnerfullSyncCall) Do(f func(context.Context, []p2p.Peer) error) *MocksyncRunnerfullSyncCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MocksyncRunnerfullSyncCall) DoAndReturn(f func(context.Context, []p2p.Peer) error) *MocksyncRunnerfullSyncCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// splitSync mocks base method. -func (m *MocksyncRunner) splitSync(ctx context.Context, syncPeers []p2p.Peer) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "splitSync", ctx, syncPeers) - ret0, _ := ret[0].(error) - return ret0 -} - -// splitSync indicates an expected call of splitSync. -func (mr *MocksyncRunnerMockRecorder) splitSync(ctx, syncPeers any) *MocksyncRunnersplitSyncCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "splitSync", reflect.TypeOf((*MocksyncRunner)(nil).splitSync), ctx, syncPeers) - return &MocksyncRunnersplitSyncCall{Call: call} -} - -// MocksyncRunnersplitSyncCall wrap *gomock.Call -type MocksyncRunnersplitSyncCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MocksyncRunnersplitSyncCall) Return(arg0 error) *MocksyncRunnersplitSyncCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MocksyncRunnersplitSyncCall) Do(f func(context.Context, []p2p.Peer) error) *MocksyncRunnersplitSyncCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MocksyncRunnersplitSyncCall) DoAndReturn(f func(context.Context, []p2p.Peer) error) *MocksyncRunnersplitSyncCall { - c.Call = c.Call.DoAndReturn(f) - return c -} diff --git a/sync2/hashsync/monoid.go b/sync2/hashsync/monoid.go deleted file mode 100644 index 0e4672fcb6..0000000000 --- a/sync2/hashsync/monoid.go +++ /dev/null @@ -1,61 +0,0 @@ -package hashsync - -type Monoid interface { - Identity() any - Op(a, b any) any - Fingerprint(v any) any -} - -type CountingMonoid struct{} - -var _ Monoid = CountingMonoid{} - -func (m CountingMonoid) Identity() any { return 0 } -func (m CountingMonoid) Op(a, b any) any { return a.(int) + b.(int) } -func (m CountingMonoid) Fingerprint(v any) any { return 1 } - -type combinedMonoid struct { - m1, m2 Monoid -} - -func CombineMonoids(m1, m2 Monoid) Monoid { - return combinedMonoid{m1: m1, m2: m2} -} - -type CombinedFingerprint struct { - First any - Second any -} - -func (m combinedMonoid) Identity() any { - return CombinedFingerprint{ - First: m.m1.Identity(), - Second: m.m2.Identity(), - } -} - -func (m combinedMonoid) Op(a, b any) any { - ac := a.(CombinedFingerprint) - bc := b.(CombinedFingerprint) - return CombinedFingerprint{ - First: m.m1.Op(ac.First, bc.First), - Second: m.m2.Op(ac.Second, bc.Second), - } -} - -func (m combinedMonoid) Fingerprint(v any) any { - return CombinedFingerprint{ - First: m.m1.Fingerprint(v), - Second: m.m2.Fingerprint(v), - } -} - -func CombinedFirst[T any](fp any) T { - cfp := fp.(CombinedFingerprint) - return cfp.First.(T) -} - -func CombinedSecond[T any](fp any) T { - cfp := fp.(CombinedFingerprint) - return cfp.Second.(T) -} diff --git a/sync2/hashsync/rangesync_test.go b/sync2/hashsync/rangesync_test.go deleted file mode 100644 index bd50bdf761..0000000000 --- a/sync2/hashsync/rangesync_test.go +++ /dev/null @@ -1,985 +0,0 @@ -package hashsync - -import ( - "context" - "errors" - "fmt" - "math/rand" - "slices" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/exp/maps" -) - -type rangeMessage struct { - mtype MessageType - x, y Ordered - fp any - count int - keys []Ordered - since time.Time -} - -var _ SyncMessage = rangeMessage{} - -func (m rangeMessage) Type() MessageType { return m.mtype } -func (m rangeMessage) X() Ordered { return m.x } -func (m rangeMessage) Y() Ordered { return m.y } -func (m rangeMessage) Fingerprint() any { return m.fp } -func (m rangeMessage) Count() int { return m.count } -func (m rangeMessage) Keys() []Ordered { return m.keys } -func (m rangeMessage) Since() time.Time { return m.since } - -func (m rangeMessage) String() string { - return SyncMessageToString(m) -} - -type fakeConduit struct { - t *testing.T - msgs []rangeMessage - resp []rangeMessage -} - -var _ Conduit = &fakeConduit{} - -func (fc *fakeConduit) gotoResponse() { - fc.msgs = fc.resp - fc.resp = nil -} - -func (fc *fakeConduit) numItems() int { - n := 0 - for _, m := range fc.msgs { - n += len(m.Keys()) - } - return n -} - -func (fc *fakeConduit) NextMessage() (SyncMessage, error) { - if len(fc.msgs) != 0 { - m := fc.msgs[0] - fc.msgs = fc.msgs[1:] - return m, nil - } - - return nil, nil -} - -func (fc *fakeConduit) sendMsg(msg rangeMessage) { - fc.resp = append(fc.resp, msg) -} - -func (fc *fakeConduit) SendFingerprint(x, y Ordered, fingerprint any, count int) error { - require.NotNil(fc.t, x) - require.NotNil(fc.t, y) - require.NotZero(fc.t, count) - require.NotNil(fc.t, fingerprint) - fc.sendMsg(rangeMessage{ - mtype: MessageTypeFingerprint, - x: x, - y: y, - fp: fingerprint, - count: count, - }) - return nil -} - -func (fc *fakeConduit) SendEmptySet() error { - fc.sendMsg(rangeMessage{mtype: MessageTypeEmptySet}) - return nil -} - -func (fc *fakeConduit) SendEmptyRange(x, y Ordered) error { - require.NotNil(fc.t, x) - require.NotNil(fc.t, y) - fc.sendMsg(rangeMessage{ - mtype: MessageTypeEmptyRange, - x: x, - y: y, - }) - return nil -} - -func (fc *fakeConduit) SendRangeContents(x, y Ordered, count int) error { - require.NotNil(fc.t, x) - require.NotNil(fc.t, y) - fc.sendMsg(rangeMessage{ - mtype: MessageTypeRangeContents, - x: x, - y: y, - count: count, - }) - return nil -} - -func (fc *fakeConduit) SendChunk(items []Ordered) error { - require.NotEmpty(fc.t, items) - fc.sendMsg(rangeMessage{ - mtype: MessageTypeItemBatch, - keys: items, - }) - return nil -} - -func (fc *fakeConduit) SendEndRound() error { - fc.sendMsg(rangeMessage{mtype: MessageTypeEndRound}) - return nil -} - -func (fc *fakeConduit) SendDone() error { - fc.sendMsg(rangeMessage{mtype: MessageTypeDone}) - return nil -} - -func (fc *fakeConduit) SendProbe(x, y Ordered, fingerprint any, sampleSize int) error { - fc.sendMsg(rangeMessage{ - mtype: MessageTypeProbe, - x: x, - y: y, - fp: fingerprint, - count: sampleSize, - }) - return nil -} - -func (fc *fakeConduit) SendSample(x, y Ordered, fingerprint any, count, sampleSize int, it Iterator) error { - msg := rangeMessage{ - mtype: MessageTypeSample, - x: x, - y: y, - fp: fingerprint, - count: count, - keys: make([]Ordered, sampleSize), - } - for n := 0; n < sampleSize; n++ { - k, err := it.Key() - require.NoError(fc.t, err) - require.NotNil(fc.t, k) - msg.keys[n] = k - if err := it.Next(); err != nil { - return err - } - } - fc.sendMsg(msg) - return nil -} - -func (fc *fakeConduit) SendRecent(since time.Time) error { - fc.sendMsg(rangeMessage{ - mtype: MessageTypeRecent, - since: since, - }) - return nil -} - -func (fc *fakeConduit) ShortenKey(k Ordered) Ordered { - return k -} - -type dumbStoreIterator struct { - ds *dumbStore - n int -} - -var _ Iterator = &dumbStoreIterator{} - -func (it *dumbStoreIterator) Equal(other Iterator) bool { - o := other.(*dumbStoreIterator) - if it.ds != o.ds { - panic("comparing iterators from different dumbStores") - } - return it.n == o.n -} - -func (it *dumbStoreIterator) Key() (Ordered, error) { - return it.ds.keys[it.n], nil -} - -func (it *dumbStoreIterator) Next() error { - if len(it.ds.keys) != 0 { - it.n = (it.n + 1) % len(it.ds.keys) - } - return nil -} - -func (it *dumbStoreIterator) Clone() Iterator { - return &dumbStoreIterator{ - ds: it.ds, - n: it.n, - } -} - -type dumbStore struct { - keys []sampleID -} - -var _ ItemStore = &dumbStore{} - -func (ds *dumbStore) Add(ctx context.Context, k Ordered) error { - id := k.(sampleID) - if len(ds.keys) == 0 { - ds.keys = []sampleID{id} - return nil - } - p := slices.IndexFunc(ds.keys, func(other sampleID) bool { - return other >= id - }) - switch { - case p < 0: - ds.keys = append(ds.keys, id) - case id == ds.keys[p]: - // already present - default: - ds.keys = slices.Insert(ds.keys, p, id) - } - - return nil -} - -func (ds *dumbStore) iter(n int) Iterator { - if n == -1 || n == len(ds.keys) { - return nil - } - return &dumbStoreIterator{ds: ds, n: n} -} - -func (ds *dumbStore) last() sampleID { - if len(ds.keys) == 0 { - panic("can't get the last element: zero items") - } - return ds.keys[len(ds.keys)-1] -} - -func (ds *dumbStore) iterFor(s sampleID) Iterator { - n := slices.Index(ds.keys, s) - if n == -1 { - panic("item not found: " + s) - } - return ds.iter(n) -} - -func (ds *dumbStore) GetRangeInfo( - ctx context.Context, - preceding Iterator, - x, y Ordered, - count int, -) (RangeInfo, error) { - if x == nil && y == nil { - it, err := ds.Min(ctx) - if err != nil { - return RangeInfo{}, err - } - if it == nil { - return RangeInfo{ - Fingerprint: "", - }, nil - } else { - x, err = it.Key() - if err != nil { - return RangeInfo{}, err - } - y = x - } - } else if x == nil || y == nil { - panic("BUG: bad X or Y") - } - all := "" - for _, k := range ds.keys { - all += string(k) - } - vx := x.(sampleID) - vy := y.(sampleID) - if preceding != nil { - k, err := preceding.Key() - if err != nil { - return RangeInfo{}, err - } - if k.Compare(x) > 0 { - panic("preceding info after x") - } - } - fp, startStr, endStr := naiveRange(all, string(vx), string(vy), count) - r := RangeInfo{ - Fingerprint: fp, - Count: len(fp), - } - if all != "" { - if startStr == "" || endStr == "" { - panic("empty startStr/endStr from naiveRange") - } - r.Start = ds.iterFor(sampleID(startStr)) - r.End = ds.iterFor(sampleID(endStr)) - } - return r, nil -} - -func (ds *dumbStore) SplitRange( - ctx context.Context, - preceding Iterator, - x, y Ordered, - count int, -) (SplitInfo, error) { - if count <= 0 { - panic("BUG: bad split count") - } - part0, err := ds.GetRangeInfo(ctx, preceding, x, y, count) - if err != nil { - return SplitInfo{}, err - } - if part0.Count == 0 { - return SplitInfo{}, errors.New("can't split empty range") - } - middle, err := part0.End.Key() - if err != nil { - return SplitInfo{}, err - } - part1, err := ds.GetRangeInfo(ctx, part0.End.Clone(), middle, y, -1) - if err != nil { - return SplitInfo{}, err - } - return SplitInfo{ - Parts: [2]RangeInfo{part0, part1}, - Middle: middle, - }, nil -} - -func (ds *dumbStore) Min(ctx context.Context) (Iterator, error) { - if len(ds.keys) == 0 { - return nil, nil - } - return &dumbStoreIterator{ - ds: ds, - n: 0, - }, nil -} - -func (ds *dumbStore) Copy() ItemStore { - return &dumbStore{keys: slices.Clone(ds.keys)} -} - -func (ds *dumbStore) Has(ctx context.Context, k Ordered) (bool, error) { - for _, cur := range ds.keys { - if k.Compare(cur) == 0 { - return true, nil - } - } - return false, nil -} - -func (ds *dumbStore) Recent(ctx context.Context, since time.Time) (Iterator, int, error) { - return nil, 0, nil -} - -type verifiedStoreIterator struct { - t *testing.T - knownGood Iterator - it Iterator -} - -var _ Iterator = &verifiedStoreIterator{} - -func (it verifiedStoreIterator) Key() (Ordered, error) { - k1, err := it.knownGood.Key() - if err != nil { - return nil, err - } - k2, err := it.it.Key() - if err == nil { - assert.Equal(it.t, k1, k2, "keys") - } - return k2, nil -} - -func (it verifiedStoreIterator) Next() error { - err1 := it.knownGood.Next() - err2 := it.it.Next() - switch { - case err1 == nil && err2 == nil: - k1, err := it.knownGood.Key() - if err != nil { - return err - } - k2, err := it.it.Key() - if err != nil { - return err - } - assert.Equal(it.t, k1, k2, "keys for Next()") - case err1 != nil && err2 != nil: - return err2 - default: - assert.Fail(it.t, "iterator error mismatch") - } - return nil -} - -func (it verifiedStoreIterator) Clone() Iterator { - return verifiedStoreIterator{ - t: it.t, - knownGood: it.knownGood.Clone(), - it: it.it.Clone(), - } -} - -type verifiedStore struct { - t *testing.T - knownGood ItemStore - store ItemStore - disableReAdd bool - added map[sampleID]struct{} -} - -var _ ItemStore = &verifiedStore{} - -func disableReAdd(s ItemStore) { - if vs, ok := s.(*verifiedStore); ok { - vs.disableReAdd = true - } -} - -func (vs *verifiedStore) Add(ctx context.Context, k Ordered) error { - if vs.disableReAdd { - _, found := vs.added[k.(sampleID)] - require.False(vs.t, found, "hash sent twice: %v", k) - if vs.added == nil { - vs.added = make(map[sampleID]struct{}) - } - vs.added[k.(sampleID)] = struct{}{} - } - if err := vs.knownGood.Add(ctx, k); err != nil { - return fmt.Errorf("add to knownGood: %w", err) - } - if err := vs.store.Add(ctx, k); err != nil { - return fmt.Errorf("add to store: %w", err) - } - return nil -} - -func (vs *verifiedStore) verifySameRangeInfo(ri1, ri2 RangeInfo) RangeInfo { - require.Equal(vs.t, ri1.Fingerprint, ri2.Fingerprint, "range info fingerprint") - require.Equal(vs.t, ri1.Count, ri2.Count, "range info count") - ri := RangeInfo{ - Fingerprint: ri2.Fingerprint, - Count: ri2.Count, - } - if ri1.Start == nil { - require.Nil(vs.t, ri2.Start, "range info start") - require.Nil(vs.t, ri1.End, "range info end (known good)") - require.Nil(vs.t, ri2.End, "range info end") - } else { - require.NotNil(vs.t, ri2.Start, "range info start") - k1, err := ri1.Start.Key() - require.NoError(vs.t, err) - k2, err := ri2.Start.Key() - require.NoError(vs.t, err) - require.Equal(vs.t, k1, k2, "range info start key") - require.NotNil(vs.t, ri1.End, "range info end (known good)") - require.NotNil(vs.t, ri2.End, "range info end") - ri.Start = verifiedStoreIterator{ - t: vs.t, - knownGood: ri1.Start, - it: ri2.Start, - } - } - if ri1.End == nil { - require.Nil(vs.t, ri2.End, "range info end") - } else { - require.NotNil(vs.t, ri2.End, "range info end") - k1, err := ri1.Start.Key() - require.NoError(vs.t, err) - k2, err := ri2.Start.Key() - require.NoError(vs.t, err) - require.Equal(vs.t, k1, k2, "range info end key") - ri.End = verifiedStoreIterator{ - t: vs.t, - knownGood: ri1.End, - it: ri2.End, - } - } - return ri -} - -func (vs *verifiedStore) GetRangeInfo( - ctx context.Context, - preceding Iterator, - x, y Ordered, - count int, -) (RangeInfo, error) { - var ( - ri1, ri2 RangeInfo - err error - ) - if preceding != nil { - p := preceding.(verifiedStoreIterator) - ri1, err = vs.knownGood.GetRangeInfo(ctx, p.knownGood, x, y, count) - require.NoError(vs.t, err) - ri2, err = vs.store.GetRangeInfo(ctx, p.it, x, y, count) - require.NoError(vs.t, err) - } else { - ri1, err = vs.knownGood.GetRangeInfo(ctx, nil, x, y, count) - require.NoError(vs.t, err) - ri2, err = vs.store.GetRangeInfo(ctx, nil, x, y, count) - require.NoError(vs.t, err) - } - // QQQQQ: TODO: if count >= 0 and start+end != nil, do more calls to GetRangeInfo using resulting - // end iterator key to make sure the range is correct - return vs.verifySameRangeInfo(ri1, ri2), nil -} - -func (vs *verifiedStore) SplitRange( - ctx context.Context, - preceding Iterator, - x, y Ordered, - count int, -) (SplitInfo, error) { - var ( - si1, si2 SplitInfo - err error - ) - if preceding != nil { - p := preceding.(verifiedStoreIterator) - si1, err = vs.knownGood.SplitRange(ctx, p.knownGood, x, y, count) - require.NoError(vs.t, err) - si2, err = vs.store.SplitRange(ctx, p.it, x, y, count) - require.NoError(vs.t, err) - } else { - si1, err = vs.knownGood.SplitRange(ctx, nil, x, y, count) - require.NoError(vs.t, err) - si2, err = vs.store.SplitRange(ctx, nil, x, y, count) - require.NoError(vs.t, err) - } - require.Equal(vs.t, si1.Middle, si2.Middle, "split middle") - return SplitInfo{ - Parts: [2]RangeInfo{ - vs.verifySameRangeInfo(si1.Parts[0], si2.Parts[0]), - vs.verifySameRangeInfo(si1.Parts[1], si2.Parts[1]), - }, - Middle: si1.Middle, - }, nil -} - -func (vs *verifiedStore) Min(ctx context.Context) (Iterator, error) { - m1, err := vs.knownGood.Min(ctx) - require.NoError(vs.t, err) - m2, err := vs.store.Min(ctx) - require.NoError(vs.t, err) - if m1 == nil { - require.Nil(vs.t, m2, "Min") - return nil, nil - } else { - require.NotNil(vs.t, m2, "Min") - k1, err := m1.Key() - require.NoError(vs.t, err) - k2, err := m2.Key() - require.NoError(vs.t, err) - require.Equal(vs.t, k1, k2, "Min key") - } - return verifiedStoreIterator{ - t: vs.t, - knownGood: m1, - it: m2, - }, nil -} - -func (vs *verifiedStore) Copy() ItemStore { - return &verifiedStore{ - t: vs.t, - knownGood: vs.knownGood.Copy(), - store: vs.store.Copy(), - disableReAdd: vs.disableReAdd, - } -} - -func (vs *verifiedStore) Has(ctx context.Context, k Ordered) (bool, error) { - h1, err := vs.knownGood.Has(ctx, k) - require.NoError(vs.t, err) - h2, err := vs.store.Has(ctx, k) - require.NoError(vs.t, err) - require.Equal(vs.t, h1, h2) - return h2, nil -} - -func (vs *verifiedStore) Recent(ctx context.Context, since time.Time) (Iterator, int, error) { - return nil, 0, nil -} - -type storeFactory func(t *testing.T) ItemStore - -func makeDumbStore(t *testing.T) ItemStore { - return &dumbStore{} -} - -func makeSyncTreeStore(t *testing.T) ItemStore { - return NewSyncTreeStore(sampleMonoid{}) -} - -func makeVerifiedSyncTreeStore(t *testing.T) ItemStore { - return &verifiedStore{ - t: t, - knownGood: makeDumbStore(t), - store: makeSyncTreeStore(t), - } -} - -func makeStore(t *testing.T, f storeFactory, items string) ItemStore { - s := f(t) - for _, c := range items { - require.NoError(t, s.Add(context.Background(), sampleID(c))) - } - return s -} - -func storeItemStr(is ItemStore) string { - ids, err := CollectStoreItems[sampleID](context.Background(), is) - if err != nil { - panic("store error") - } - var r strings.Builder - for _, id := range ids { - r.WriteString(string(id)) - } - return r.String() -} - -var testStores = []struct { - name string - factory storeFactory -}{ - { - name: "dumb store", - factory: makeDumbStore, - }, - { - name: "monoid tree store", - factory: makeSyncTreeStore, - }, - { - name: "verified monoid tree store", - factory: makeVerifiedSyncTreeStore, - }, -} - -func forTestStores(t *testing.T, testFunc func(t *testing.T, factory storeFactory)) { - for _, s := range testStores { - t.Run(s.name, func(t *testing.T) { - testFunc(t, s.factory) - }) - } -} - -// QQQQQ: rm -func dumpRangeMessages(t *testing.T, msgs []rangeMessage, fmt string, args ...any) { - t.Logf(fmt, args...) - for _, m := range msgs { - t.Logf(" %s", m) - } -} - -func runSync(t *testing.T, syncA, syncB *RangeSetReconciler, maxRounds int) (nRounds, nMsg, nItems int) { - fc := &fakeConduit{t: t} - require.NoError(t, syncA.Initiate(context.Background(), fc)) - return doRunSync(fc, syncA, syncB, maxRounds) -} - -func runBoundedSync(t *testing.T, syncA, syncB *RangeSetReconciler, x, y Ordered, maxRounds int) (nRounds, nMsg, nItems int) { - fc := &fakeConduit{t: t} - require.NoError(t, syncA.InitiateBounded(context.Background(), fc, x, y)) - return doRunSync(fc, syncA, syncB, maxRounds) -} - -func doRunSync(fc *fakeConduit, syncA, syncB *RangeSetReconciler, maxRounds int) (nRounds, nMsg, nItems int) { - var i int - aDone, bDone := false, false - // dumpRangeMessages(fc.t, fc.resp.msgs, "A %q -> B %q (init):", storeItemStr(syncA.is), storeItemStr(syncB.is)) - // dumpRangeMessages(fc.t, fc.resp.msgs, "A -> B (init):") - for i = 0; ; i++ { - if i == maxRounds { - require.FailNow(fc.t, "too many rounds", "didn't reconcile in %d rounds", i) - } - fc.gotoResponse() - nMsg += len(fc.msgs) - nItems += fc.numItems() - var err error - bDone, err = syncB.Process(context.Background(), fc) - require.NoError(fc.t, err) - // a party should never send anything in response to the "done" message - require.False(fc.t, aDone && !bDone, "A is done but B after that is not") - // dumpRangeMessages(fc.t, fc.resp.msgs, "B %q -> A %q:", storeItemStr(syncA.is), storeItemStr(syncB.is)) - // dumpRangeMessages(fc.t, fc.resp.msgs, "B -> A:") - if aDone && bDone { - require.Empty(fc.t, fc.resp, "got messages from B in response to done msg from A") - break - } - fc.gotoResponse() - nMsg += len(fc.msgs) - nItems += fc.numItems() - aDone, err = syncA.Process(context.Background(), fc) - require.NoError(fc.t, err) - // dumpRangeMessages(fc.t, fc.msgs, "A %q --> B %q:", storeItemStr(syncB.is), storeItemStr(syncA.is)) - // dumpRangeMessages(fc.t, fc.resp.msgs, "A -> B:") - require.False(fc.t, bDone && !aDone, "B is done but A after that is not") - if aDone && bDone { - require.Empty(fc.t, fc.resp, "got messages from A in response to done msg from B") - break - } - } - return i + 1, nMsg, nItems -} - -func runProbe(t *testing.T, from, to *RangeSetReconciler) ProbeResult { - fc := &fakeConduit{t: t} - info, err := from.InitiateProbe(context.Background(), fc) - require.NoError(t, err) - return doRunProbe(fc, from, to, info) -} - -func runBoundedProbe(t *testing.T, from, to *RangeSetReconciler, x, y Ordered) ProbeResult { - fc := &fakeConduit{t: t} - info, err := from.InitiateBoundedProbe(context.Background(), fc, x, y) - require.NoError(t, err) - return doRunProbe(fc, from, to, info) -} - -func doRunProbe(fc *fakeConduit, from, to *RangeSetReconciler, info RangeInfo) ProbeResult { - require.NotEmpty(fc.t, fc.resp, "empty initial round") - fc.gotoResponse() - done, err := to.Process(context.Background(), fc) - require.True(fc.t, done) - require.NoError(fc.t, err) - fc.gotoResponse() - pr, err := from.HandleProbeResponse(fc, info) - require.NoError(fc.t, err) - require.Nil(fc.t, fc.resp, "got messages from Probe in response to done msg") - return pr -} - -func TestRangeSync(t *testing.T) { - forTestStores(t, func(t *testing.T, storeFactory storeFactory) { - for _, tc := range []struct { - name string - a, b string - finalA, finalB string - x, y string - countA, countB int - fpA, fpB string - maxRounds [4]int - sim float64 - }{ - { - name: "empty sets", - a: "", - b: "", - finalA: "", - finalB: "", - countA: 0, - countB: 0, - fpA: "", - fpB: "", - maxRounds: [4]int{1, 1, 1, 1}, - sim: 1, - }, - { - name: "empty to non-empty", - a: "", - b: "abcd", - finalA: "abcd", - finalB: "abcd", - countA: 0, - countB: 4, - fpA: "", - fpB: "abcd", - maxRounds: [4]int{2, 2, 2, 2}, - sim: 0, - }, - { - name: "non-empty to empty", - a: "abcd", - b: "", - finalA: "abcd", - finalB: "abcd", - countA: 4, - countB: 0, - fpA: "abcd", - fpB: "", - maxRounds: [4]int{2, 2, 2, 2}, - sim: 0, - }, - { - name: "non-intersecting sets", - a: "ab", - b: "cd", - finalA: "abcd", - finalB: "abcd", - countA: 2, - countB: 2, - fpA: "ab", - fpB: "cd", - maxRounds: [4]int{3, 2, 2, 2}, - sim: 0, - }, - { - name: "intersecting sets", - a: "acdefghijklmn", - b: "bcdopqr", - finalA: "abcdefghijklmnopqr", - finalB: "abcdefghijklmnopqr", - countA: 13, - countB: 7, - fpA: "acdefghijklmn", - fpB: "bcdopqr", - maxRounds: [4]int{4, 4, 3, 3}, - sim: 0.153, - }, - { - name: "bounded reconciliation", - a: "acdefghijklmn", - b: "bcdopqr", - finalA: "abcdefghijklmn", - finalB: "abcdefgopqr", - x: "a", - y: "h", - countA: 6, - countB: 3, - fpA: "acdefg", - fpB: "bcd", - maxRounds: [4]int{3, 3, 2, 2}, - sim: 0.333, - }, - { - name: "bounded reconciliation with rollover", - a: "acdefghijklmn", - b: "bcdopqr", - finalA: "acdefghijklmnopqr", - finalB: "bcdhijklmnopqr", - x: "h", - y: "a", - countA: 7, - countB: 4, - fpA: "hijklmn", - fpB: "opqr", - maxRounds: [4]int{4, 3, 3, 2}, - sim: 0, - }, - { - name: "sync against 1-element set", - a: "bcd", - b: "a", - finalA: "abcd", - finalB: "abcd", - countA: 3, - countB: 1, - fpA: "bcd", - fpB: "a", - maxRounds: [4]int{2, 2, 2, 2}, - sim: 0, - }, - } { - t.Run(tc.name, func(t *testing.T) { - for n, maxSendRange := range []int{1, 2, 3, 4} { - t.Logf("maxSendRange: %d", maxSendRange) - storeA := makeStore(t, storeFactory, tc.a) - disableReAdd(storeA) - syncA := NewRangeSetReconciler(storeA, - WithMaxSendRange(maxSendRange), - WithItemChunkSize(3)) - storeB := makeStore(t, storeFactory, tc.b) - disableReAdd(storeB) - syncB := NewRangeSetReconciler(storeB, - WithMaxSendRange(maxSendRange), - WithItemChunkSize(3)) - - var ( - nRounds int - prBA, prAB ProbeResult - ) - if tc.x == "" { - prBA = runProbe(t, syncB, syncA) - prAB = runProbe(t, syncA, syncB) - nRounds, _, _ = runSync(t, syncA, syncB, tc.maxRounds[n]) - } else { - x := sampleID(tc.x) - y := sampleID(tc.y) - prBA = runBoundedProbe(t, syncB, syncA, x, y) - prAB = runBoundedProbe(t, syncA, syncB, x, y) - nRounds, _, _ = runBoundedSync(t, syncA, syncB, x, y, tc.maxRounds[n]) - } - t.Logf("%s: maxSendRange %d: %d rounds", tc.name, maxSendRange, nRounds) - - require.Equal(t, tc.countA, prBA.Count, "countA") - require.Equal(t, tc.countB, prAB.Count, "countB") - require.Equal(t, tc.fpA, prBA.FP, "fpA") - require.Equal(t, tc.fpB, prAB.FP, "fpB") - require.Equal(t, tc.finalA, storeItemStr(storeA), "finalA") - require.Equal(t, tc.finalB, storeItemStr(storeB), "finalB") - require.InDelta(t, tc.sim, prAB.Sim, 0.01, "prAB.Sim") - require.InDelta(t, tc.sim, prBA.Sim, 0.01, "prBA.Sim") - } - }) - } - }) -} - -func TestRandomSync(t *testing.T) { - forTestStores(t, func(t *testing.T, storeFactory storeFactory) { - var bytesA, bytesB []byte - defer func() { - if t.Failed() { - t.Logf("Random sync failed: %q <-> %q", bytesA, bytesB) - } - }() - for i := 0; i < 1000; i++ { - var chars []byte - for c := byte(33); c < 127; c++ { - chars = append(chars, c) - } - - bytesA = append([]byte(nil), chars...) - rand.Shuffle(len(bytesA), func(i, j int) { - bytesA[i], bytesA[j] = bytesA[j], bytesA[i] - }) - bytesA = bytesA[:rand.Intn(len(bytesA))] - storeA := makeStore(t, storeFactory, string(bytesA)) - - bytesB = append([]byte(nil), chars...) - rand.Shuffle(len(bytesB), func(i, j int) { - bytesB[i], bytesB[j] = bytesB[j], bytesB[i] - }) - bytesB = bytesB[:rand.Intn(len(bytesB))] - storeB := makeStore(t, storeFactory, string(bytesB)) - - keySet := make(map[byte]struct{}) - for _, c := range append(bytesA, bytesB...) { - keySet[byte(c)] = struct{}{} - } - - expectedSet := maps.Keys(keySet) - slices.Sort(expectedSet) - - maxSendRange := rand.Intn(16) + 1 - syncA := NewRangeSetReconciler(storeA, - WithMaxSendRange(maxSendRange), - WithItemChunkSize(3)) - syncB := NewRangeSetReconciler(storeB, - WithMaxSendRange(maxSendRange), - WithItemChunkSize(3)) - - runSync(t, syncA, syncB, max(len(expectedSet), 2)) // FIXME: less rounds! - // t.Logf("maxSendRange %d a %d b %d n %d", maxSendRange, len(bytesA), len(bytesB), n) - require.Equal(t, storeItemStr(storeA), storeItemStr(storeB)) - require.Equal(t, string(expectedSet), storeItemStr(storeA), - "expected set for %q<->%q", bytesA, bytesB) - } - }) -} - -// TBD: make sure that requests with MessageTypeDone are never -// answered!!! -// TBD: use logger for verbose logging (messages) -// TBD: in fakeConduit -- check item count against the iterator in -// SendItems / SendItemsOnly!! -// TBD: record interaction using golden master in testRangeSync, for -// both probe and sync, together with N of rounds / msgs / items -// and don't check max rounds diff --git a/sync2/hashsync/sync_tree.go b/sync2/hashsync/sync_tree.go deleted file mode 100644 index 0ab32c6b8e..0000000000 --- a/sync2/hashsync/sync_tree.go +++ /dev/null @@ -1,903 +0,0 @@ -// TBD: add paper ref -package hashsync - -import ( - "fmt" - "io" - "reflect" - "slices" - "strings" - "sync" -) - -type Ordered interface { - Compare(other any) int -} - -type LowerBound struct{} - -var _ Ordered = LowerBound{} - -func (vb LowerBound) Compare(x any) int { return -1 } - -type UpperBound struct{} - -var _ Ordered = UpperBound{} - -func (vb UpperBound) Compare(x any) int { return 1 } - -type FingerprintPredicate func(fp any) bool - -func (fpred FingerprintPredicate) Match(y any) bool { - return fpred != nil && fpred(y) -} - -type SyncTree interface { - // Make a copy of the tree. The copy shares the structure with this tree but all - // its nodes are copy-on-write, so any changes in the copied tree do not affect - // this one and are safe to perform in another goroutine. The copy operation is - // O(n) where n is the number of nodes added to this tree since its creation via - // either NewSyncTree function or this Copy method, or the last call of this Copy - // method for this tree, whichever occurs last. The call to Copy is thread-safe. - Copy() SyncTree - Fingerprint() any - Add(k Ordered) - Set(k Ordered, v any) - Lookup(k Ordered) (any, bool) - Min() SyncTreePointer - RangeFingerprint(ptr SyncTreePointer, start, end Ordered, stop FingerprintPredicate) (fp any, startNode, endNode SyncTreePointer) - Dump() string -} - -func SyncTreeFromSortedSlice[T Ordered](m Monoid, items []T) SyncTree { - s := make([]Ordered, len(items)) - for n, item := range items { - s[n] = item - } - st := NewSyncTree(m).(*syncTree) - st.root = st.buildFromSortedSlice(s) - return st -} - -func SyncTreeFromSlice[T Ordered](m Monoid, items []T) SyncTree { - sorted := make([]T, len(items)) - copy(sorted, items) - slices.SortFunc(sorted, func(a, b T) int { - return a.Compare(b) - }) - return SyncTreeFromSortedSlice(m, items) -} - -type SyncTreePointer interface { - Equal(other SyncTreePointer) bool - Key() Ordered - Value() any - Prev() - Next() - Clone() SyncTreePointer -} - -type flags uint8 - -const ( - // flagBlack indicates a black node. If it is not set, the - // node is red, which is the default for newly created nodes - flagBlack flags = 1 - // flagCloned indicates a node that is only present in this - // tree and not in any of its copies, thus permitting - // modification of this node without cloning it. When the tree - // is copied, flagCloned is cleared on all of its nodes. - flagCloned flags = 2 -) - -type dir uint8 - -const ( - left dir = 0 - right dir = 1 -) - -func (d dir) flip() dir { return d ^ 1 } - -func (d dir) String() string { - switch d { - case left: - return "left" - case right: - return "right" - default: - return fmt.Sprintf("", d) - } -} - -const initialParentStackSize = 32 - -type syncTreePointer struct { - parentStack []*syncTreeNode - node *syncTreeNode -} - -var _ SyncTreePointer = &syncTreePointer{} - -func (p *syncTreePointer) clone() *syncTreePointer { - // TODO: copy node stack - r := &syncTreePointer{ - parentStack: make([]*syncTreeNode, len(p.parentStack), cap(p.parentStack)), - node: p.node, - } - copy(r.parentStack, p.parentStack) - return r -} - -func (p *syncTreePointer) parent() { - n := len(p.parentStack) - if n == 0 { - p.node = nil - } else { - n-- - p.node = p.parentStack[n] - p.parentStack = p.parentStack[:n] - } -} - -func (p *syncTreePointer) left() { - if p.node != nil { - p.parentStack = append(p.parentStack, p.node) - p.node = p.node.left - } -} - -func (p *syncTreePointer) right() { - if p.node != nil { - p.parentStack = append(p.parentStack, p.node) - p.node = p.node.right - } -} - -func (p *syncTreePointer) min() { - for { - switch { - case p.node == nil || p.node.left == nil: - return - default: - p.left() - } - } -} - -func (p *syncTreePointer) max() { - for { - switch { - case p.node == nil || p.node.right == nil: - return - default: - p.right() - } - } -} - -func (p *syncTreePointer) Equal(other SyncTreePointer) bool { - if other == nil { - return p.node == nil - } - return p.node == other.(*syncTreePointer).node -} - -func (p *syncTreePointer) Prev() { - switch { - case p.node == nil: - case p.node.left != nil: - p.left() - p.max() - default: - oldNode := p.node - for { - p.parent() - if p.node == nil || oldNode != p.node.left { - return - } - oldNode = p.node - } - } -} - -func (p *syncTreePointer) Next() { - switch { - case p.node == nil: - case p.node.right != nil: - p.right() - p.min() - default: - oldNode := p.node - for { - p.parent() - if p.node == nil || oldNode != p.node.right { - return - } - oldNode = p.node - } - } -} - -func (p *syncTreePointer) Clone() SyncTreePointer { - return &syncTreePointer{ - parentStack: slices.Clone(p.parentStack), - node: p.node, - } -} - -func (p *syncTreePointer) Key() Ordered { - if p.node == nil { - return nil - } - return p.node.key -} - -func (p *syncTreePointer) Value() any { - if p.node == nil { - return nil - } - return p.node.value -} - -type syncTreeNode struct { - left *syncTreeNode - right *syncTreeNode - key Ordered - value any - max Ordered - fingerprint any - flags flags -} - -func (sn *syncTreeNode) red() bool { - return sn != nil && (sn.flags&flagBlack) == 0 -} - -func (sn *syncTreeNode) black() bool { - return sn == nil || (sn.flags&flagBlack) != 0 -} - -func (sn *syncTreeNode) child(dir dir) *syncTreeNode { - if sn == nil { - return nil - } - if dir == left { - return sn.left - } - return sn.right -} - -func (sn *syncTreeNode) Key() Ordered { return sn.key } - -func (sn *syncTreeNode) dump(w io.Writer, indent int) { - indentStr := strings.Repeat(" ", indent) - fmt.Fprintf(w, "%skey: %v\n", indentStr, sn.key) - fmt.Fprintf(w, "%smax: %v\n", indentStr, sn.max) - fmt.Fprintf(w, "%sfp: %v\n", indentStr, sn.fingerprint) - color := "red" - if sn.black() { - color = "black" - } - fmt.Fprintf(w, "%scolor: %v\n", indentStr, color) - if sn.left != nil { - fmt.Fprintf(w, "%sleft:\n", indentStr) - sn.left.dump(w, indent+1) - if sn.left.key.Compare(sn.key) >= 0 { - fmt.Fprintf(w, "%sERROR: left key >= parent key\n", indentStr) - } - } - if sn.right != nil { - fmt.Fprintf(w, "%sright:\n", indentStr) - sn.right.dump(w, indent+1) - if sn.right.key.Compare(sn.key) <= 0 { - fmt.Fprintf(w, "%sERROR: right key <= parent key\n", indentStr) - } - } -} - -func (sn *syncTreeNode) dumpSubtree() string { - var sb strings.Builder - sn.dump(&sb, 0) - return sb.String() -} - -// cleanNodes removed flagCloned from all of the nodes in the subtree, -// so that it can be used in further cloned trees. -// A non-cloned node cannot have any cloned children, so the function -// stops the recursion at any non-cloned node. -func (sn *syncTreeNode) cleanCloned() { - if sn == nil || sn.flags&flagCloned == 0 { - return - } - sn.flags &^= flagCloned - sn.left.cleanCloned() - sn.right.cleanCloned() -} - -type syncTree struct { - rootMtx sync.Mutex - m Monoid - root *syncTreeNode - cachedMinPtr *syncTreePointer - cachedMaxPtr *syncTreePointer -} - -func NewSyncTree(m Monoid) SyncTree { - return &syncTree{m: m} -} - -func (st *syncTree) Copy() SyncTree { - st.rootMtx.Lock() - defer st.rootMtx.Unlock() - // Clean flagCloned from any nodes created specifically for - // this tree. This will mean they will have to be re-cloned if - // they need to be changed again. - st.root.cleanCloned() - // Don't reuse cachedMinPtr / cachedMaxPtr for the cloned - // tree to be on the safe side - return &syncTree{ - m: st.m, - root: st.root, - } -} - -func (st *syncTree) rootPtr() *syncTreePointer { - return &syncTreePointer{ - parentStack: make([]*syncTreeNode, 0, initialParentStackSize), - node: st.root, - } -} - -func (st *syncTree) ensureCloned(sn *syncTreeNode) *syncTreeNode { - if sn.flags&flagCloned != 0 { - return sn - } - cloned := *sn - cloned.flags |= flagCloned - return &cloned -} - -func (st *syncTree) setChild(sn *syncTreeNode, dir dir, child *syncTreeNode) *syncTreeNode { - if sn == nil { - panic("setChild for a nil node") - } - if sn.child(dir) == child { - return sn - } - sn = st.ensureCloned(sn) - if dir == left { - sn.left = child - } else { - sn.right = child - } - return sn -} - -func (st *syncTree) flip(sn *syncTreeNode) *syncTreeNode { - if sn.left == nil || sn.right == nil { - panic("can't flip color with one or more nil children") - } - - left := st.ensureCloned(sn.left) - right := st.ensureCloned(sn.right) - sn = st.ensureCloned(sn) - sn.left = left - sn.right = right - - sn.flags ^= flagBlack - left.flags ^= flagBlack - right.flags ^= flagBlack - return sn -} - -func (st *syncTree) Min() SyncTreePointer { - if st.root == nil { - return nil - } - if st.cachedMinPtr == nil { - st.cachedMinPtr = st.rootPtr() - st.cachedMinPtr.min() - } - if st.cachedMinPtr.node == nil { - panic("BUG: no minNode in a non-empty tree") - } - return st.cachedMinPtr.clone() -} - -func (st *syncTree) Fingerprint() any { - if st.root == nil { - return st.m.Identity() - } - return st.root.fingerprint -} - -func (st *syncTree) newNode(k Ordered, v any) *syncTreeNode { - return &syncTreeNode{ - key: k, - value: v, - max: k, - fingerprint: st.m.Fingerprint(k), - } -} - -func (st *syncTree) buildFromSortedSlice(s []Ordered) *syncTreeNode { - switch len(s) { - case 0: - return nil - case 1: - return st.newNode(s[0], nil) - } - middle := len(s) / 2 - node := st.newNode(s[middle], nil) - node.left = st.buildFromSortedSlice(s[:middle]) - node.right = st.buildFromSortedSlice(s[middle+1:]) - if node.left != nil { - node.fingerprint = st.m.Op(node.left.fingerprint, node.fingerprint) - } - if node.right != nil { - node.fingerprint = st.m.Op(node.fingerprint, node.right.fingerprint) - node.max = node.right.max - } - return node -} - -func (st *syncTree) safeFingerprint(sn *syncTreeNode) any { - if sn == nil { - return st.m.Identity() - } - return sn.fingerprint -} - -func (st *syncTree) updateFingerprintAndMax(sn *syncTreeNode) { - fp := st.m.Op(st.safeFingerprint(sn.left), st.m.Fingerprint(sn.key)) - fp = st.m.Op(fp, st.safeFingerprint(sn.right)) - newMax := sn.key - if sn.right != nil { - newMax = sn.right.max - } - if sn.flags&flagCloned == 0 && - (!reflect.DeepEqual(sn.fingerprint, fp) || sn.max.Compare(newMax) != 0) { - panic("BUG: updating fingerprint/max for a non-cloned node") - } - sn.fingerprint = fp - sn.max = newMax -} - -func (st *syncTree) rotate(sn *syncTreeNode, d dir) *syncTreeNode { - // sn.verify() - - rd := d.flip() - tmp := sn.child(rd) - if tmp == nil { - panic("BUG: nil parent after rotate") - } - // fmt.Fprintf(os.Stderr, "QQQQQ: rotate %s (child at %s is %s): subtree:\n%s\n", - // d, rd, tmp.key, sn.dumpSubtree()) - sn = st.setChild(sn, rd, tmp.child(d)) - tmp = st.setChild(tmp, d, sn) - - // copy node color to the tmp - tmp.flags = (tmp.flags &^ flagBlack) | (sn.flags & flagBlack) - sn.flags &^= flagBlack // set to red - - // it's important to update sn first as it may be the new right child of - // tmp, and we need to update tmp.max too - st.updateFingerprintAndMax(sn) - st.updateFingerprintAndMax(tmp) - - return tmp -} - -func (st *syncTree) doubleRotate(sn *syncTreeNode, d dir) *syncTreeNode { - rd := d.flip() - sn = st.setChild(sn, rd, st.rotate(sn.child(rd), rd)) - return st.rotate(sn, d) -} - -func (st *syncTree) Add(k Ordered) { - st.add(k, nil, false) -} - -func (st *syncTree) Set(k Ordered, v any) { - st.add(k, v, true) -} - -func (st *syncTree) add(k Ordered, v any, set bool) { - st.rootMtx.Lock() - defer st.rootMtx.Unlock() - st.root = st.insert(st.root, k, v, true, set) - if st.root.flags&flagBlack == 0 { - st.root = st.ensureCloned(st.root) - st.root.flags |= flagBlack - } -} - -func (st *syncTree) insert(sn *syncTreeNode, k Ordered, v any, rb, set bool) *syncTreeNode { - // simplified insert implementation idea from - // https://zarif98sjs.github.io/blog/blog/redblacktree/ - if sn == nil { - sn = st.newNode(k, v) - // the new node is not really "cloned", but at this point it's - // only present in this tree so we can safely modify it - // without allocating new nodes - sn.flags |= flagCloned - // when the tree is being modified, cached min/max ptrs are no longer valid - st.cachedMinPtr = nil - st.cachedMaxPtr = nil - return sn - } - c := k.Compare(sn.key) - if c == 0 { - if v != sn.value { - sn = st.ensureCloned(sn) - sn.value = v - } - return sn - } - d := left - if c > 0 { - d = right - } - oldChild := sn.child(d) - newChild := st.insert(oldChild, k, v, rb, set) - sn = st.setChild(sn, d, newChild) - updateFP := true - if rb { - // non-red-black insert is used for testing - sn, updateFP = st.insertFixup(sn, d, oldChild != newChild) - } - if updateFP { - st.updateFingerprintAndMax(sn) - } - return sn -} - -// insertFixup fixes a subtree after insert according to Red-Black tree rules. -// It returns the updated node and a boolean indicating whether the fingerprint/max -// update is needed. The latter is NOT the case -func (st *syncTree) insertFixup(sn *syncTreeNode, d dir, updateFP bool) (*syncTreeNode, bool) { - child := sn.child(d) - rd := d.flip() - switch { - case child.black(): - return sn, true - case sn.child(rd).red(): - // both children of sn are red => any child has 2 reds in a row - // (LL LR RR RL) => flip colors - if child.child(d).red() || child.child(rd).red() { - return st.flip(sn), true - } - return sn, true - case child.child(d).red(): - // another child of sn is black - // any child has 2 reds in a row (LL RR) => rotate - // rotate will update fingerprint of sn and the node - // that replaces it - return st.rotate(sn, rd), updateFP - case child.child(rd).red(): - // another child of sn is black - // any child has 2 reds in a row (LR RL) => align first, then rotate - // doubleRotate will update fingerprint of sn and the node - // that replaces it - return st.doubleRotate(sn, rd), updateFP - default: - return sn, true - } -} - -func (st *syncTree) Lookup(k Ordered) (any, bool) { - // TODO: lookups shouldn't cause any allocation! - ptr := st.rootPtr() - if !st.findGTENode(ptr, k) || ptr.node == nil || ptr.Key().Compare(k) != 0 { - return nil, false - } - return ptr.Value(), true -} - -func (st *syncTree) findGTENode(ptr *syncTreePointer, x Ordered) bool { - for { - switch { - case ptr.node == nil: - return false - case x.Compare(ptr.node.key) == 0: - // Exact match - return true - case x.Compare(ptr.node.max) > 0: - // All of this subtree is below v, maybe we can have - // some luck with the parent node - ptr.parent() - st.findGTENode(ptr, x) - case x.Compare(ptr.node.key) >= 0: - // We're still below x (or at x, but allowEqual is - // false), but given that we checked Max and saw that - // this subtree has some keys that are greater than - // or equal to x, we can find them on the right - if ptr.node.right == nil { - // sn.Max lied to us - // TODO: QQQQQ: this bug is being hit - panic("BUG: SyncTreeNode: x > sn.Max but no right branch") - } - // Avoid endless recursion in case of a bug - if x.Compare(ptr.node.right.max) > 0 { - // TODO: QQQQQ: this bug is being hit - panic("BUG: SyncTreeNode: inconsistent Max on the right branch") - } - ptr.right() - case ptr.node.left == nil || x.Compare(ptr.node.left.max) > 0: - // The current node's key is greater than x and the - // left branch is either empty or fully below x, so - // the current node is what we were looking for - return true - default: - // Some keys on the left branch are greater or equal - // than x accordingto sn.Left.Max - ptr.left() - } - } -} - -func (st *syncTree) rangeFingerprint(preceding SyncTreePointer, start, end Ordered, stop FingerprintPredicate) (fp any, startPtr, endPtr *syncTreePointer) { - if st.root == nil { - return st.m.Identity(), nil, nil - } - var ptr *syncTreePointer - if preceding == nil { - ptr = st.rootPtr() - } else { - ptr = preceding.(*syncTreePointer) - } - - minPtr := st.Min().(*syncTreePointer) - acc := st.m.Identity() - haveGTE := st.findGTENode(ptr, start) - startPtr = ptr.clone() - switch { - case start.Compare(end) >= 0: - // rollover range, which includes the case start == end - // this includes 2 subranges: - // [start, max_element] and [min_element, end) - var stopped bool - if haveGTE { - acc, stopped = st.aggregateUntil(ptr, acc, start, UpperBound{}, stop) - } - - if !stopped && end.Compare(minPtr.Key()) > 0 { - ptr = minPtr.clone() - acc, _ = st.aggregateUntil(ptr, acc, LowerBound{}, end, stop) - } - case haveGTE: - // normal range, that is, start < end - acc, _ = st.aggregateUntil(ptr, st.m.Identity(), start, end, stop) - } - - if startPtr.node == nil { - startPtr = minPtr.clone() - } - if ptr.node == nil { - ptr = minPtr.clone() - } - - return acc, startPtr, ptr -} - -func (st *syncTree) RangeFingerprint(ptr SyncTreePointer, start, end Ordered, stop FingerprintPredicate) (fp any, startNode, endNode SyncTreePointer) { - fp, startPtr, endPtr := st.rangeFingerprint(ptr, start, end, stop) - switch { - case startPtr == nil && endPtr == nil: - // avoid wrapping nil in SyncTreePointer interface - return fp, nil, nil - case startPtr == nil || endPtr == nil: - panic("BUG: can't have nil node just on one end") - default: - return fp, startPtr, endPtr - } -} - -func (st *syncTree) aggregateUntil(ptr *syncTreePointer, acc any, start, end Ordered, stop FingerprintPredicate) (fp any, stopped bool) { - acc, stopped = st.aggregateUp(ptr, acc, start, end, stop) - if ptr.node == nil || end.Compare(ptr.node.key) <= 0 || stopped { - return acc, stopped - } - - // fmt.Fprintf(os.Stderr, "QQQQQ: from aggregateUp: acc %q; ptr.node %q\n", acc, ptr.node.key) - f := st.m.Op(acc, st.m.Fingerprint(ptr.node.key)) - if stop.Match(f) { - return acc, true - } - ptr.right() - return st.aggregateDown(ptr, f, end, stop) -} - -// aggregateUp ascends from the left (lower) end of the range towards the LCA -// (lowest common ancestor) of nodes within the range [start,end). Instead of -// descending from the root node, the LCA is determined by the way of checking -// whether the stored max subtree key is below or at the end or not, saving -// some extra tree traversal when processing the ascending ranges. -// On the way up, if the current node is within the range, we include the right -// subtree in the aggregation using its saved fingerprint, as it is guaranteed -// to lie with the range. When we happen to go up from the right branch, we can -// only reach a predecessor node that lies below the start, and in this case we -// don't include the right subtree in the aggregation to avoid aggregating the -// same subset of nodes twice. -// If stop function is passed, we find the node on which it returns true -// for the fingerprint accumulated between start and that node, if the target -// node is somewhere to the left from the LCA. -func (st *syncTree) aggregateUp(ptr *syncTreePointer, acc any, start, end Ordered, stop FingerprintPredicate) (fp any, stopped bool) { - for { - switch { - case ptr.node == nil: - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: null node\n") - return acc, false - case stop.Match(acc): - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: stop: node %v acc %v\n", sn.key, acc) - ptr.Prev() - return acc, true - case end.Compare(ptr.node.max) <= 0: - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: LCA: node %v acc %v\n", sn.key, acc) - // This node is a the LCA, the starting point for AggregateDown - return acc, false - case start.Compare(ptr.node.key) <= 0: - // This node is within the target range - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: in-range node %v acc %v\n", sn.key, acc) - f := st.m.Op(acc, st.m.Fingerprint(ptr.node.key)) - if stop.Match(f) { - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: stop at the own node %v acc %v\n", sn.key, acc) - return acc, true - } - f1 := st.m.Op(f, st.safeFingerprint(ptr.node.right)) - if stop.Match(f1) { - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: right subtree matches node %v acc %v f1 %v\n", sn.key, acc, f1) - // The target node is somewhere in the right subtree - if ptr.node.right == nil { - panic("BUG: nil right child with non-identity fingerprint") - } - ptr.right() - acc := st.boundedAggregate(ptr, f, stop) - if ptr.node == nil { - panic("BUG: aggregateUp: bad subtree fingerprint on the right branch") - } - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: right subtree: node %v acc %v\n", node.key, acc) - return acc, true - } else { - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateUp: no right subtree match: node %v acc %v f1 %v\n", sn.key, acc, f1) - acc = f1 - } - } - ptr.parent() - } -} - -// aggregateDown descends from the LCA (lowest common ancestor) of nodes within -// the range ending at the 'end'. On the way down, the unvisited left subtrees -// are guaranteed to lie within the range, so they're included into the -// aggregation using their saved fingerprint. -// If stop function is passed, we find the node on which it returns true -// for the fingerprint accumulated between start and that node -func (st *syncTree) aggregateDown(ptr *syncTreePointer, acc any, end Ordered, stop FingerprintPredicate) (fp any, stopped bool) { - for { - switch { - case ptr.node == nil: - // fmt.Fprintf(os.Stderr, "QQQQQ: sn == nil\n") - return acc, false - case stop.Match(acc): - // fmt.Fprintf(os.Stderr, "QQQQQ: stop on node\n") - ptr.Prev() - return acc, true - case end.Compare(ptr.node.key) > 0: - // fmt.Fprintf(os.Stderr, "QQQQQ: within the range\n") - // We're within the range but there also may be nodes - // within the range to the right. The left branch is - // fully within the range - f := st.m.Op(acc, st.safeFingerprint(ptr.node.left)) - if stop.Match(f) { - // fmt.Fprintf(os.Stderr, "QQQQQ: left subtree covers it\n") - // The target node is somewhere in the left subtree - if ptr.node.left == nil { - panic("BUG: aggregateDown: nil left child with non-identity fingerprint") - } - ptr.left() - return st.boundedAggregate(ptr, acc, stop), true - } - f1 := st.m.Op(f, st.m.Fingerprint(ptr.node.key)) - if stop.Match(f1) { - // fmt.Fprintf(os.Stderr, "QQQQQ: stop at the node, prev %#v\n", node.prev()) - return f, true - } else { - acc = f1 - } - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateDown on the right\n") - ptr.right() - case ptr.node.left == nil || end.Compare(ptr.node.left.max) > 0: - // fmt.Fprintf(os.Stderr, "QQQQQ: node covers the range\n") - // Found the rightmost bounding node - f := st.m.Op(acc, st.safeFingerprint(ptr.node.left)) - if stop.Match(f) { - // The target node is somewhere in the left subtree - if ptr.node.left == nil { - panic("BUG: aggregateDown: nil left child with non-identity fingerprint") - } - // XXXXX fixme - ptr.left() - return st.boundedAggregate(ptr, acc, stop), true - } - return f, false - default: - // fmt.Fprintf(os.Stderr, "QQQQQ: aggregateDown: going further down\n") - // We're too far to the right, outside the range - ptr.left() - } - } -} - -func (st *syncTree) boundedAggregate(ptr *syncTreePointer, acc any, stop FingerprintPredicate) any { - for { - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: node %v, acc %v\n", sn.key, acc) - if ptr.node == nil { - return acc - } - - // If we don't need to stop, or if the stop point is somewhere after - // this subtree, we can just use the pre-calculated subtree fingerprint - if f := st.m.Op(acc, ptr.node.fingerprint); !stop.Match(f) { - return f - } - - // This function is not supposed to be called with acc already matching - // the stop condition - if stop.Match(acc) { - panic("BUG: boundedAggregate: initial fingerprint is matched before the first node") - } - - if ptr.node.left != nil { - // See if we can skip recursion on the left branch - f := st.m.Op(acc, ptr.node.left.fingerprint) - if !stop.Match(f) { - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: sn Left non-nil and no-stop %v, f %v, left fingerprint %v\n", sn.key, f, sn.Left.Fingerprint) - acc = f - } else { - // The target node must be contained in the left subtree - ptr.left() - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: sn Left non-nil and stop %v, new node %v, acc %v\n", sn.key, node.key, acc) - continue - } - } - - f := st.m.Op(acc, st.m.Fingerprint(ptr.node.key)) - if stop.Match(f) { - return acc - } - acc = f - - if ptr.node.right != nil { - f1 := st.m.Op(f, ptr.node.right.fingerprint) - if !stop.Match(f1) { - // The right branch is still below the target fingerprint - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: sn Right non-nil and no-stop %v, acc %v\n", sn.key, acc) - acc = f1 - } else { - // The target node must be contained in the right subtree - acc = f - ptr.right() - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: sn Right non-nil and stop %v, new node %v, acc %v\n", sn.key, node.key, acc) - continue - } - } - // fmt.Fprintf(os.Stderr, "QQQQQ: boundedAggregate: %v -- return acc %v\n", sn.key, acc) - return acc - } -} - -func (st *syncTree) Dump() string { - if st.root == nil { - return "" - } - var sb strings.Builder - st.root.dump(&sb, 0) - return sb.String() -} - -// TODO: use sync.Pool for node alloc -// see also: -// https://www.akshaydeo.com/blog/2017/12/23/How-did-I-improve-latency-by-700-percent-using-syncPool/ -// so may need refcounting diff --git a/sync2/hashsync/sync_tree_store.go b/sync2/hashsync/sync_tree_store.go deleted file mode 100644 index 2c5cab0e01..0000000000 --- a/sync2/hashsync/sync_tree_store.go +++ /dev/null @@ -1,175 +0,0 @@ -package hashsync - -import ( - "context" - "errors" - "time" -) - -type syncTreeIterator struct { - st SyncTree - ptr SyncTreePointer -} - -var _ Iterator = &syncTreeIterator{} - -func (it *syncTreeIterator) Equal(other Iterator) bool { - o := other.(*syncTreeIterator) - if it.st != o.st { - panic("comparing iterators from different SyncTreeStore") - } - return it.ptr.Equal(o.ptr) -} - -func (it *syncTreeIterator) Key() (Ordered, error) { - return it.ptr.Key(), nil -} - -func (it *syncTreeIterator) Next() error { - it.ptr.Next() - if it.ptr.Key() == nil { - it.ptr = it.st.Min() - } - return nil -} - -func (it *syncTreeIterator) Clone() Iterator { - return &syncTreeIterator{ - st: it.st, - ptr: it.ptr.Clone(), - } -} - -type SyncTreeStore struct { - st SyncTree - identity any -} - -var _ ItemStore = &SyncTreeStore{} - -func NewSyncTreeStore(m Monoid) ItemStore { - return &SyncTreeStore{ - st: NewSyncTree(CombineMonoids(m, CountingMonoid{})), - identity: m.Identity(), - } -} - -// Add implements ItemStore. -func (sts *SyncTreeStore) Add(ctx context.Context, k Ordered) error { - sts.st.Set(k, nil) - return nil -} - -func (sts *SyncTreeStore) iter(ptr SyncTreePointer) Iterator { - if ptr == nil { - return nil - } - return &syncTreeIterator{ - st: sts.st, - ptr: ptr, - } -} - -// GetRangeInfo implements ItemStore. -func (sts *SyncTreeStore) GetRangeInfo( - ctx context.Context, - preceding Iterator, - x, y Ordered, - count int, -) (RangeInfo, error) { - if x == nil && y == nil { - it, err := sts.Min(ctx) - if err != nil { - return RangeInfo{}, err - } - if it == nil { - return RangeInfo{ - Fingerprint: sts.identity, - }, nil - } else { - x, err = it.Key() - if err != nil { - return RangeInfo{}, err - } - y = x - } - } else if x == nil || y == nil { - panic("BUG: bad X or Y") - } - var stop FingerprintPredicate - var node SyncTreePointer - if preceding != nil { - p := preceding.(*syncTreeIterator) - if p.st != sts.st { - panic("GetRangeInfo: preceding iterator from a wrong SyncTreeStore") - } - node = p.ptr - } - if count >= 0 { - stop = func(fp any) bool { - return CombinedSecond[int](fp) > count - } - } - fp, startPtr, endPtr := sts.st.RangeFingerprint(node, x, y, stop) - cfp := fp.(CombinedFingerprint) - return RangeInfo{ - Fingerprint: cfp.First, - Count: cfp.Second.(int), - Start: sts.iter(startPtr), - End: sts.iter(endPtr), - }, nil -} - -// SplitRange implements ItemStore. -func (sts *SyncTreeStore) SplitRange( - ctx context.Context, - preceding Iterator, - x, y Ordered, - count int, -) (SplitInfo, error) { - if count <= 0 { - panic("BUG: bad split count") - } - part0, err := sts.GetRangeInfo(ctx, preceding, x, y, count) - if err != nil { - return SplitInfo{}, err - } - if part0.Count == 0 { - return SplitInfo{}, errors.New("can't split empty range") - } - middle, err := part0.End.Key() - if err != nil { - return SplitInfo{}, err - } - part1, err := sts.GetRangeInfo(ctx, part0.End.Clone(), middle, y, -1) - if err != nil { - return SplitInfo{}, err - } - return SplitInfo{ - Parts: [2]RangeInfo{part0, part1}, - Middle: middle, - }, nil -} - -// Min implements ItemStore. -func (sts *SyncTreeStore) Min(ctx context.Context) (Iterator, error) { - return sts.iter(sts.st.Min()), nil -} - -// Copy implements ItemStore. -func (sts *SyncTreeStore) Copy() ItemStore { - return &SyncTreeStore{ - st: sts.st.Copy(), - identity: sts.identity, - } -} - -// Has implements ItemStore. -func (sts *SyncTreeStore) Has(ctx context.Context, k Ordered) (bool, error) { - _, found := sts.st.Lookup(k) - return found, nil -} - -func (sts *SyncTreeStore) Recent(ctx context.Context, since time.Time) (Iterator, int, error) { - return nil, 0, nil -} diff --git a/sync2/hashsync/sync_tree_test.go b/sync2/hashsync/sync_tree_test.go deleted file mode 100644 index 1eb60a06dc..0000000000 --- a/sync2/hashsync/sync_tree_test.go +++ /dev/null @@ -1,568 +0,0 @@ -package hashsync - -import ( - "cmp" - "fmt" - "math/rand" - "slices" - "sync" - "testing" - - "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/stretchr/testify/require" -) - -type sampleID string - -var _ Ordered = sampleID("") - -func (s sampleID) String() string { return string(s) } -func (s sampleID) Compare(other any) int { - return cmp.Compare(s, other.(sampleID)) -} - -type sampleMonoid struct{} - -var _ Monoid = sampleMonoid{} - -func (m sampleMonoid) Identity() any { return "" } -func (m sampleMonoid) Op(a, b any) any { return a.(string) + b.(string) } -func (m sampleMonoid) Fingerprint(a any) any { return string(a.(sampleID)) } - -func sampleCountMonoid() Monoid { - return CombineMonoids(sampleMonoid{}, CountingMonoid{}) -} - -func makeStringConcatTree(chars string) SyncTree { - ids := make([]sampleID, len(chars)) - for n, c := range chars { - ids[n] = sampleID(c) - } - return SyncTreeFromSlice(sampleCountMonoid(), ids) -} - -// dumbAdd inserts the node into the tree without trying to maintain the -// red-black properties -func dumbAdd(st SyncTree, k Ordered) { - stree := st.(*syncTree) - stree.root = stree.insert(stree.root, k, nil, false, false) -} - -// makeDumbTree constructs a binary tree by adding the chars one-by-one without -// trying to maintain the red-black properties -func makeDumbTree(chars string) SyncTree { - if len(chars) == 0 { - panic("empty set") - } - st := NewSyncTree(sampleCountMonoid()) - for _, c := range chars { - dumbAdd(st, sampleID(c)) - } - return st -} - -func makeRBTree(chars string) SyncTree { - st := NewSyncTree(sampleCountMonoid()) - for _, c := range chars { - st.Add(sampleID(c)) - } - return st -} - -func gtePos(all, item string) int { - n := slices.IndexFunc([]byte(all), func(v byte) bool { - return v >= item[0] - }) - if n >= 0 { - return n - } - return len(all) -} - -func naiveRange(all, x, y string, stopCount int) (fingerprint, startStr, endStr string) { - if len(all) == 0 { - return "", "", "" - } - allBytes := []byte(all) - slices.Sort(allBytes) - all = string(allBytes) - start := gtePos(all, x) - end := gtePos(all, y) - if x < y { - if stopCount >= 0 && end-start > stopCount { - end = start + stopCount - } - if end < len(all) { - endStr = all[end : end+1] - } else { - endStr = all[0:1] - } - startStr = "" - if start < len(all) { - startStr = all[start : start+1] - } else { - startStr = all[0:1] - } - return all[start:end], startStr, endStr - } else { - r := all[start:] + all[:end] - // fmt.Fprintf(os.Stderr, "QQQQQ: x %q start %d y %q end %d\n", x, start, y, end) - if len(r) == 0 { - // fmt.Fprintf(os.Stderr, "QQQQQ: x %q start %d y %q end %d -- ret start\n", x, start, y, end) - return "", all[0:1], all[0:1] - } - if stopCount >= 0 && len(r) > stopCount { - return r[:stopCount], r[0:1], r[stopCount : stopCount+1] - } - if end < len(all) { - endStr = all[end : end+1] - } else { - endStr = all[0:1] - } - startStr = "" - if len(r) != 0 { - startStr = r[0:1] - } - return r, startStr, endStr - } -} - -func TestEmptyTree(t *testing.T) { - tree := NewSyncTree(sampleCountMonoid()) - rfp1, startNode, endNode := tree.RangeFingerprint(nil, sampleID("a"), sampleID("a"), nil) - require.Nil(t, startNode) - require.Nil(t, endNode) - rfp2, startNode, endNode := tree.RangeFingerprint(nil, sampleID("a"), sampleID("c"), nil) - require.Nil(t, startNode) - require.Nil(t, endNode) - rfp3, startNode, endNode := tree.RangeFingerprint(nil, sampleID("c"), sampleID("a"), nil) - require.Nil(t, startNode) - require.Nil(t, endNode) - for _, fp := range []any{ - tree.Fingerprint(), - rfp1, - rfp2, - rfp3, - } { - require.Equal(t, "", CombinedFirst[string](fp)) - require.Equal(t, 0, CombinedSecond[int](fp)) - } -} - -func testSyncTreeRanges(t *testing.T, tree SyncTree) { - all := "abcdefghijklmnopqr" - for _, tc := range []struct { - all string - x, y sampleID - gte string - fp string - stop int - startAt sampleID - endAt sampleID - }{ - // normal ranges: [x, y) (x -> y) - {x: "0", y: "9", stop: -1, startAt: "a", endAt: "a", fp: ""}, - {x: "x", y: "y", stop: -1, startAt: "a", endAt: "a", fp: ""}, - {x: "a", y: "b", stop: -1, startAt: "a", endAt: "b", fp: "a"}, - {x: "a", y: "d", stop: -1, startAt: "a", endAt: "d", fp: "abc"}, - {x: "f", y: "o", stop: -1, startAt: "f", endAt: "o", fp: "fghijklmn"}, - {x: "0", y: "y", stop: -1, startAt: "a", endAt: "a", fp: "abcdefghijklmnopqr"}, - {x: "a", y: "r", stop: -1, startAt: "a", endAt: "r", fp: "abcdefghijklmnopq"}, - // full rollover range x -> end -> x, or [x, max) + [min, x) - {x: "a", y: "a", stop: -1, startAt: "a", endAt: "a", fp: "abcdefghijklmnopqr"}, - {x: "l", y: "l", stop: -1, startAt: "l", endAt: "l", fp: "lmnopqrabcdefghijk"}, - // rollover ranges: x -> end -> y, or [x, max), [min, y) - {x: "l", y: "f", stop: -1, startAt: "l", endAt: "f", fp: "lmnopqrabcde"}, - {x: "l", y: "0", stop: -1, startAt: "l", endAt: "a", fp: "lmnopqr"}, - {x: "y", y: "f", stop: -1, startAt: "a", endAt: "f", fp: "abcde"}, - {x: "y", y: "x", stop: -1, startAt: "a", endAt: "a", fp: "abcdefghijklmnopqr"}, - {x: "9", y: "0", stop: -1, startAt: "a", endAt: "a", fp: "abcdefghijklmnopqr"}, - {x: "s", y: "a", stop: -1, startAt: "a", endAt: "a", fp: ""}, - // normal ranges + stop - {x: "a", y: "q", stop: 0, startAt: "a", endAt: "a", fp: ""}, - {x: "a", y: "q", stop: 3, startAt: "a", endAt: "d", fp: "abc"}, - {x: "a", y: "q", stop: 5, startAt: "a", endAt: "f", fp: "abcde"}, - {x: "a", y: "q", stop: 7, startAt: "a", endAt: "h", fp: "abcdefg"}, - {x: "a", y: "q", stop: 16, startAt: "a", endAt: "q", fp: "abcdefghijklmnop"}, - // rollover ranges + stop - {x: "l", y: "f", stop: 3, startAt: "l", endAt: "o", fp: "lmn"}, - {x: "l", y: "f", stop: 8, startAt: "l", endAt: "b", fp: "lmnopqra"}, - {x: "y", y: "x", stop: 5, startAt: "a", endAt: "f", fp: "abcde"}, - // full rollover range + stop - {x: "a", y: "a", stop: 3, startAt: "a", endAt: "d", fp: "abc"}, - {x: "a", y: "a", stop: 10, startAt: "a", endAt: "k", fp: "abcdefghij"}, - {x: "l", y: "l", stop: 3, startAt: "l", endAt: "o", fp: "lmn"}, - } { - testName := fmt.Sprintf("%s-%s", tc.x, tc.y) - if tc.stop >= 0 { - testName += fmt.Sprintf("-%d", tc.stop) - } - t.Run(testName, func(t *testing.T) { - rootFP := tree.Fingerprint() - require.Equal(t, all, CombinedFirst[string](rootFP)) - require.Equal(t, len(all), CombinedSecond[int](rootFP)) - stopCounts := []int{tc.stop} - if tc.stop < 0 { - // Stop point at the end of the sequence or beyond it - // should produce the same results as no stop point at all - stopCounts = append(stopCounts, len(all), len(all)*2) - } - for _, stopCount := range stopCounts { - // make sure naiveRangeWithStopCount works as epxected, even - // though it is only used for tests - fpStr, startStr, endStr := naiveRange(all, string(tc.x), string(tc.y), stopCount) - require.Equal(t, tc.fp, fpStr, "naive fingerprint") - require.Equal(t, string(tc.startAt), startStr, "naive fingerprint: startAt") - require.Equal(t, string(tc.endAt), endStr, "naive fingerprint: endAt") - - var stop FingerprintPredicate - if stopCount >= 0 { - // stopCount is not used after this iteration - // so it's ok to have it captured in the closure - stop = func(fp any) bool { - count := CombinedSecond[int](fp) - return count > stopCount - } - } - fp, startNode, endNode := tree.RangeFingerprint(nil, tc.x, tc.y, stop) - require.Equal(t, tc.fp, CombinedFirst[string](fp), "fingerprint") - require.Equal(t, len(tc.fp), CombinedSecond[int](fp), "count") - require.NotNil(t, startNode, "start node") - require.NotNil(t, endNode, "end node") - require.Equal(t, tc.startAt, startNode.Key(), "start node key") - require.Equal(t, tc.endAt, endNode.Key(), "end node key") - } - }) - } -} - -func TestSyncTreeRanges(t *testing.T) { - t.Run("pre-balanced tree", func(t *testing.T) { - testSyncTreeRanges(t, makeStringConcatTree("abcdefghijklmnopqr")) - }) - t.Run("sequential add", func(t *testing.T) { - testSyncTreeRanges(t, makeDumbTree("abcdefghijklmnopqr")) - }) - t.Run("shuffled add", func(t *testing.T) { - testSyncTreeRanges(t, makeDumbTree("lodrnifeqacmbhkgjp")) - }) - t.Run("red-black add", func(t *testing.T) { - testSyncTreeRanges(t, makeRBTree("lodrnifeqacmbhkgjp")) - }) -} - -func TestAscendingRanges(t *testing.T) { - all := "abcdefghijklmnopqr" - tree := makeRBTree(all) - for _, tc := range []struct { - name string - ranges []string - fingerprints []string - }{ - { - name: "normal ranges", - ranges: []string{"ac", "cj", "lq", "qr"}, - fingerprints: []string{"ab", "cdefghi", "lmnop", "q"}, - }, - { - name: "normal and inverted ranges", - ranges: []string{"xc", "cj", "p0"}, - fingerprints: []string{"ab", "cdefghi", "pqr"}, - }, - } { - t.Run(tc.name, func(t *testing.T) { - var fps []string - var node SyncTreePointer - for n, rng := range tc.ranges { - x := sampleID(rng[0]) - y := sampleID(rng[1]) - if n > 0 { - require.NotNil(t, node, "nil starting node for range %s-%s", x, y) - } - fpStr, _, _ := naiveRange(all, string(x), string(y), -1) - var fp any - fp, _, node = tree.RangeFingerprint(node, x, y, nil) - actualFP := CombinedFirst[string](fp) - require.Equal(t, len(actualFP), CombinedSecond[int](fp), "count") - require.Equal(t, fpStr, actualFP) - fps = append(fps, actualFP) - } - require.Equal(t, tc.fingerprints, fps, "fingerprints") - }) - } -} - -func verifyBinaryTree(t *testing.T, sn *syncTreeNode) { - cloned := sn.flags&flagCloned != 0 - if sn.left != nil { - if !cloned { - require.Zero(t, sn.left.flags&flagCloned, "cloned left child of a non-cloned node") - } - require.Negative(t, sn.left.key.Compare(sn.key)) - // not a "real" pointer (no parent stack), just to get max - leftMax := &syncTreePointer{node: sn.left} - leftMax.max() - require.Negative(t, leftMax.Key().Compare(sn.key)) - verifyBinaryTree(t, sn.left) - } - - if sn.right != nil { - if !cloned { - require.Zero(t, sn.right.flags&flagCloned, "cloned right child of a non-cloned node") - } - require.Positive(t, sn.right.key.Compare(sn.key)) - // not a "real" pointer (no parent stack), just to get min - rightMin := &syncTreePointer{node: sn.right} - rightMin.min() - require.Positive(t, rightMin.Key().Compare(sn.key)) - verifyBinaryTree(t, sn.right) - } -} - -func verifyRedBlackNode(t *testing.T, sn *syncTreeNode, blackDepth int) int { - if sn == nil { - return blackDepth + 1 - } - if sn.flags&flagBlack == 0 { - if sn.left != nil { - require.Equal(t, flagBlack, sn.left.flags&flagBlack, "left child of a red node is red") - } - if sn.right != nil { - require.Equal(t, flagBlack, sn.right.flags&flagBlack, "right child of a red node is red") - } - } else { - blackDepth++ - } - bdLeft := verifyRedBlackNode(t, sn.left, blackDepth) - bdRight := verifyRedBlackNode(t, sn.right, blackDepth) - require.Equal(t, bdLeft, bdRight, "subtree black depth for node %s", sn.key) - return bdLeft -} - -func verifyRedBlack(t *testing.T, st *syncTree) { - if st.root == nil { - return - } - require.Equal(t, flagBlack, st.root.flags&flagBlack, "root node must be black") - verifyRedBlackNode(t, st.root, 0) -} - -func TestRedBlackTreeInsert(t *testing.T) { - for i := 0; i < 1000; i++ { - tree := NewSyncTree(sampleCountMonoid()) - items := []byte("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz") - count := rand.Intn(len(items)) + 1 - items = items[:count] - shuffled := append([]byte(nil), items...) - rand.Shuffle(len(shuffled), func(i, j int) { - shuffled[i], shuffled[j] = shuffled[j], shuffled[i] - }) - - // items := []byte("0123456789ABCDEFG") - // shuffled := []byte("0678DF1CG5A9324BE") - - trees := make([]SyncTree, len(shuffled)) - treeDumps := make([]string, len(shuffled)) - for i := 0; i < len(shuffled); i++ { - trees[i] = tree.Copy() - treeDumps[i] = tree.Dump() - require.Equal(t, treeDumps[i], trees[i].Dump(), "initial tree dump %d", i) - tree.Add(sampleID(shuffled[i])) - if i >= 3 && i%3 == 0 { - // this shouldn't change anything - trees[i-1].Add(sampleID(shuffled[rand.Intn(i-1)])) - // cloning should not happen b/c no new nodes are inserted - require.Zero(t, trees[i-1].(*syncTree).root.flags&flagCloned) - } - } - - for i := 0; i < len(shuffled); i++ { - require.Equal(t, treeDumps[i], trees[i].Dump(), "tree dump %d after copy", i) - } - - var actualItems []byte - n := 0 - // t.Logf("items: %q", string(items)) - // t.Logf("shuffled: %q", string(shuffled)) - // t.Logf("QQQQQ: tree:\n%s", tree.Dump()) - verifyBinaryTree(t, tree.(*syncTree).root) - verifyRedBlack(t, tree.(*syncTree)) - for ptr := tree.Min(); ptr.Key() != nil; ptr.Next() { - // avoid endless loop due to bugs in the tree impl - require.Less(t, n, len(items)*2, "got much more items than needed: %q -- %q", actualItems, shuffled) - n++ - actualItems = append(actualItems, ptr.Key().(sampleID)[0]) - } - require.Equal(t, items, actualItems) - - fp, startNode, endNode := tree.RangeFingerprint(nil, sampleID(items[0]), sampleID(items[0]), nil) - fpStr := CombinedFirst[string](fp) - require.Equal(t, string(items), fpStr, "fingerprint %q", shuffled) - require.Equal(t, len(fpStr), CombinedSecond[int](fp), "count %q") - require.Equal(t, sampleID(items[0]), startNode.Key(), "startNode") - require.Equal(t, sampleID(items[0]), endNode.Key(), "endNode") - } -} - -type makeTestTreeFunc func(chars string) SyncTree - -func testRandomOrderAndRanges(t *testing.T, mktree makeTestTreeFunc) { - all := "abcdefghijklmnopqr" - for i := 0; i < 1000; i++ { - shuffled := []byte(all) - rand.Shuffle(len(shuffled), func(i, j int) { - shuffled[i], shuffled[j] = shuffled[j], shuffled[i] - }) - tree := mktree(string(shuffled)) - x := sampleID(shuffled[rand.Intn(len(shuffled))]) - y := sampleID(shuffled[rand.Intn(len(shuffled))]) - stopCount := rand.Intn(len(shuffled)+2) - 1 - var stop FingerprintPredicate - if stopCount >= 0 { - stop = func(fp any) bool { - return CombinedSecond[int](fp) > stopCount - } - } - - verify := func() { - expFP, expStart, expEnd := naiveRange(all, string(x), string(y), stopCount) - fp, startNode, endNode := tree.RangeFingerprint(nil, x, y, stop) - - fpStr := CombinedFirst[string](fp) - curCase := fmt.Sprintf("items %q x %q y %q stopCount %d", shuffled, x, y, stopCount) - require.Equal(t, expFP, fpStr, "%s: fingerprint", curCase) - require.Equal(t, len(fpStr), CombinedSecond[int](fp), "%s: count", curCase) - - startStr := "" - if startNode != nil { - startStr = string(startNode.Key().(sampleID)) - } - require.Equal(t, expStart, startStr, "%s: next", curCase) - - endStr := "" - if endNode != nil { - endStr = string(endNode.Key().(sampleID)) - } - require.Equal(t, expEnd, endStr, "%s: next", curCase) - } - verify() - tree1 := tree.Copy() - tree1.Add(sampleID("s")) - tree1.Add(sampleID("t")) - tree1.Add(sampleID("u")) - verify() // the original tree should be unchanged - fp, _, _ := tree1.RangeFingerprint(nil, sampleID("a"), sampleID("a"), nil) - require.Equal(t, "abcdefghijklmnopqrstu", CombinedFirst[string](fp)) - require.Equal(t, len(all)+3, CombinedSecond[int](fp)) - } -} - -func TestRandomOrderAndRanges(t *testing.T) { - t.Run("randomized dumb insert", func(t *testing.T) { - testRandomOrderAndRanges(t, makeDumbTree) - }) - t.Run("red-black tree", func(t *testing.T) { - testRandomOrderAndRanges(t, makeRBTree) - }) -} - -func TestTreeValues(t *testing.T) { - tree := makeRBTree("") - tree.Add(sampleID("a")) - tree.Set(sampleID("b"), 123) - tree.Set(sampleID("d"), 456) - verifyOrig := func() { - v, found := tree.Lookup(sampleID("a")) - require.True(t, found) - require.Nil(t, v) - v, found = tree.Lookup(sampleID("b")) - require.True(t, found) - require.Equal(t, 123, v) - v, found = tree.Lookup(sampleID("c")) - require.False(t, found) - require.Nil(t, v) - v, found = tree.Lookup(sampleID("d")) - require.True(t, found) - require.Equal(t, 456, v) - } - verifyOrig() - - treeDump := tree.Dump() - tree1 := tree.Copy() - - // flagCloned on the root should be cleared after copy - // and not set again by Set b/c the value is the same - tree.Set(sampleID("d"), 456) // nothing changed - require.Zero(t, tree.(*syncTree).root.flags&flagCloned) - - tree1.Set(sampleID("b"), 1234) - tree1.Set(sampleID("c"), 222) - verifyOrig() - require.Equal(t, treeDump, tree.Dump()) - v, found := tree1.Lookup(sampleID("a")) - require.True(t, found) - require.Nil(t, v) - v, found = tree1.Lookup(sampleID("b")) - require.True(t, found) - require.Equal(t, 1234, v) - v, found = tree1.Lookup(sampleID("c")) - require.True(t, found) - require.Equal(t, 222, v) - v, found = tree1.Lookup(sampleID("d")) - require.True(t, found) - require.Equal(t, 456, v) -} - -func TestParallelAddition(t *testing.T) { - for i := 0; i < 10; i++ { - const ( - nInitial = 10000 - nAdd = 1000 - nSets = 100 - ) - srcTree := NewSyncTree(Hash32To12Xor{}) - initialHashes := make([]types.Hash32, nInitial) - for n := range initialHashes { - h := types.RandomHash() - initialHashes[n] = h - srcTree.Add(h) - } - type set struct { - added []types.Hash32 - tree SyncTree - } - sets := make([]*set, nSets) - for n := range sets { - sets[n] = &set{} - } - sets[0].tree = srcTree - var wg sync.WaitGroup - for n, s := range sets { - wg.Add(1) - go func() { - defer wg.Done() - if n > 0 { - s.tree = srcTree.Copy() - } - s.added = make([]types.Hash32, nAdd) - for n := range s.added { - h := types.RandomHash() - s.added[n] = h - s.tree.Add(h) - } - }() - } - wg.Wait() - for _, s := range sets { - items := make(map[types.Hash32]struct{}, nInitial+nAdd) - for ptr := s.tree.Min(); ptr.Key() != nil; ptr.Next() { - items[ptr.Key().(types.Hash32)] = struct{}{} - } - require.GreaterOrEqual(t, len(items), nInitial+nAdd) - for _, k := range s.added { - _, found := items[k] // faster than require.Contains - require.True(t, found) - } - } - } -} diff --git a/sync2/hashsync/wire_helpers.go b/sync2/hashsync/wire_helpers.go deleted file mode 100644 index 4229fc6046..0000000000 --- a/sync2/hashsync/wire_helpers.go +++ /dev/null @@ -1,70 +0,0 @@ -package hashsync - -import ( - "github.com/spacemeshos/go-scale" - "github.com/spacemeshos/go-spacemesh/common/types" -) - -type CompactHash32 struct { - H *types.Hash32 -} - -// DecodeScale implements scale.Decodable. -func (c *CompactHash32) DecodeScale(dec *scale.Decoder) (int, error) { - var h types.Hash32 - b, total, err := scale.DecodeByte(dec) - switch { - case err != nil: - return total, err - case b == 255: - c.H = nil - return total, nil - case b != 0: - n, err := scale.DecodeByteArray(dec, h[:b]) - total += n - if err != nil { - return total, err - } - } - c.H = &h - return total, nil -} - -// EncodeScale implements scale.Encodable. -func (c *CompactHash32) EncodeScale(enc *scale.Encoder) (int, error) { - if c.H == nil { - return scale.EncodeByte(enc, 255) - } - - b := byte(31) - for b = 32; b > 0; b-- { - if c.H[b-1] != 0 { - break - } - } - - total, err := scale.EncodeByte(enc, b) - if b == 0 || err != nil { - return total, err - } - - n, err := scale.EncodeByteArray(enc, c.H[:b]) - total += n - return total, err -} - -func (c *CompactHash32) ToOrdered() Ordered { - if c.H == nil { - return nil - } - return *c.H -} - -func Hash32ToCompact(h types.Hash32) CompactHash32 { - return CompactHash32{H: &h} -} - -func OrderedToCompactHash32(h Ordered) CompactHash32 { - hash := h.(types.Hash32) - return CompactHash32{H: &hash} -} diff --git a/sync2/hashsync/wire_types.go b/sync2/hashsync/wire_types.go deleted file mode 100644 index 824ec8fbb3..0000000000 --- a/sync2/hashsync/wire_types.go +++ /dev/null @@ -1,209 +0,0 @@ -package hashsync - -import ( - "cmp" - "fmt" - "time" - - "github.com/spacemeshos/go-scale" - "github.com/spacemeshos/go-spacemesh/common/types" -) - -//go:generate scalegen - -type Marker struct{} - -func (*Marker) X() Ordered { return nil } -func (*Marker) Y() Ordered { return nil } -func (*Marker) Fingerprint() any { return nil } -func (*Marker) Count() int { return 0 } -func (*Marker) Keys() []Ordered { return nil } -func (*Marker) Since() time.Time { return time.Time{} } - -// DoneMessage is a SyncMessage that denotes the end of the synchronization. -// The peer should stop any further processing after receiving this message. -type DoneMessage struct{ Marker } - -var _ SyncMessage = &DoneMessage{} - -func (*DoneMessage) Type() MessageType { return MessageTypeDone } - -// EndRoundMessage is a SyncMessage that denotes the end of the sync round. -type EndRoundMessage struct{ Marker } - -var _ SyncMessage = &EndRoundMessage{} - -func (*EndRoundMessage) Type() MessageType { return MessageTypeEndRound } - -// EmptySetMessage is a SyncMessage that denotes an empty set, requesting the -// peer to send all of its items -type EmptySetMessage struct{ Marker } - -var _ SyncMessage = &EmptySetMessage{} - -func (*EmptySetMessage) Type() MessageType { return MessageTypeEmptySet } - -// EmptyRangeMessage notifies the peer that it needs to send all of its items in -// the specified range -type EmptyRangeMessage struct { - RangeX, RangeY CompactHash32 -} - -var _ SyncMessage = &EmptyRangeMessage{} - -func (m *EmptyRangeMessage) Type() MessageType { return MessageTypeEmptyRange } -func (m *EmptyRangeMessage) X() Ordered { return m.RangeX.ToOrdered() } -func (m *EmptyRangeMessage) Y() Ordered { return m.RangeY.ToOrdered() } -func (m *EmptyRangeMessage) Fingerprint() any { return nil } -func (m *EmptyRangeMessage) Count() int { return 0 } -func (m *EmptyRangeMessage) Keys() []Ordered { return nil } -func (m *EmptyRangeMessage) Since() time.Time { return time.Time{} } - -// FingerprintMessage contains range fingerprint for comparison against the -// peer's fingerprint of the range with the same bounds [RangeX, RangeY) -type FingerprintMessage struct { - RangeX, RangeY CompactHash32 - RangeFingerprint types.Hash12 - NumItems uint32 -} - -var _ SyncMessage = &FingerprintMessage{} - -func (m *FingerprintMessage) Type() MessageType { return MessageTypeFingerprint } -func (m *FingerprintMessage) X() Ordered { return m.RangeX.ToOrdered() } -func (m *FingerprintMessage) Y() Ordered { return m.RangeY.ToOrdered() } -func (m *FingerprintMessage) Fingerprint() any { return m.RangeFingerprint } -func (m *FingerprintMessage) Count() int { return int(m.NumItems) } -func (m *FingerprintMessage) Keys() []Ordered { return nil } -func (m *FingerprintMessage) Since() time.Time { return time.Time{} } - -// RangeContentsMessage denotes a range for which the set of items has been sent. -// The peer needs to send back any items it has in the same range bounded -// by [RangeX, RangeY) -type RangeContentsMessage struct { - RangeX, RangeY CompactHash32 - NumItems uint32 -} - -var _ SyncMessage = &RangeContentsMessage{} - -func (m *RangeContentsMessage) Type() MessageType { return MessageTypeRangeContents } -func (m *RangeContentsMessage) X() Ordered { return m.RangeX.ToOrdered() } -func (m *RangeContentsMessage) Y() Ordered { return m.RangeY.ToOrdered() } -func (m *RangeContentsMessage) Fingerprint() any { return nil } -func (m *RangeContentsMessage) Count() int { return int(m.NumItems) } -func (m *RangeContentsMessage) Keys() []Ordered { return nil } -func (m *RangeContentsMessage) Since() time.Time { return time.Time{} } - -// ItemBatchMessage denotes a batch of items to be added to the peer's set. -type ItemBatchMessage struct { - ContentKeys []types.Hash32 `scale:"max=1024"` -} - -func (m *ItemBatchMessage) Type() MessageType { return MessageTypeItemBatch } -func (m *ItemBatchMessage) X() Ordered { return nil } -func (m *ItemBatchMessage) Y() Ordered { return nil } -func (m *ItemBatchMessage) Fingerprint() any { return nil } -func (m *ItemBatchMessage) Count() int { return 0 } -func (m *ItemBatchMessage) Keys() []Ordered { - var r []Ordered - for _, k := range m.ContentKeys { - r = append(r, k) - } - return r -} -func (m *ItemBatchMessage) Since() time.Time { return time.Time{} } - -// ProbeMessage requests bounded range fingerprint and count from the peer, -// along with a minhash sample if fingerprints differ -type ProbeMessage struct { - RangeX, RangeY CompactHash32 - RangeFingerprint types.Hash12 - SampleSize uint32 -} - -var _ SyncMessage = &ProbeMessage{} - -func (m *ProbeMessage) Type() MessageType { return MessageTypeProbe } -func (m *ProbeMessage) X() Ordered { return m.RangeX.ToOrdered() } -func (m *ProbeMessage) Y() Ordered { return m.RangeY.ToOrdered() } -func (m *ProbeMessage) Fingerprint() any { return m.RangeFingerprint } -func (m *ProbeMessage) Count() int { return int(m.SampleSize) } -func (m *ProbeMessage) Keys() []Ordered { return nil } -func (m *ProbeMessage) Since() time.Time { return time.Time{} } - -// MinhashSampleItem represents an item of minhash sample subset -type MinhashSampleItem uint32 - -var _ Ordered = MinhashSampleItem(0) - -func (m MinhashSampleItem) String() string { - return fmt.Sprintf("0x%08x", uint32(m)) -} - -// Compare implements Ordered -func (m MinhashSampleItem) Compare(other any) int { - return cmp.Compare(m, other.(MinhashSampleItem)) -} - -// EncodeScale implements scale.Encodable. -func (m MinhashSampleItem) EncodeScale(e *scale.Encoder) (int, error) { - // QQQQQ: FIXME: there's EncodeUint32 (non-compact which is better for hashes) - // but no DecodeUint32 - return scale.EncodeCompact32(e, uint32(m)) -} - -// DecodeScale implements scale.Decodable. -func (m *MinhashSampleItem) DecodeScale(d *scale.Decoder) (int, error) { - v, total, err := scale.DecodeCompact32(d) - *m = MinhashSampleItem(v) - return total, err -} - -// MinhashSampleItemFromHash32 uses lower 32 bits of a Hash32 as a MinhashSampleItem -func MinhashSampleItemFromHash32(h types.Hash32) MinhashSampleItem { - return MinhashSampleItem(uint32(h[28])<<24 + uint32(h[29])<<16 + uint32(h[30])<<8 + uint32(h[31])) -} - -// SampleMessage is a sample of set items -type SampleMessage struct { - RangeX, RangeY CompactHash32 - RangeFingerprint types.Hash12 - NumItems uint32 - // NOTE: max must be in sync with maxSampleSize in hashsync/rangesync.go - Sample []MinhashSampleItem `scale:"max=1000"` -} - -var _ SyncMessage = &SampleMessage{} - -func (m *SampleMessage) Type() MessageType { return MessageTypeSample } -func (m *SampleMessage) X() Ordered { return m.RangeX.ToOrdered() } -func (m *SampleMessage) Y() Ordered { return m.RangeY.ToOrdered() } -func (m *SampleMessage) Fingerprint() any { return m.RangeFingerprint } -func (m *SampleMessage) Count() int { return int(m.NumItems) } -func (m *SampleMessage) Keys() []Ordered { - r := make([]Ordered, len(m.Sample)) - for n, item := range m.Sample { - r[n] = item - } - return r -} -func (m *SampleMessage) Since() time.Time { return time.Time{} } - -// RecentMessage is a SyncMessage that denotes a set of items that have been -// added to the peer's set since the specific point in time. -type RecentMessage struct { - SinceTime uint64 -} - -var _ SyncMessage = &RecentMessage{} - -func (m *RecentMessage) Type() MessageType { return MessageTypeRecent } -func (m *RecentMessage) X() Ordered { return nil } -func (m *RecentMessage) Y() Ordered { return nil } -func (m *RecentMessage) Fingerprint() any { return nil } -func (m *RecentMessage) Count() int { return 0 } -func (m *RecentMessage) Keys() []Ordered { return nil } -func (m *RecentMessage) Since() time.Time { return time.Unix(0, int64(m.SinceTime)) } - -// TODO: don't do scalegen for empty types diff --git a/sync2/hashsync/xorsync.go b/sync2/hashsync/xorsync.go deleted file mode 100644 index e14773a21d..0000000000 --- a/sync2/hashsync/xorsync.go +++ /dev/null @@ -1,59 +0,0 @@ -package hashsync - -import ( - "sync" - - "github.com/zeebo/blake3" - - "github.com/spacemeshos/go-spacemesh/common/types" -) - -// Note: we don't care too much about artificially induced collisions. -// Given that none of the synced hashes are used internally or -// propagated further down the P2P network before the actual contents -// of the objects is received and validated, most an attacker can get -// is partial sync of this node with the attacker node, which doesn't -// pose any serious threat. We could even skip additional hashing -// altogether, but let's make playing the algorithm not too easy. - -type Hash32To12Xor struct{} - -var _ Monoid = Hash32To12Xor{} - -func (m Hash32To12Xor) Identity() any { - return types.Hash12{} -} - -func (m Hash32To12Xor) Op(b, a any) any { - var r types.Hash12 - h1 := a.(types.Hash12) - h2 := b.(types.Hash12) - for n, b := range h1 { - r[n] = b ^ h2[n] - } - return r -} - -var hashPool = &sync.Pool{ - New: func() any { - return blake3.New() - }, -} - -func (m Hash32To12Xor) Fingerprint(v any) any { - // Blake3's New allocates too much memory, - // so we can't just call types.CalcHash12(h[:]) here - // TODO: fix types.CalcHash12() - h := v.(types.Hash32) - var r types.Hash12 - hasher := hashPool.Get().(*blake3.Hasher) - defer func() { - hasher.Reset() - hashPool.Put(hasher) - }() - var hashRes [32]byte - hasher.Write(h[:]) - hasher.Sum(hashRes[:0]) - copy(r[:], hashRes[:]) - return r -} diff --git a/sync2/hashsync/xorsync_test.go b/sync2/hashsync/xorsync_test.go deleted file mode 100644 index 94dca0e53b..0000000000 --- a/sync2/hashsync/xorsync_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package hashsync - -import ( - "context" - "math/rand" - "slices" - "testing" - - "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestHash32To12Xor(t *testing.T) { - var m Hash32To12Xor - require.Equal(t, m.Identity(), m.Op(m.Identity(), m.Identity())) - hash1 := types.CalcHash32([]byte("foo")) - fp1 := m.Fingerprint(hash1) - hash2 := types.CalcHash32([]byte("bar")) - fp2 := m.Fingerprint(hash2) - hash3 := types.CalcHash32([]byte("baz")) - fp3 := m.Fingerprint(hash3) - require.Equal(t, fp1, m.Op(m.Identity(), fp1)) - require.Equal(t, fp2, m.Op(fp2, m.Identity())) - require.NotEqual(t, fp1, fp2) - require.NotEqual(t, fp1, fp3) - require.NotEqual(t, fp1, m.Op(fp1, fp2)) - require.NotEqual(t, fp2, m.Op(fp1, fp2)) - require.NotEqual(t, m.Identity(), m.Op(fp1, fp2)) - require.Equal(t, m.Op(m.Op(fp1, fp2), fp3), m.Op(fp1, m.Op(fp2, fp3))) -} - -type catchTransferTwice struct { - ItemStore - t *testing.T - added map[types.Hash32]bool -} - -func (s *catchTransferTwice) Add(ctx context.Context, k Ordered) error { - h := k.(types.Hash32) - _, found := s.added[h] - assert.False(s.t, found, "hash sent twice") - if err := s.ItemStore.Add(ctx, k); err != nil { - return err - } - if s.added == nil { - s.added = make(map[types.Hash32]bool) - } - s.added[h] = true - return nil -} - -type xorSyncTestConfig struct { - maxSendRange int - numTestHashes int - minNumSpecificA int - maxNumSpecificA int - minNumSpecificB int - maxNumSpecificB int - allowReAdd bool -} - -func verifyXORSync(t *testing.T, cfg xorSyncTestConfig, sync func(storeA, storeB ItemStore, numSpecific int, opts []RangeSetReconcilerOption) bool) { - opts := []RangeSetReconcilerOption{ - WithMaxSendRange(cfg.maxSendRange), - WithMaxDiff(0.1), - } - numSpecificA := rand.Intn(cfg.maxNumSpecificA+1-cfg.minNumSpecificA) + cfg.minNumSpecificA - numSpecificB := rand.Intn(cfg.maxNumSpecificB+1-cfg.minNumSpecificB) + cfg.minNumSpecificB - src := make([]types.Hash32, cfg.numTestHashes) - for n := range src { - src[n] = types.RandomHash() - } - - sliceA := src[:cfg.numTestHashes-numSpecificB] - storeA := NewSyncTreeStore(Hash32To12Xor{}) - for _, h := range sliceA { - require.NoError(t, storeA.Add(context.Background(), h)) - } - if !cfg.allowReAdd { - storeA = &catchTransferTwice{t: t, ItemStore: storeA} - } - - sliceB := append([]types.Hash32(nil), src[:cfg.numTestHashes-numSpecificB-numSpecificA]...) - sliceB = append(sliceB, src[cfg.numTestHashes-numSpecificB:]...) - storeB := NewSyncTreeStore(Hash32To12Xor{}) - for _, h := range sliceB { - require.NoError(t, storeB.Add(context.Background(), h)) - } - if !cfg.allowReAdd { - storeB = &catchTransferTwice{t: t, ItemStore: storeB} - } - - slices.SortFunc(src, func(a, b types.Hash32) int { - return a.Compare(b) - }) - - if sync(storeA, storeB, numSpecificA+numSpecificB, opts) { - itemsA, err := CollectStoreItems[types.Hash32](context.Background(), storeA) - require.NoError(t, err) - itemsB, err := CollectStoreItems[types.Hash32](context.Background(), storeB) - require.NoError(t, err) - require.Equal(t, itemsA, itemsB) - srcKeys := make([]types.Hash32, len(src)) - for n, h := range src { - srcKeys[n] = h - } - require.Equal(t, srcKeys, itemsA) - } -} - -func TestBigSyncHash32(t *testing.T) { - cfg := xorSyncTestConfig{ - maxSendRange: 1, - numTestHashes: 100000, - minNumSpecificA: 4, - maxNumSpecificA: 100, - minNumSpecificB: 4, - maxNumSpecificB: 100, - } - verifyXORSync(t, cfg, func(storeA, storeB ItemStore, numSpecific int, opts []RangeSetReconcilerOption) bool { - syncA := NewRangeSetReconciler(storeA, opts...) - syncB := NewRangeSetReconciler(storeB, opts...) - nRounds, nMsg, nItems := runSync(t, syncA, syncB, 100) - itemCoef := float64(nItems) / float64(numSpecific) - t.Logf("numSpecific: %d, nRounds: %d, nMsg: %d, nItems: %d, itemCoef: %.2f", - numSpecific, nRounds, nMsg, nItems, itemCoef) - return true - }) -} diff --git a/sync2/multipeer/delim.go b/sync2/multipeer/delim.go new file mode 100644 index 0000000000..7127e5b3e6 --- /dev/null +++ b/sync2/multipeer/delim.go @@ -0,0 +1,22 @@ +package multipeer + +import ( + "encoding/binary" + + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +func getDelimiters(numPeers, keyLen, maxDepth int) (h []types.KeyBytes) { + if numPeers < 2 { + return nil + } + mask := uint64(0xffffffffffffffff) << (64 - maxDepth) + inc := (uint64(0x80) << 56) / uint64(numPeers) + h = make([]types.KeyBytes, numPeers-1) + for i, v := 0, uint64(0); i < numPeers-1; i++ { + h[i] = make(types.KeyBytes, keyLen) + v += inc + binary.BigEndian.PutUint64(h[i], (v<<1)&mask) + } + return h +} diff --git a/sync2/multipeer/delim_test.go b/sync2/multipeer/delim_test.go new file mode 100644 index 0000000000..624ef78c65 --- /dev/null +++ b/sync2/multipeer/delim_test.go @@ -0,0 +1,103 @@ +package multipeer + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetDelimiters(t *testing.T) { + for _, tc := range []struct { + numPeers int + keyLen int + maxDepth int + values []string + }{ + { + numPeers: 0, + maxDepth: 64, + keyLen: 32, + values: nil, + }, + { + numPeers: 1, + maxDepth: 64, + keyLen: 32, + values: nil, + }, + { + numPeers: 2, + maxDepth: 64, + keyLen: 32, + values: []string{ + "8000000000000000000000000000000000000000000000000000000000000000", + }, + }, + { + numPeers: 2, + maxDepth: 24, + keyLen: 32, + values: []string{ + "8000000000000000000000000000000000000000000000000000000000000000", + }, + }, + { + numPeers: 3, + maxDepth: 64, + keyLen: 32, + values: []string{ + "5555555555555554000000000000000000000000000000000000000000000000", + "aaaaaaaaaaaaaaa8000000000000000000000000000000000000000000000000", + }, + }, + { + numPeers: 3, + maxDepth: 24, + keyLen: 32, + values: []string{ + "5555550000000000000000000000000000000000000000000000000000000000", + "aaaaaa0000000000000000000000000000000000000000000000000000000000", + }, + }, + { + numPeers: 3, + maxDepth: 4, + keyLen: 32, + values: []string{ + "5000000000000000000000000000000000000000000000000000000000000000", + "a000000000000000000000000000000000000000000000000000000000000000", + }, + }, + { + numPeers: 4, + maxDepth: 64, + keyLen: 32, + values: []string{ + "4000000000000000000000000000000000000000000000000000000000000000", + "8000000000000000000000000000000000000000000000000000000000000000", + "c000000000000000000000000000000000000000000000000000000000000000", + }, + }, + { + numPeers: 4, + maxDepth: 24, + keyLen: 32, + values: []string{ + "4000000000000000000000000000000000000000000000000000000000000000", + "8000000000000000000000000000000000000000000000000000000000000000", + "c000000000000000000000000000000000000000000000000000000000000000", + }, + }, + } { + ds := getDelimiters(tc.numPeers, tc.keyLen, tc.maxDepth) + var hs []string + for _, d := range ds { + hs = append(hs, d.String()) + } + if len(tc.values) == 0 { + require.Empty(t, hs, "%d delimiters", tc.numPeers) + } else { + require.Equal(t, tc.values, hs, "%d delimiters", tc.numPeers) + } + } +} diff --git a/sync2/multipeer/interface.go b/sync2/multipeer/interface.go new file mode 100644 index 0000000000..57e201ca18 --- /dev/null +++ b/sync2/multipeer/interface.go @@ -0,0 +1,46 @@ +package multipeer + +import ( + "context" + "io" + + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +//go:generate mockgen -typed -package=multipeer -destination=./mocks_test.go -source=./interface.go + +type SyncBase interface { + Count(ctx context.Context) (int, error) + Derive(p p2p.Peer) Syncer + Probe(ctx context.Context, p p2p.Peer) (rangesync.ProbeResult, error) + Wait() error +} + +type Syncer interface { + Peer() p2p.Peer + Sync(ctx context.Context, x, y types.KeyBytes) error + Serve(ctx context.Context, req []byte, stream io.ReadWriter) error +} + +type PairwiseSyncer interface { + Probe( + ctx context.Context, + peer p2p.Peer, + os rangesync.OrderedSet, + x, y types.KeyBytes, + ) (rangesync.ProbeResult, error) + Sync( + ctx context.Context, + peer p2p.Peer, + os rangesync.OrderedSet, + x, y types.KeyBytes, + ) error + Serve(ctx context.Context, req []byte, stream io.ReadWriter, os rangesync.OrderedSet) error +} + +type syncRunner interface { + splitSync(ctx context.Context, syncPeers []p2p.Peer) error + fullSync(ctx context.Context, syncPeers []p2p.Peer) error +} diff --git a/sync2/multipeer/mocks_test.go b/sync2/multipeer/mocks_test.go new file mode 100644 index 0000000000..2846331e0c --- /dev/null +++ b/sync2/multipeer/mocks_test.go @@ -0,0 +1,572 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./interface.go +// +// Generated by this command: +// +// mockgen -typed -package=multipeer -destination=./mocks_test.go -source=./interface.go +// + +// Package multipeer is a generated GoMock package. +package multipeer + +import ( + context "context" + io "io" + reflect "reflect" + + p2p "github.com/spacemeshos/go-spacemesh/p2p" + rangesync "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + types "github.com/spacemeshos/go-spacemesh/sync2/types" + gomock "go.uber.org/mock/gomock" +) + +// MockSyncBase is a mock of SyncBase interface. +type MockSyncBase struct { + ctrl *gomock.Controller + recorder *MockSyncBaseMockRecorder +} + +// MockSyncBaseMockRecorder is the mock recorder for MockSyncBase. +type MockSyncBaseMockRecorder struct { + mock *MockSyncBase +} + +// NewMockSyncBase creates a new mock instance. +func NewMockSyncBase(ctrl *gomock.Controller) *MockSyncBase { + mock := &MockSyncBase{ctrl: ctrl} + mock.recorder = &MockSyncBaseMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSyncBase) EXPECT() *MockSyncBaseMockRecorder { + return m.recorder +} + +// Count mocks base method. +func (m *MockSyncBase) Count(ctx context.Context) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Count", ctx) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Count indicates an expected call of Count. +func (mr *MockSyncBaseMockRecorder) Count(ctx any) *MockSyncBaseCountCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Count", reflect.TypeOf((*MockSyncBase)(nil).Count), ctx) + return &MockSyncBaseCountCall{Call: call} +} + +// MockSyncBaseCountCall wrap *gomock.Call +type MockSyncBaseCountCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSyncBaseCountCall) Return(arg0 int, arg1 error) *MockSyncBaseCountCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSyncBaseCountCall) Do(f func(context.Context) (int, error)) *MockSyncBaseCountCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSyncBaseCountCall) DoAndReturn(f func(context.Context) (int, error)) *MockSyncBaseCountCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Derive mocks base method. +func (m *MockSyncBase) Derive(p p2p.Peer) Syncer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Derive", p) + ret0, _ := ret[0].(Syncer) + return ret0 +} + +// Derive indicates an expected call of Derive. +func (mr *MockSyncBaseMockRecorder) Derive(p any) *MockSyncBaseDeriveCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Derive", reflect.TypeOf((*MockSyncBase)(nil).Derive), p) + return &MockSyncBaseDeriveCall{Call: call} +} + +// MockSyncBaseDeriveCall wrap *gomock.Call +type MockSyncBaseDeriveCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSyncBaseDeriveCall) Return(arg0 Syncer) *MockSyncBaseDeriveCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSyncBaseDeriveCall) Do(f func(p2p.Peer) Syncer) *MockSyncBaseDeriveCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSyncBaseDeriveCall) DoAndReturn(f func(p2p.Peer) Syncer) *MockSyncBaseDeriveCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Probe mocks base method. +func (m *MockSyncBase) Probe(ctx context.Context, p p2p.Peer) (rangesync.ProbeResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Probe", ctx, p) + ret0, _ := ret[0].(rangesync.ProbeResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Probe indicates an expected call of Probe. +func (mr *MockSyncBaseMockRecorder) Probe(ctx, p any) *MockSyncBaseProbeCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Probe", reflect.TypeOf((*MockSyncBase)(nil).Probe), ctx, p) + return &MockSyncBaseProbeCall{Call: call} +} + +// MockSyncBaseProbeCall wrap *gomock.Call +type MockSyncBaseProbeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSyncBaseProbeCall) Return(arg0 rangesync.ProbeResult, arg1 error) *MockSyncBaseProbeCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSyncBaseProbeCall) Do(f func(context.Context, p2p.Peer) (rangesync.ProbeResult, error)) *MockSyncBaseProbeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSyncBaseProbeCall) DoAndReturn(f func(context.Context, p2p.Peer) (rangesync.ProbeResult, error)) *MockSyncBaseProbeCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Wait mocks base method. +func (m *MockSyncBase) Wait() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Wait") + ret0, _ := ret[0].(error) + return ret0 +} + +// Wait indicates an expected call of Wait. +func (mr *MockSyncBaseMockRecorder) Wait() *MockSyncBaseWaitCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Wait", reflect.TypeOf((*MockSyncBase)(nil).Wait)) + return &MockSyncBaseWaitCall{Call: call} +} + +// MockSyncBaseWaitCall wrap *gomock.Call +type MockSyncBaseWaitCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSyncBaseWaitCall) Return(arg0 error) *MockSyncBaseWaitCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSyncBaseWaitCall) Do(f func() error) *MockSyncBaseWaitCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSyncBaseWaitCall) DoAndReturn(f func() error) *MockSyncBaseWaitCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockSyncer is a mock of Syncer interface. +type MockSyncer struct { + ctrl *gomock.Controller + recorder *MockSyncerMockRecorder +} + +// MockSyncerMockRecorder is the mock recorder for MockSyncer. +type MockSyncerMockRecorder struct { + mock *MockSyncer +} + +// NewMockSyncer creates a new mock instance. +func NewMockSyncer(ctrl *gomock.Controller) *MockSyncer { + mock := &MockSyncer{ctrl: ctrl} + mock.recorder = &MockSyncerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSyncer) EXPECT() *MockSyncerMockRecorder { + return m.recorder +} + +// Peer mocks base method. +func (m *MockSyncer) Peer() p2p.Peer { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Peer") + ret0, _ := ret[0].(p2p.Peer) + return ret0 +} + +// Peer indicates an expected call of Peer. +func (mr *MockSyncerMockRecorder) Peer() *MockSyncerPeerCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Peer", reflect.TypeOf((*MockSyncer)(nil).Peer)) + return &MockSyncerPeerCall{Call: call} +} + +// MockSyncerPeerCall wrap *gomock.Call +type MockSyncerPeerCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSyncerPeerCall) Return(arg0 p2p.Peer) *MockSyncerPeerCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSyncerPeerCall) Do(f func() p2p.Peer) *MockSyncerPeerCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSyncerPeerCall) DoAndReturn(f func() p2p.Peer) *MockSyncerPeerCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Serve mocks base method. +func (m *MockSyncer) Serve(ctx context.Context, req []byte, stream io.ReadWriter) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Serve", ctx, req, stream) + ret0, _ := ret[0].(error) + return ret0 +} + +// Serve indicates an expected call of Serve. +func (mr *MockSyncerMockRecorder) Serve(ctx, req, stream any) *MockSyncerServeCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Serve", reflect.TypeOf((*MockSyncer)(nil).Serve), ctx, req, stream) + return &MockSyncerServeCall{Call: call} +} + +// MockSyncerServeCall wrap *gomock.Call +type MockSyncerServeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSyncerServeCall) Return(arg0 error) *MockSyncerServeCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSyncerServeCall) Do(f func(context.Context, []byte, io.ReadWriter) error) *MockSyncerServeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSyncerServeCall) DoAndReturn(f func(context.Context, []byte, io.ReadWriter) error) *MockSyncerServeCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Sync mocks base method. +func (m *MockSyncer) Sync(ctx context.Context, x, y types.KeyBytes) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Sync", ctx, x, y) + ret0, _ := ret[0].(error) + return ret0 +} + +// Sync indicates an expected call of Sync. +func (mr *MockSyncerMockRecorder) Sync(ctx, x, y any) *MockSyncerSyncCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sync", reflect.TypeOf((*MockSyncer)(nil).Sync), ctx, x, y) + return &MockSyncerSyncCall{Call: call} +} + +// MockSyncerSyncCall wrap *gomock.Call +type MockSyncerSyncCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSyncerSyncCall) Return(arg0 error) *MockSyncerSyncCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSyncerSyncCall) Do(f func(context.Context, types.KeyBytes, types.KeyBytes) error) *MockSyncerSyncCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSyncerSyncCall) DoAndReturn(f func(context.Context, types.KeyBytes, types.KeyBytes) error) *MockSyncerSyncCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockPairwiseSyncer is a mock of PairwiseSyncer interface. +type MockPairwiseSyncer struct { + ctrl *gomock.Controller + recorder *MockPairwiseSyncerMockRecorder +} + +// MockPairwiseSyncerMockRecorder is the mock recorder for MockPairwiseSyncer. +type MockPairwiseSyncerMockRecorder struct { + mock *MockPairwiseSyncer +} + +// NewMockPairwiseSyncer creates a new mock instance. +func NewMockPairwiseSyncer(ctrl *gomock.Controller) *MockPairwiseSyncer { + mock := &MockPairwiseSyncer{ctrl: ctrl} + mock.recorder = &MockPairwiseSyncerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPairwiseSyncer) EXPECT() *MockPairwiseSyncerMockRecorder { + return m.recorder +} + +// Probe mocks base method. +func (m *MockPairwiseSyncer) Probe(ctx context.Context, peer p2p.Peer, os rangesync.OrderedSet, x, y types.KeyBytes) (rangesync.ProbeResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Probe", ctx, peer, os, x, y) + ret0, _ := ret[0].(rangesync.ProbeResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Probe indicates an expected call of Probe. +func (mr *MockPairwiseSyncerMockRecorder) Probe(ctx, peer, os, x, y any) *MockPairwiseSyncerProbeCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Probe", reflect.TypeOf((*MockPairwiseSyncer)(nil).Probe), ctx, peer, os, x, y) + return &MockPairwiseSyncerProbeCall{Call: call} +} + +// MockPairwiseSyncerProbeCall wrap *gomock.Call +type MockPairwiseSyncerProbeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPairwiseSyncerProbeCall) Return(arg0 rangesync.ProbeResult, arg1 error) *MockPairwiseSyncerProbeCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPairwiseSyncerProbeCall) Do(f func(context.Context, p2p.Peer, rangesync.OrderedSet, types.KeyBytes, types.KeyBytes) (rangesync.ProbeResult, error)) *MockPairwiseSyncerProbeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPairwiseSyncerProbeCall) DoAndReturn(f func(context.Context, p2p.Peer, rangesync.OrderedSet, types.KeyBytes, types.KeyBytes) (rangesync.ProbeResult, error)) *MockPairwiseSyncerProbeCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Serve mocks base method. +func (m *MockPairwiseSyncer) Serve(ctx context.Context, req []byte, stream io.ReadWriter, os rangesync.OrderedSet) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Serve", ctx, req, stream, os) + ret0, _ := ret[0].(error) + return ret0 +} + +// Serve indicates an expected call of Serve. +func (mr *MockPairwiseSyncerMockRecorder) Serve(ctx, req, stream, os any) *MockPairwiseSyncerServeCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Serve", reflect.TypeOf((*MockPairwiseSyncer)(nil).Serve), ctx, req, stream, os) + return &MockPairwiseSyncerServeCall{Call: call} +} + +// MockPairwiseSyncerServeCall wrap *gomock.Call +type MockPairwiseSyncerServeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPairwiseSyncerServeCall) Return(arg0 error) *MockPairwiseSyncerServeCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPairwiseSyncerServeCall) Do(f func(context.Context, []byte, io.ReadWriter, rangesync.OrderedSet) error) *MockPairwiseSyncerServeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPairwiseSyncerServeCall) DoAndReturn(f func(context.Context, []byte, io.ReadWriter, rangesync.OrderedSet) error) *MockPairwiseSyncerServeCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Sync mocks base method. +func (m *MockPairwiseSyncer) Sync(ctx context.Context, peer p2p.Peer, os rangesync.OrderedSet, x, y types.KeyBytes) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Sync", ctx, peer, os, x, y) + ret0, _ := ret[0].(error) + return ret0 +} + +// Sync indicates an expected call of Sync. +func (mr *MockPairwiseSyncerMockRecorder) Sync(ctx, peer, os, x, y any) *MockPairwiseSyncerSyncCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sync", reflect.TypeOf((*MockPairwiseSyncer)(nil).Sync), ctx, peer, os, x, y) + return &MockPairwiseSyncerSyncCall{Call: call} +} + +// MockPairwiseSyncerSyncCall wrap *gomock.Call +type MockPairwiseSyncerSyncCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPairwiseSyncerSyncCall) Return(arg0 error) *MockPairwiseSyncerSyncCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPairwiseSyncerSyncCall) Do(f func(context.Context, p2p.Peer, rangesync.OrderedSet, types.KeyBytes, types.KeyBytes) error) *MockPairwiseSyncerSyncCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPairwiseSyncerSyncCall) DoAndReturn(f func(context.Context, p2p.Peer, rangesync.OrderedSet, types.KeyBytes, types.KeyBytes) error) *MockPairwiseSyncerSyncCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MocksyncRunner is a mock of syncRunner interface. +type MocksyncRunner struct { + ctrl *gomock.Controller + recorder *MocksyncRunnerMockRecorder +} + +// MocksyncRunnerMockRecorder is the mock recorder for MocksyncRunner. +type MocksyncRunnerMockRecorder struct { + mock *MocksyncRunner +} + +// NewMocksyncRunner creates a new mock instance. +func NewMocksyncRunner(ctrl *gomock.Controller) *MocksyncRunner { + mock := &MocksyncRunner{ctrl: ctrl} + mock.recorder = &MocksyncRunnerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MocksyncRunner) EXPECT() *MocksyncRunnerMockRecorder { + return m.recorder +} + +// fullSync mocks base method. +func (m *MocksyncRunner) fullSync(ctx context.Context, syncPeers []p2p.Peer) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "fullSync", ctx, syncPeers) + ret0, _ := ret[0].(error) + return ret0 +} + +// fullSync indicates an expected call of fullSync. +func (mr *MocksyncRunnerMockRecorder) fullSync(ctx, syncPeers any) *MocksyncRunnerfullSyncCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "fullSync", reflect.TypeOf((*MocksyncRunner)(nil).fullSync), ctx, syncPeers) + return &MocksyncRunnerfullSyncCall{Call: call} +} + +// MocksyncRunnerfullSyncCall wrap *gomock.Call +type MocksyncRunnerfullSyncCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MocksyncRunnerfullSyncCall) Return(arg0 error) *MocksyncRunnerfullSyncCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MocksyncRunnerfullSyncCall) Do(f func(context.Context, []p2p.Peer) error) *MocksyncRunnerfullSyncCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MocksyncRunnerfullSyncCall) DoAndReturn(f func(context.Context, []p2p.Peer) error) *MocksyncRunnerfullSyncCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// splitSync mocks base method. +func (m *MocksyncRunner) splitSync(ctx context.Context, syncPeers []p2p.Peer) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "splitSync", ctx, syncPeers) + ret0, _ := ret[0].(error) + return ret0 +} + +// splitSync indicates an expected call of splitSync. +func (mr *MocksyncRunnerMockRecorder) splitSync(ctx, syncPeers any) *MocksyncRunnersplitSyncCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "splitSync", reflect.TypeOf((*MocksyncRunner)(nil).splitSync), ctx, syncPeers) + return &MocksyncRunnersplitSyncCall{Call: call} +} + +// MocksyncRunnersplitSyncCall wrap *gomock.Call +type MocksyncRunnersplitSyncCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MocksyncRunnersplitSyncCall) Return(arg0 error) *MocksyncRunnersplitSyncCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MocksyncRunnersplitSyncCall) Do(f func(context.Context, []p2p.Peer) error) *MocksyncRunnersplitSyncCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MocksyncRunnersplitSyncCall) DoAndReturn(f func(context.Context, []p2p.Peer) error) *MocksyncRunnersplitSyncCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/sync2/hashsync/multipeer.go b/sync2/multipeer/multipeer.go similarity index 97% rename from sync2/hashsync/multipeer.go rename to sync2/multipeer/multipeer.go index 195c94d0c5..9df538670b 100644 --- a/sync2/hashsync/multipeer.go +++ b/sync2/multipeer/multipeer.go @@ -1,4 +1,4 @@ -package hashsync +package multipeer import ( "context" @@ -99,7 +99,7 @@ var _ syncRunner = &runner{} func (r *runner) splitSync(ctx context.Context, syncPeers []p2p.Peer) error { s := newSplitSync( r.mpr.logger, r.mpr.syncBase, r.mpr.peers, syncPeers, - r.mpr.splitSyncGracePeriod, r.mpr.clock) + r.mpr.splitSyncGracePeriod, r.mpr.clock, r.mpr.keyLen, r.mpr.maxDepth) return s.sync(ctx) } @@ -120,12 +120,15 @@ type MultiPeerReconciler struct { syncInterval time.Duration noPeersRecheckInterval time.Duration clock clockwork.Clock + keyLen int + maxDepth int runner syncRunner } func NewMultiPeerReconciler( syncBase SyncBase, peers *peers.Peers, + keyLen, maxDepth int, opts ...MultiPeerReconcilerOpt, ) *MultiPeerReconciler { mpr := &MultiPeerReconciler{ @@ -141,6 +144,8 @@ func NewMultiPeerReconciler( splitSyncGracePeriod: time.Minute, noPeersRecheckInterval: 30 * time.Second, clock: clockwork.NewRealClock(), + keyLen: keyLen, + maxDepth: maxDepth, } for _, opt := range opts { opt(mpr) diff --git a/sync2/hashsync/multipeer_test.go b/sync2/multipeer/multipeer_test.go similarity index 90% rename from sync2/hashsync/multipeer_test.go rename to sync2/multipeer/multipeer_test.go index 58fd2ea1fe..235c1a9fc6 100644 --- a/sync2/hashsync/multipeer_test.go +++ b/sync2/multipeer/multipeer_test.go @@ -1,4 +1,4 @@ -package hashsync +package multipeer import ( "context" @@ -7,14 +7,15 @@ import ( "testing" "time" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "go.uber.org/zap/zaptest" "golang.org/x/sync/errgroup" - "github.com/jonboulle/clockwork" "github.com/spacemeshos/go-spacemesh/fetch/peers" "github.com/spacemeshos/go-spacemesh/p2p" - "github.com/stretchr/testify/require" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" ) // FIXME: BlockUntilContext is not included in FakeClock interface. @@ -49,7 +50,7 @@ func newMultiPeerSyncTester(t *testing.T) *multiPeerSyncTester { peers: peers.New(), clock: clockwork.NewFakeClock().(fakeClock), } - mt.reconciler = NewMultiPeerReconciler(mt.syncBase, mt.peers, + mt.reconciler = NewMultiPeerReconciler(mt.syncBase, mt.peers, 32, 24, WithLogger(zaptest.NewLogger(t)), WithSyncInterval(time.Minute), WithSyncPeerCount(6), @@ -82,10 +83,10 @@ func (mt *multiPeerSyncTester) start() context.Context { return ctx } -func (mt *multiPeerSyncTester) expectProbe(times int, pr ProbeResult) { +func (mt *multiPeerSyncTester) expectProbe(times int, pr rangesync.ProbeResult) { mt.selectedPeers = nil mt.syncBase.EXPECT().Probe(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, p p2p.Peer) (ProbeResult, error) { + func(_ context.Context, p p2p.Peer) (rangesync.ProbeResult, error) { require.NotContains(mt, mt.selectedPeers, p, "peer probed twice") require.True(mt, mt.peers.Contains(p)) mt.selectedPeers = append(mt.selectedPeers, p) @@ -133,7 +134,7 @@ func TestMultiPeerSync(t *testing.T) { // randomly and probed mt.syncBase.EXPECT().Count(gomock.Any()).Return(50, nil).AnyTimes() for i := 0; i < numSyncs; i++ { - mt.expectProbe(6, ProbeResult{ + mt.expectProbe(6, rangesync.ProbeResult{ FP: "foo", Count: 100, Sim: 0.5, // too low for full sync @@ -161,7 +162,7 @@ func TestMultiPeerSync(t *testing.T) { mt.addPeers(10) mt.syncBase.EXPECT().Count(gomock.Any()).Return(100, nil).AnyTimes() for i := 0; i < numSyncs; i++ { - mt.expectProbe(6, ProbeResult{ + mt.expectProbe(6, rangesync.ProbeResult{ FP: "foo", Count: 100, Sim: 0.99, // high enough for full sync @@ -181,7 +182,7 @@ func TestMultiPeerSync(t *testing.T) { mt.addPeers(1) mt.syncBase.EXPECT().Count(gomock.Any()).Return(50, nil).AnyTimes() for i := 0; i < numSyncs; i++ { - mt.expectProbe(1, ProbeResult{ + mt.expectProbe(1, rangesync.ProbeResult{ FP: "foo", Count: 100, Sim: 0.5, // too low for full sync, but will have it anyway @@ -201,8 +202,8 @@ func TestMultiPeerSync(t *testing.T) { mt.addPeers(10) mt.syncBase.EXPECT().Count(gomock.Any()).Return(100, nil).AnyTimes() mt.syncBase.EXPECT().Probe(gomock.Any(), gomock.Any()). - Return(ProbeResult{}, errors.New("probe failed")) - mt.expectProbe(5, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) + Return(rangesync.ProbeResult{}, errors.New("probe failed")) + mt.expectProbe(5, rangesync.ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) // just 5 peers for which the probe worked will be checked mt.expectFullSync(5, 0) mt.syncBase.EXPECT().Wait().Times(2) @@ -216,7 +217,7 @@ func TestMultiPeerSync(t *testing.T) { mt.addPeers(10) mt.syncBase.EXPECT().Count(gomock.Any()).Return(100, nil).AnyTimes() for i := 0; i < numSyncs; i++ { - mt.expectProbe(6, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) + mt.expectProbe(6, rangesync.ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) mt.expectFullSync(6, 3) mt.syncBase.EXPECT().Wait() mt.clock.BlockUntilContext(ctx, 1) @@ -232,7 +233,7 @@ func TestMultiPeerSync(t *testing.T) { mt.addPeers(10) mt.syncBase.EXPECT().Count(gomock.Any()).Return(100, nil).AnyTimes() for i := 0; i < numSyncs; i++ { - mt.expectProbe(6, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) + mt.expectProbe(6, rangesync.ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) mt.expectFullSync(6, 0) mt.syncBase.EXPECT().Wait().Return(errors.New("some handlers failed")) mt.clock.BlockUntilContext(ctx, 1) @@ -247,7 +248,7 @@ func TestMultiPeerSync(t *testing.T) { ctx := mt.start() mt.addPeers(10) mt.syncBase.EXPECT().Count(gomock.Any()).Return(100, nil).AnyTimes() - mt.expectProbe(6, ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) + mt.expectProbe(6, rangesync.ProbeResult{FP: "foo", Count: 100, Sim: 0.99}) mt.syncRunner.EXPECT().fullSync(gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, peers []p2p.Peer) error { mt.cancel() diff --git a/sync2/hashsync/setsyncbase.go b/sync2/multipeer/setsyncbase.go similarity index 63% rename from sync2/hashsync/setsyncbase.go rename to sync2/multipeer/setsyncbase.go index 4c0dcbd57c..64f56ddea7 100644 --- a/sync2/hashsync/setsyncbase.go +++ b/sync2/multipeer/setsyncbase.go @@ -1,4 +1,4 @@ -package hashsync +package multipeer import ( "context" @@ -10,15 +10,16 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p" "golang.org/x/sync/singleflight" - "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/types" ) -type SyncKeyHandler func(ctx context.Context, k Ordered, peer p2p.Peer) error +type SyncKeyHandler func(ctx context.Context, k types.Ordered, peer p2p.Peer) error type SetSyncBase struct { sync.Mutex ps PairwiseSyncer - is ItemStore + os rangesync.OrderedSet handler SyncKeyHandler waiting []<-chan singleflight.Result g singleflight.Group @@ -26,10 +27,10 @@ type SetSyncBase struct { var _ SyncBase = &SetSyncBase{} -func NewSetSyncBase(ps PairwiseSyncer, is ItemStore, handler SyncKeyHandler) *SetSyncBase { +func NewSetSyncBase(ps PairwiseSyncer, os rangesync.OrderedSet, handler SyncKeyHandler) *SetSyncBase { return &SetSyncBase{ ps: ps, - is: is, + os: os, handler: handler, } } @@ -39,15 +40,20 @@ func (ssb *SetSyncBase) Count(ctx context.Context) (int, error) { // TODO: don't lock on db-bound operations ssb.Lock() defer ssb.Unlock() - it, err := ssb.is.Min(ctx) - if it == nil || err != nil { - return 0, err + if empty, err := ssb.os.Empty(ctx); err != nil { + return 0, fmt.Errorf("check if the set is empty: %w", err) + } else if empty { + return 0, nil } - x, err := it.Key() + seq, err := ssb.os.Items(ctx) if err != nil { - return 0, err + return 0, fmt.Errorf("get items: %w", err) + } + x, err := seq.First() + if err != nil { + return 0, fmt.Errorf("get first item: %w", err) } - info, err := ssb.is.GetRangeInfo(ctx, nil, x, x, -1) + info, err := ssb.os.GetRangeInfo(ctx, x, x, -1) if err != nil { return 0, err } @@ -60,26 +66,26 @@ func (ssb *SetSyncBase) Derive(p p2p.Peer) Syncer { defer ssb.Unlock() return &setSyncer{ SetSyncBase: ssb, - ItemStore: ssb.is.Copy(), + OrderedSet: ssb.os.Copy(), p: p, } } // Probe implements syncBase. -func (ssb *SetSyncBase) Probe(ctx context.Context, p p2p.Peer) (ProbeResult, error) { +func (ssb *SetSyncBase) Probe(ctx context.Context, p p2p.Peer) (rangesync.ProbeResult, error) { // Use a snapshot of the store to avoid holding the mutex for a long time ssb.Lock() - is := ssb.is.Copy() + os := ssb.os.Copy() ssb.Unlock() - return ssb.ps.Probe(ctx, p, is, nil, nil) + return ssb.ps.Probe(ctx, p, os, nil, nil) } -func (ssb *SetSyncBase) acceptKey(ctx context.Context, k Ordered, p p2p.Peer) error { +func (ssb *SetSyncBase) acceptKey(ctx context.Context, k types.Ordered, p p2p.Peer) error { ssb.Lock() defer ssb.Unlock() key := k.(fmt.Stringer).String() - has, err := ssb.is.Has(ctx, k) + has, err := ssb.os.Has(ctx, k) if err != nil { return err } @@ -90,7 +96,7 @@ func (ssb *SetSyncBase) acceptKey(ctx context.Context, k Ordered, p p2p.Peer) er if err == nil { ssb.Lock() defer ssb.Unlock() - err = ssb.is.Add(ctx, k) + err = ssb.os.Add(ctx, k) } return key, err })) @@ -118,13 +124,13 @@ func (ssb *SetSyncBase) Wait() error { type setSyncer struct { *SetSyncBase - ItemStore + rangesync.OrderedSet p p2p.Peer } var ( - _ Syncer = &setSyncer{} - _ ItemStore = &setSyncer{} + _ Syncer = &setSyncer{} + _ rangesync.OrderedSet = &setSyncer{} ) // Peer implements syncer. @@ -133,8 +139,8 @@ func (ss *setSyncer) Peer() p2p.Peer { } // Sync implements syncer. -func (ss *setSyncer) Sync(ctx context.Context, x, y *types.Hash32) error { - return ss.ps.SyncStore(ctx, ss.p, ss, x, y) +func (ss *setSyncer) Sync(ctx context.Context, x, y types.KeyBytes) error { + return ss.ps.Sync(ctx, ss.p, ss, x, y) } // Serve implements Syncer @@ -143,9 +149,9 @@ func (ss *setSyncer) Serve(ctx context.Context, req []byte, stream io.ReadWriter } // Add implements ItemStore. -func (ss *setSyncer) Add(ctx context.Context, k Ordered) error { +func (ss *setSyncer) Add(ctx context.Context, k types.Ordered) error { if err := ss.acceptKey(ctx, k, ss.p); err != nil { return err } - return ss.ItemStore.Add(ctx, k) + return ss.OrderedSet.Add(ctx, k) } diff --git a/sync2/hashsync/setsyncbase_test.go b/sync2/multipeer/setsyncbase_test.go similarity index 58% rename from sync2/hashsync/setsyncbase_test.go rename to sync2/multipeer/setsyncbase_test.go index 0d9e963d62..4c32e45764 100644 --- a/sync2/hashsync/setsyncbase_test.go +++ b/sync2/multipeer/setsyncbase_test.go @@ -1,4 +1,4 @@ -package hashsync +package multipeer import ( "context" @@ -6,59 +6,62 @@ import ( "sync" "testing" - "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/p2p" "github.com/stretchr/testify/require" gomock "go.uber.org/mock/gomock" "golang.org/x/sync/errgroup" + + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + rmocks "github.com/spacemeshos/go-spacemesh/sync2/rangesync/mocks" + "github.com/spacemeshos/go-spacemesh/sync2/types" ) type setSyncBaseTester struct { *testing.T ctrl *gomock.Controller ps *MockPairwiseSyncer - is *MockItemStore + os *rmocks.MockOrderedSet ssb *SetSyncBase waitMtx sync.Mutex - waitChs map[Ordered]chan error - doneCh chan Ordered + waitChs map[string]chan error + doneCh chan types.Ordered } -func newSetSyncBaseTester(t *testing.T, is ItemStore) *setSyncBaseTester { +func newSetSyncBaseTester(t *testing.T, os rangesync.OrderedSet) *setSyncBaseTester { ctrl := gomock.NewController(t) st := &setSyncBaseTester{ T: t, ctrl: ctrl, ps: NewMockPairwiseSyncer(ctrl), - waitChs: make(map[Ordered]chan error), - doneCh: make(chan Ordered), + waitChs: make(map[string]chan error), + doneCh: make(chan types.Ordered), } - if is == nil { - st.is = NewMockItemStore(ctrl) - is = st.is + if os == nil { + st.os = rmocks.NewMockOrderedSet(ctrl) + os = st.os } - st.ssb = NewSetSyncBase(st.ps, is, func(ctx context.Context, k Ordered, p p2p.Peer) error { - err := <-st.getWaitCh(k) + st.ssb = NewSetSyncBase(st.ps, os, func(ctx context.Context, k types.Ordered, p p2p.Peer) error { + err := <-st.getWaitCh(k.(types.KeyBytes)) st.doneCh <- k return err }) return st } -func (st *setSyncBaseTester) getWaitCh(k Ordered) chan error { +func (st *setSyncBaseTester) getWaitCh(k types.KeyBytes) chan error { st.waitMtx.Lock() defer st.waitMtx.Unlock() - ch, found := st.waitChs[k] + ch, found := st.waitChs[string(k)] if !found { ch = make(chan error) - st.waitChs[k] = ch + st.waitChs[string(k)] = ch } return ch } -func (st *setSyncBaseTester) expectCopy(ctx context.Context, addedKeys ...types.Hash32) *MockItemStore { - copy := NewMockItemStore(st.ctrl) - st.is.EXPECT().Copy().DoAndReturn(func() ItemStore { +func (st *setSyncBaseTester) expectCopy(ctx context.Context, addedKeys ...types.KeyBytes) *rmocks.MockOrderedSet { + copy := rmocks.NewMockOrderedSet(st.ctrl) + st.os.EXPECT().Copy().DoAndReturn(func() rangesync.OrderedSet { for _, k := range addedKeys { copy.EXPECT().Add(ctx, k) } @@ -71,12 +74,17 @@ func (st *setSyncBaseTester) expectSyncStore( ctx context.Context, p p2p.Peer, ss Syncer, - addedKeys ...types.Hash32, + addedKeys ...types.KeyBytes, ) { st.ps.EXPECT().SyncStore(ctx, p, ss, nil, nil). - DoAndReturn(func(ctx context.Context, p p2p.Peer, is ItemStore, x, y *types.Hash32) error { + DoAndReturn(func( + ctx context.Context, + p p2p.Peer, + os rangesync.OrderedSet, + x, y types.KeyBytes, + ) error { for _, k := range addedKeys { - require.NoError(st, is.Add(ctx, k)) + require.NoError(st, os.Add(ctx, k)) } return nil }) @@ -89,17 +97,17 @@ func (st *setSyncBaseTester) failToSyncStore( err error, ) { st.ps.EXPECT().SyncStore(ctx, p, ss, nil, nil). - DoAndReturn(func(ctx context.Context, p p2p.Peer, is ItemStore, x, y *types.Hash32) error { + DoAndReturn(func(ctx context.Context, p p2p.Peer, os rangesync.OrderedSet, x, y types.KeyBytes) error { return err }) } -func (st *setSyncBaseTester) wait(count int) ([]types.Hash32, error) { +func (st *setSyncBaseTester) wait(count int) ([]types.KeyBytes, error) { var eg errgroup.Group eg.Go(st.ssb.Wait) - var handledKeys []types.Hash32 + var handledKeys []types.KeyBytes for k := range st.doneCh { - handledKeys = append(handledKeys, k.(types.Hash32)) + handledKeys = append(handledKeys, k.(types.KeyBytes).Clone()) count-- if count == 0 { break @@ -113,8 +121,8 @@ func TestSetSyncBase(t *testing.T) { t.Parallel() st := newSetSyncBaseTester(t, nil) ctx := context.Background() - expPr := ProbeResult{ - FP: types.RandomHash(), + expPr := rangesync.ProbeResult{ + FP: types.RandomFingerprint(), Count: 42, Sim: 0.99, } @@ -130,25 +138,25 @@ func TestSetSyncBase(t *testing.T) { st := newSetSyncBaseTester(t, nil) ctx := context.Background() - addedKey := types.RandomHash() + addedKey := types.RandomKeyBytes(32) st.expectCopy(ctx, addedKey) ss := st.ssb.Derive(p2p.Peer("p1")) require.Equal(t, p2p.Peer("p1"), ss.Peer()) - x := types.RandomHash() - y := types.RandomHash() - st.ps.EXPECT().SyncStore(ctx, p2p.Peer("p1"), ss, &x, &y) - require.NoError(t, ss.Sync(ctx, &x, &y)) + x := types.RandomKeyBytes(32) + y := types.RandomKeyBytes(32) + st.ps.EXPECT().SyncStore(ctx, p2p.Peer("p1"), ss, x, y) + require.NoError(t, ss.Sync(ctx, x, y)) - st.is.EXPECT().Has(gomock.Any(), addedKey) - st.is.EXPECT().Add(ctx, addedKey) + st.os.EXPECT().Has(gomock.Any(), addedKey) + st.os.EXPECT().Add(ctx, addedKey) st.expectSyncStore(ctx, p2p.Peer("p1"), ss, addedKey) require.NoError(t, ss.Sync(ctx, nil, nil)) close(st.getWaitCh(addedKey)) handledKeys, err := st.wait(1) require.NoError(t, err) - require.ElementsMatch(t, []types.Hash32{addedKey}, handledKeys) + require.ElementsMatch(t, []types.KeyBytes{addedKey}, handledKeys) }) t.Run("single key synced multiple times", func(t *testing.T) { @@ -156,15 +164,15 @@ func TestSetSyncBase(t *testing.T) { st := newSetSyncBaseTester(t, nil) ctx := context.Background() - addedKey := types.RandomHash() + addedKey := types.RandomKeyBytes(32) st.expectCopy(ctx, addedKey, addedKey, addedKey) ss := st.ssb.Derive(p2p.Peer("p1")) require.Equal(t, p2p.Peer("p1"), ss.Peer()) // added just once - st.is.EXPECT().Add(ctx, addedKey) + st.os.EXPECT().Add(ctx, addedKey) for i := 0; i < 3; i++ { - st.is.EXPECT().Has(gomock.Any(), addedKey) + st.os.EXPECT().Has(gomock.Any(), addedKey) st.expectSyncStore(ctx, p2p.Peer("p1"), ss, addedKey) require.NoError(t, ss.Sync(ctx, nil, nil)) } @@ -172,7 +180,7 @@ func TestSetSyncBase(t *testing.T) { handledKeys, err := st.wait(1) require.NoError(t, err) - require.ElementsMatch(t, []types.Hash32{addedKey}, handledKeys) + require.ElementsMatch(t, []types.KeyBytes{addedKey}, handledKeys) }) t.Run("multiple keys", func(t *testing.T) { @@ -180,16 +188,16 @@ func TestSetSyncBase(t *testing.T) { st := newSetSyncBaseTester(t, nil) ctx := context.Background() - k1 := types.RandomHash() - k2 := types.RandomHash() + k1 := types.RandomKeyBytes(32) + k2 := types.RandomKeyBytes(32) st.expectCopy(ctx, k1, k2) ss := st.ssb.Derive(p2p.Peer("p1")) require.Equal(t, p2p.Peer("p1"), ss.Peer()) - st.is.EXPECT().Has(gomock.Any(), k1) - st.is.EXPECT().Has(gomock.Any(), k2) - st.is.EXPECT().Add(ctx, k1) - st.is.EXPECT().Add(ctx, k2) + st.os.EXPECT().Has(gomock.Any(), k1) + st.os.EXPECT().Has(gomock.Any(), k2) + st.os.EXPECT().Add(ctx, k1) + st.os.EXPECT().Add(ctx, k2) st.expectSyncStore(ctx, p2p.Peer("p1"), ss, k1, k2) require.NoError(t, ss.Sync(ctx, nil, nil)) close(st.getWaitCh(k1)) @@ -197,7 +205,7 @@ func TestSetSyncBase(t *testing.T) { handledKeys, err := st.wait(2) require.NoError(t, err) - require.ElementsMatch(t, []types.Hash32{k1, k2}, handledKeys) + require.ElementsMatch(t, []types.KeyBytes{k1, k2}, handledKeys) }) t.Run("handler failure", func(t *testing.T) { @@ -205,16 +213,16 @@ func TestSetSyncBase(t *testing.T) { st := newSetSyncBaseTester(t, nil) ctx := context.Background() - k1 := types.RandomHash() - k2 := types.RandomHash() + k1 := types.RandomKeyBytes(32) + k2 := types.RandomKeyBytes(32) st.expectCopy(ctx, k1, k2) ss := st.ssb.Derive(p2p.Peer("p1")) require.Equal(t, p2p.Peer("p1"), ss.Peer()) - st.is.EXPECT().Has(gomock.Any(), k1) - st.is.EXPECT().Has(gomock.Any(), k2) + st.os.EXPECT().Has(gomock.Any(), k1) + st.os.EXPECT().Has(gomock.Any(), k2) // k1 is not propagated to syncBase due to the handler failure - st.is.EXPECT().Add(ctx, k2) + st.os.EXPECT().Add(ctx, k2) st.expectSyncStore(ctx, p2p.Peer("p1"), ss, k1, k2) require.NoError(t, ss.Sync(ctx, nil, nil)) handlerErr := errors.New("fail") @@ -223,27 +231,27 @@ func TestSetSyncBase(t *testing.T) { handledKeys, err := st.wait(2) require.ErrorIs(t, err, handlerErr) - require.ElementsMatch(t, []types.Hash32{k1, k2}, handledKeys) + require.ElementsMatch(t, []types.KeyBytes{k1, k2}, handledKeys) }) t.Run("synctree based item store", func(t *testing.T) { t.Parallel() - hs := make([]types.Hash32, 4) + hs := make([]types.KeyBytes, 4) for n := range hs { - hs[n] = types.RandomHash() + hs[n] = types.RandomKeyBytes(32) } - is := NewSyncTreeStore(Hash32To12Xor{}) - is.Add(context.Background(), hs[0]) - is.Add(context.Background(), hs[1]) - st := newSetSyncBaseTester(t, is) + os := rangesync.NewDumbHashSet(true) + os.Add(context.Background(), hs[0]) + os.Add(context.Background(), hs[1]) + st := newSetSyncBaseTester(t, os) ss := st.ssb.Derive(p2p.Peer("p1")) - ss.(ItemStore).Add(context.Background(), hs[2]) - ss.(ItemStore).Add(context.Background(), hs[3]) + ss.(rangesync.OrderedSet).Add(context.Background(), hs[2]) + ss.(rangesync.OrderedSet).Add(context.Background(), hs[3]) // syncer's cloned ItemStore has new key immediately - has, err := ss.(ItemStore).Has(context.Background(), hs[2]) + has, err := ss.(rangesync.OrderedSet).Has(context.Background(), hs[2]) require.NoError(t, err) require.True(t, has) - has, err = ss.(ItemStore).Has(context.Background(), hs[3]) + has, err = ss.(rangesync.OrderedSet).Has(context.Background(), hs[3]) require.True(t, has) handlerErr := errors.New("fail") st.getWaitCh(hs[2]) <- handlerErr @@ -252,9 +260,9 @@ func TestSetSyncBase(t *testing.T) { require.ErrorIs(t, err, handlerErr) require.ElementsMatch(t, hs[2:], handledKeys) // only successfully handled key propagate the syncBase - has, err = is.Has(context.Background(), hs[2]) + has, err = os.Has(context.Background(), hs[2]) require.False(t, has) - has, err = is.Has(context.Background(), hs[3]) + has, err = os.Has(context.Background(), hs[3]) require.True(t, has) }) } diff --git a/sync2/hashsync/split_sync.go b/sync2/multipeer/split_sync.go similarity index 90% rename from sync2/hashsync/split_sync.go rename to sync2/multipeer/split_sync.go index d164b72b12..9131bacca4 100644 --- a/sync2/hashsync/split_sync.go +++ b/sync2/multipeer/split_sync.go @@ -1,8 +1,7 @@ -package hashsync +package multipeer import ( "context" - "encoding/binary" "errors" "slices" "time" @@ -11,7 +10,6 @@ import ( "go.uber.org/zap" "golang.org/x/sync/errgroup" - "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/fetch/peers" "github.com/spacemeshos/go-spacemesh/p2p" ) @@ -47,6 +45,7 @@ func newSplitSync( syncPeers []p2p.Peer, gracePeriod time.Duration, clock clockwork.Clock, + keyLen, maxDepth int, ) *splitSync { if len(syncPeers) == 0 { panic("BUG: no peers passed to splitSync") @@ -58,7 +57,7 @@ func newSplitSync( syncPeers: syncPeers, gracePeriod: gracePeriod, clock: clock, - sq: newSyncQueue(len(syncPeers)), + sq: newSyncQueue(len(syncPeers), keyLen, maxDepth), resCh: make(chan syncResult), syncMap: make(map[p2p.Peer]*syncRange), failedPeers: make(map[p2p.Peer]struct{}), @@ -83,7 +82,7 @@ func (s *splitSync) startPeerSync(ctx context.Context, p p2p.Peer, sr *syncRange doneCh := make(chan struct{}) s.eg.Go(func() error { defer close(doneCh) - err := syncer.Sync(ctx, &sr.x, &sr.y) + err := syncer.Sync(ctx, sr.x, sr.y) select { case <-ctx.Done(): return ctx.Err() @@ -200,17 +199,3 @@ func (s *splitSync) sync(ctx context.Context) error { s.logger.Debug("QQQQQ: wg wait") return s.eg.Wait() } - -func getDelimiters(numPeers int) (h []types.Hash32) { - if numPeers < 2 { - return nil - } - // QQQQQ: TBD: support maxDepth - inc := (uint64(0x80) << 56) / uint64(numPeers) - h = make([]types.Hash32, numPeers-1) - for i, v := 0, uint64(0); i < numPeers-1; i++ { - v += inc - binary.BigEndian.PutUint64(h[i][:], v<<1) - } - return h -} diff --git a/sync2/hashsync/split_sync_test.go b/sync2/multipeer/split_sync_test.go similarity index 75% rename from sync2/hashsync/split_sync_test.go rename to sync2/multipeer/split_sync_test.go index 2eff5ede4e..95de0ac852 100644 --- a/sync2/hashsync/split_sync_test.go +++ b/sync2/multipeer/split_sync_test.go @@ -1,4 +1,4 @@ -package hashsync +package multipeer import ( "context" @@ -13,62 +13,12 @@ import ( "go.uber.org/zap/zaptest" "golang.org/x/sync/errgroup" - "github.com/spacemeshos/go-spacemesh/common/types" + smtypes "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/fetch/peers" "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/sync2/types" ) -func hexDelimiters(n int) (r []string) { - for _, h := range getDelimiters(n) { - r = append(r, h.String()) - } - return r -} - -func TestGetDelimiters(t *testing.T) { - for _, tc := range []struct { - numPeers int - values []string - }{ - { - numPeers: 0, - values: nil, - }, - { - numPeers: 1, - values: nil, - }, - { - numPeers: 2, - values: []string{ - "8000000000000000000000000000000000000000000000000000000000000000", - }, - }, - { - numPeers: 3, - values: []string{ - "5555555555555554000000000000000000000000000000000000000000000000", - "aaaaaaaaaaaaaaa8000000000000000000000000000000000000000000000000", - }, - }, - { - numPeers: 4, - values: []string{ - "4000000000000000000000000000000000000000000000000000000000000000", - "8000000000000000000000000000000000000000000000000000000000000000", - "c000000000000000000000000000000000000000000000000000000000000000", - }, - }, - } { - r := hexDelimiters(tc.numPeers) - if len(tc.values) == 0 { - require.Empty(t, r, "%d delimiters", tc.numPeers) - } else { - require.Equal(t, tc.values, r, "%d delimiters", tc.numPeers) - } - } -} - type splitSyncTester struct { testing.TB @@ -119,7 +69,7 @@ func newTestSplitSync(t testing.TB) *splitSyncTester { peers: peers.New(), } for n := range tst.syncPeers { - tst.syncPeers[n] = p2p.Peer(types.RandomBytes(20)) + tst.syncPeers[n] = p2p.Peer(smtypes.RandomBytes(20)) } for index, p := range tst.syncPeers { index := index @@ -131,7 +81,7 @@ func newTestSplitSync(t testing.TB) *splitSyncTester { s.EXPECT().Peer().Return(p).AnyTimes() s.EXPECT(). Sync(gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, x, y *types.Hash32) error { + DoAndReturn(func(ctx context.Context, x, y types.KeyBytes) error { tst.mtx.Lock() defer tst.mtx.Unlock() require.NotNil(t, ctx) @@ -165,6 +115,7 @@ func newTestSplitSync(t testing.TB) *splitSyncTester { tst.syncPeers, time.Minute, tst.clock, + 32, 24, ) return tst } diff --git a/sync2/hashsync/sync_queue.go b/sync2/multipeer/sync_queue.go similarity index 84% rename from sync2/hashsync/sync_queue.go rename to sync2/multipeer/sync_queue.go index d4805d7b29..b250e944d6 100644 --- a/sync2/hashsync/sync_queue.go +++ b/sync2/multipeer/sync_queue.go @@ -1,14 +1,14 @@ -package hashsync +package multipeer import ( "container/heap" "time" - "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/sync2/types" ) type syncRange struct { - x, y types.Hash32 + x, y types.KeyBytes lastSyncStarted time.Time done bool numSyncers int @@ -53,16 +53,16 @@ func (sq *syncQueue) Pop() any { return sr } -func newSyncQueue(numPeers int) syncQueue { - delim := getDelimiters(numPeers) - var y types.Hash32 +func newSyncQueue(numPeers, keyLen, maxDepth int) syncQueue { + delim := getDelimiters(numPeers, keyLen, maxDepth) + y := make(types.KeyBytes, keyLen) sq := make(syncQueue, numPeers) for n := range sq { - x := y - if n == numPeers-1 { - y = types.Hash32{} - } else { + x := y.Clone() + if n < numPeers-1 { y = delim[n] + } else { + y = make(types.KeyBytes, keyLen) } sq[n] = &syncRange{ x: x, diff --git a/sync2/hashsync/sync_queue_test.go b/sync2/multipeer/sync_queue_test.go similarity index 97% rename from sync2/hashsync/sync_queue_test.go rename to sync2/multipeer/sync_queue_test.go index 180d2aa0f2..9fa5d0f329 100644 --- a/sync2/hashsync/sync_queue_test.go +++ b/sync2/multipeer/sync_queue_test.go @@ -1,4 +1,4 @@ -package hashsync +package multipeer import ( "testing" @@ -28,7 +28,7 @@ func TestSyncQueue(t *testing.T) { "0000000000000000000000000000000000000000000000000000000000000000", }: false, } - sq := newSyncQueue(4) + sq := newSyncQueue(4, 32, 24) startTime := time.Now() pushed := make([]hexRange, 4) for i := 0; i < 4; i++ { diff --git a/sync2/p2p.go b/sync2/p2p.go index fca6a84665..633341eda8 100644 --- a/sync2/p2p.go +++ b/sync2/p2p.go @@ -14,7 +14,8 @@ import ( "github.com/spacemeshos/go-spacemesh/fetch/peers" "github.com/spacemeshos/go-spacemesh/p2p/server" - "github.com/spacemeshos/go-spacemesh/sync2/hashsync" + "github.com/spacemeshos/go-spacemesh/sync2/multipeer" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" ) type Config struct { @@ -33,8 +34,8 @@ type Config struct { func DefaultConfig() Config { return Config{ - MaxSendRange: hashsync.DefaultMaxSendRange, - SampleSize: hashsync.DefaultSampleSize, + MaxSendRange: rangesync.DefaultMaxSendRange, + SampleSize: rangesync.DefaultSampleSize, Timeout: 10 * time.Second, SyncPeerCount: 20, MinSplitSyncPeers: 2, @@ -50,9 +51,9 @@ func DefaultConfig() Config { type P2PHashSync struct { logger *zap.Logger h host.Host - is hashsync.ItemStore - syncBase hashsync.SyncBase - reconciler *hashsync.MultiPeerReconciler + os rangesync.OrderedSet + syncBase multipeer.SyncBase + reconciler *multipeer.MultiPeerReconciler srv *server.Server cancel context.CancelFunc eg errgroup.Group @@ -63,35 +64,37 @@ type P2PHashSync struct { func NewP2PHashSync( logger *zap.Logger, h host.Host, + os rangesync.OrderedSet, + keyLen, maxDepth int, proto string, peers *peers.Peers, - handler hashsync.SyncKeyHandler, + handler multipeer.SyncKeyHandler, cfg Config, ) *P2PHashSync { s := &P2PHashSync{ logger: logger, h: h, - is: hashsync.NewSyncTreeStore(hashsync.Hash32To12Xor{}), + os: os, } s.srv = server.New(h, proto, s.handle, server.WithTimeout(cfg.Timeout), server.WithLog(logger)) - ps := hashsync.NewPairwiseStoreSyncer(s.srv, []hashsync.RangeSetReconcilerOption{ - hashsync.WithMaxSendRange(cfg.MaxSendRange), - hashsync.WithSampleSize(cfg.SampleSize), + ps := rangesync.NewPairwiseSetSyncer(s.srv, []rangesync.RangeSetReconcilerOption{ + rangesync.WithMaxSendRange(cfg.MaxSendRange), + rangesync.WithSampleSize(cfg.SampleSize), }) - s.syncBase = hashsync.NewSetSyncBase(ps, s.is, handler) - s.reconciler = hashsync.NewMultiPeerReconciler( - s.syncBase, peers, - hashsync.WithLogger(logger), - hashsync.WithSyncPeerCount(cfg.SyncPeerCount), - hashsync.WithMinSplitSyncPeers(cfg.MinSplitSyncPeers), - hashsync.WithMinSplitSyncCount(cfg.MinSplitSyncCount), - hashsync.WithMaxFullDiff(cfg.MaxFullDiff), - hashsync.WithSyncInterval(cfg.SyncInterval), - hashsync.WithMinCompleteFraction(cfg.MinCompleteFraction), - hashsync.WithSplitSyncGracePeriod(time.Minute), - hashsync.WithNoPeersRecheckInterval(cfg.NoPeersRecheckInterval)) + s.syncBase = multipeer.NewSetSyncBase(ps, s.os, handler) + s.reconciler = multipeer.NewMultiPeerReconciler( + s.syncBase, peers, keyLen, maxDepth, + multipeer.WithLogger(logger), + multipeer.WithSyncPeerCount(cfg.SyncPeerCount), + multipeer.WithMinSplitSyncPeers(cfg.MinSplitSyncPeers), + multipeer.WithMinSplitSyncCount(cfg.MinSplitSyncCount), + multipeer.WithMaxFullDiff(cfg.MaxFullDiff), + multipeer.WithSyncInterval(cfg.SyncInterval), + multipeer.WithMinCompleteFraction(cfg.MinCompleteFraction), + multipeer.WithSplitSyncGracePeriod(time.Minute), + multipeer.WithNoPeersRecheckInterval(cfg.NoPeersRecheckInterval)) return s } @@ -109,8 +112,8 @@ func (s *P2PHashSync) handle(ctx context.Context, req []byte, stream io.ReadWrit return syncer.Serve(ctx, req, stream) } -func (s *P2PHashSync) ItemStore() hashsync.ItemStore { - return s.is +func (s *P2PHashSync) Set() rangesync.OrderedSet { + return s.os } func (s *P2PHashSync) Start() { diff --git a/sync2/p2p_test.go b/sync2/p2p_test.go index e40fa74b94..7c4c4a7cae 100644 --- a/sync2/p2p_test.go +++ b/sync2/p2p_test.go @@ -10,10 +10,10 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" - "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/fetch/peers" "github.com/spacemeshos/go-spacemesh/p2p" - "github.com/spacemeshos/go-spacemesh/sync2/hashsync" + "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + "github.com/spacemeshos/go-spacemesh/sync2/types" ) func TestP2P(t *testing.T) { @@ -26,14 +26,14 @@ func TestP2P(t *testing.T) { require.NoError(t, err) type addedKey struct { fromPeer, toPeer p2p.Peer - key hashsync.Ordered + key string } var mtx sync.Mutex synced := make(map[addedKey]struct{}) hs := make([]*P2PHashSync, numNodes) - initialSet := make([]types.Hash32, numHashes) + initialSet := make([]types.KeyBytes, numHashes) for n := range initialSet { - initialSet[n] = types.RandomHash() + initialSet[n] = types.RandomKeyBytes(32) } for n := range hs { ps := peers.New() @@ -45,20 +45,21 @@ func TestP2P(t *testing.T) { cfg := DefaultConfig() cfg.SyncInterval = 100 * time.Millisecond host := mesh.Hosts()[n] - handler := func(ctx context.Context, k hashsync.Ordered, peer p2p.Peer) error { + handler := func(ctx context.Context, k types.Ordered, peer p2p.Peer) error { mtx.Lock() defer mtx.Unlock() ak := addedKey{ fromPeer: peer, toPeer: host.ID(), - key: k, + key: string(k.(types.KeyBytes)), } synced[ak] = struct{}{} return nil } - hs[n] = NewP2PHashSync(logger, host, "sync2test", ps, handler, cfg) + os := rangesync.NewDumbHashSet(true) + hs[n] = NewP2PHashSync(logger, host, os, 32, 24, "sync2test", ps, handler, cfg) if n == 0 { - is := hs[n].ItemStore() + is := hs[n].Set() for _, h := range initialSet { is.Add(context.Background(), h) } @@ -69,15 +70,17 @@ func TestP2P(t *testing.T) { require.Eventually(t, func() bool { for _, hsync := range hs { // use a snapshot to avoid races - is := hsync.ItemStore().Copy() - it, err := is.Min(context.Background()) + os := hsync.Set().Copy() + empty, err := os.Empty(context.Background()) require.NoError(t, err) - if it == nil { + if empty { return false } - k, err := it.Key() + seq, err := os.Items(context.Background()) require.NoError(t, err) - info, err := is.GetRangeInfo(context.Background(), nil, k, k, -1) + k, err := seq.First() + require.NoError(t, err) + info, err := os.GetRangeInfo(context.Background(), k, k, -1) require.NoError(t, err) if info.Count < numHashes { return false @@ -88,8 +91,7 @@ func TestP2P(t *testing.T) { for _, hsync := range hs { hsync.Stop() - actualItems, err := hashsync.CollectStoreItems[types.Hash32]( - context.Background(), hsync.ItemStore()) + actualItems, err := rangesync.CollectSetItems[types.KeyBytes](context.Background(), hsync.Set()) require.NoError(t, err) require.ElementsMatch(t, initialSet, actualItems) } diff --git a/sync2/rangesync/dumbset.go b/sync2/rangesync/dumbset.go new file mode 100644 index 0000000000..19b150d2f0 --- /dev/null +++ b/sync2/rangesync/dumbset.go @@ -0,0 +1,271 @@ +package rangesync + +import ( + "context" + "crypto/md5" + "errors" + "slices" + "sync" + "time" + + "github.com/zeebo/blake3" + + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +func stringToFP(s string) types.Fingerprint { + h := md5.New() + h.Write([]byte(s)) + return types.Fingerprint(h.Sum(nil)) +} + +func gtePos(all []types.KeyBytes, item types.KeyBytes) int { + n := slices.IndexFunc(all, func(v types.KeyBytes) bool { + return v.Compare(item) >= 0 + }) + if n >= 0 { + return n + } + return len(all) +} + +func naiveRange(all []types.KeyBytes, x, y types.KeyBytes, stopCount int) (items []types.KeyBytes, startID, endID types.KeyBytes) { + if len(all) == 0 { + return nil, nil, nil + } + // all = slices.Clone(all) // QQQQQ: should not need this + // slices.Sort(all) // QQQQQ: should not need this + start := gtePos(all, x) + end := gtePos(all, y) + if x.Compare(y) < 0 { + if stopCount >= 0 && end-start > stopCount { + end = start + stopCount + } + if end < len(all) { + endID = all[end] + } else { + endID = all[0] + } + startID = nil + if start < len(all) { + startID = all[start] + } else { + startID = all[0] + } + return all[start:end], startID, endID + } else { + r := append(all[start:], all[:end]...) + if len(r) == 0 { + return nil, all[0], all[0] + } + if stopCount >= 0 && len(r) > stopCount { + return r[:stopCount], r[0], r[stopCount] + } + if end < len(all) { + endID = all[end] + } else { + endID = all[0] + } + startID = nil + if len(r) != 0 { + startID = r[0] + } + return r, startID, endID + } +} + +var naiveFPFunc = func(items []types.KeyBytes) types.Fingerprint { + s := "" + for _, k := range items { + s += string(k) + } + return stringToFP(s) +} + +type dumbSet struct { + keys []types.KeyBytes + disableReAdd bool + added map[string]bool + fpFunc func(items []types.KeyBytes) types.Fingerprint +} + +var _ OrderedSet = &dumbSet{} + +func (ds *dumbSet) Add(ctx context.Context, k types.Ordered) error { + id := k.(types.KeyBytes) + if len(ds.keys) == 0 { + ds.keys = []types.KeyBytes{id} + return nil + } + p := slices.IndexFunc(ds.keys, func(other types.KeyBytes) bool { + return other.Compare(id) >= 0 + }) + switch { + case p < 0: + ds.keys = append(ds.keys, id) + case id.Compare(ds.keys[p]) == 0: + if ds.disableReAdd { + if ds.added[string(id)] { + panic("hash sent twice: " + id.String()) + } + if ds.added == nil { + ds.added = make(map[string]bool) + } + ds.added[string(id)] = true + } + // already present + default: + ds.keys = slices.Insert(ds.keys, p, id) + } + + return nil +} + +func (ds *dumbSet) seq(n int) types.Seq { + if n < -0 || n > len(ds.keys) { + panic("bad index") + } + return types.Seq(func(yield func(types.Ordered, error) bool) { + n := n // make the sequence reusable + for { + if !yield(ds.keys[n], nil) { + break + } + n = (n + 1) % len(ds.keys) + } + }) +} + +func (ds *dumbSet) seqFor(s types.KeyBytes) types.Seq { + n := slices.IndexFunc(ds.keys, func(k types.KeyBytes) bool { + return k.Compare(s) == 0 + }) + if n == -1 { + panic("item not found: " + s.String()) + } + return ds.seq(n) +} + +func (ds *dumbSet) getRangeInfo( + _ context.Context, + x, y types.Ordered, + count int, +) (r RangeInfo, end types.KeyBytes, err error) { + if x == nil && y == nil { + if len(ds.keys) == 0 { + return RangeInfo{ + Fingerprint: types.EmptyFingerprint(), + }, nil, nil + } + x = ds.keys[0] + y = x + } else if x == nil || y == nil { + panic("BUG: bad X or Y") + } + vx := x.(types.KeyBytes) + vy := y.(types.KeyBytes) + rangeItems, start, end := naiveRange(ds.keys, vx, vy, count) + fpFunc := ds.fpFunc + if fpFunc == nil { + fpFunc = naiveFPFunc + } + r = RangeInfo{ + Fingerprint: fpFunc(rangeItems), + Count: len(rangeItems), + } + if r.Count != 0 { + if start == nil || end == nil { + panic("empty start/end from naiveRange") + } + r.Items = ds.seqFor(start) + } + return r, end, nil +} + +func (ds *dumbSet) GetRangeInfo( + ctx context.Context, + x, y types.Ordered, + count int, +) (RangeInfo, error) { + ri, _, err := ds.getRangeInfo(ctx, x, y, count) + return ri, err +} + +func (ds *dumbSet) SplitRange( + ctx context.Context, + x, y types.Ordered, + count int, +) (SplitInfo, error) { + if count <= 0 { + panic("BUG: bad split count") + } + part0, middle, err := ds.getRangeInfo(ctx, x, y, count) + if err != nil { + return SplitInfo{}, err + } + if part0.Count == 0 { + return SplitInfo{}, errors.New("can't split empty range") + } + part1, err := ds.GetRangeInfo(ctx, middle, y, -1) + if err != nil { + return SplitInfo{}, err + } + return SplitInfo{ + Parts: [2]RangeInfo{part0, part1}, + Middle: middle, + }, nil +} + +func (ds *dumbSet) Empty(ctx context.Context) (bool, error) { + return len(ds.keys) == 0, nil +} + +func (ds *dumbSet) Items(ctx context.Context) (types.Seq, error) { + if len(ds.keys) == 0 { + return types.EmptySeq(), nil + } + return ds.seq(0), nil +} + +func (ds *dumbSet) Copy() OrderedSet { + return &dumbSet{keys: slices.Clone(ds.keys)} +} + +func (ds *dumbSet) Has(ctx context.Context, k types.Ordered) (bool, error) { + for _, cur := range ds.keys { + if k.Compare(cur) == 0 { + return true, nil + } + } + return false, nil +} + +func (ds *dumbSet) Recent(ctx context.Context, since time.Time) (types.Seq, int, error) { + return nil, 0, nil +} + +var hashPool = &sync.Pool{ + New: func() any { + return blake3.New() + }, +} + +func NewDumbHashSet(disableReAdd bool) OrderedSet { + return &dumbSet{ + disableReAdd: disableReAdd, + fpFunc: func(items []types.KeyBytes) (r types.Fingerprint) { + hasher := hashPool.Get().(*blake3.Hasher) + defer func() { + hasher.Reset() + hashPool.Put(hasher) + }() + var hashRes [32]byte + for _, h := range items { + hasher.Write(h[:]) + } + hasher.Sum(hashRes[:0]) + copy(r[:], hashRes[:]) + return r + }, + } +} diff --git a/sync2/rangesync/interface.go b/sync2/rangesync/interface.go new file mode 100644 index 0000000000..8ada2d35d2 --- /dev/null +++ b/sync2/rangesync/interface.go @@ -0,0 +1,62 @@ +package rangesync + +import ( + "context" + "time" + + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/p2p/server" + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +//go:generate mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./interface.go + +// RangeInfo contains information about a range of items in the OrderedSet as returned by +// OrderedSet.GetRangeInfo. +type RangeInfo struct { + // Fingerprint of the interval + Fingerprint types.Fingerprint + // Number of items in the interval + Count int + // Items is the sequence of set elements in the interval. + Items types.Seq +} + +// SplitInfo contains information about range split in two. +type SplitInfo struct { + // 2 parts of the range + Parts [2]RangeInfo + // Middle point between the ranges + Middle types.Ordered +} + +// OrderedSet represents the set that can be synced against a remote peer +type OrderedSet interface { + // Add adds a key to the set + Add(ctx context.Context, k types.Ordered) error + // GetRangeInfo returns RangeInfo for the item range in the tree. + // If count >= 0, at most count items are returned, and RangeInfo + // is returned for the corresponding subrange of the requested range. + // X and Y must not be nil. + GetRangeInfo(ctx context.Context, x, y types.Ordered, count int) (RangeInfo, error) + // SplitRange splits the range roughly after the specified count of items, + // returning RangeInfo for the first half and the second half of the range. + SplitRange(ctx context.Context, x, y types.Ordered, count int) (SplitInfo, error) + // Items returns the sequence of items in the set. + Items(ctx context.Context) (types.Seq, error) + // Empty returns true if the set is empty. + Empty(ctx context.Context) (bool, error) + // Copy makes a shallow copy of the OrderedSet + Copy() OrderedSet + // Has returns true if the specified key is present in OrderedSet + Has(ctx context.Context, k types.Ordered) (bool, error) + // Recent returns an Iterator that yields the items added since the specified + // timestamp. Some OrderedSet implementations may not have Recent implemented, in + // which case it should return an error. + Recent(ctx context.Context, since time.Time) (types.Seq, int, error) +} + +type Requester interface { + Run(context.Context) error + StreamRequest(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error +} diff --git a/sync2/hashsync/log.go b/sync2/rangesync/log.go similarity index 61% rename from sync2/hashsync/log.go rename to sync2/rangesync/log.go index 980c9f1434..25a979ac1a 100644 --- a/sync2/hashsync/log.go +++ b/sync2/rangesync/log.go @@ -1,4 +1,4 @@ -package hashsync +package rangesync import ( "encoding/hex" @@ -7,26 +7,25 @@ import ( "go.uber.org/zap" - "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/sync2/types" ) -type itFormatter struct { - it Iterator +type seqFormatter struct { + seq types.Seq } -func (f itFormatter) String() string { - k, err := f.it.Key() - if err != nil { - return fmt.Sprintf("", err) +func (f seqFormatter) String() string { + for k, err := range f.seq { + if err != nil { + return fmt.Sprintf("", err) + } + return fmt.Sprintf("[%s, ...]", hexStr(k)) } - return hexStr(k) + return "" } -func IteratorField(name string, it Iterator) zap.Field { - if it == nil { - return zap.String(name, "") - } - return zap.Stringer(name, itFormatter{it: it}) +func SeqField(name string, seq types.Seq) zap.Field { + return zap.Stringer(name, seqFormatter{seq: seq}) } // based on code from testify @@ -50,10 +49,8 @@ func isNil(object any) bool { func hexStr(k any) string { switch h := k.(type) { - case types.Hash32: - return h.ShortString() - case types.Hash12: - return hex.EncodeToString(h[:5]) + case types.KeyBytes: + return h.String() case []byte: if len(h) > 5 { h = h[:5] diff --git a/sync2/rangesync/mocks/mocks.go b/sync2/rangesync/mocks/mocks.go new file mode 100644 index 0000000000..a366ff61e5 --- /dev/null +++ b/sync2/rangesync/mocks/mocks.go @@ -0,0 +1,460 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./interface.go +// +// Generated by this command: +// +// mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./interface.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + time "time" + + p2p "github.com/spacemeshos/go-spacemesh/p2p" + server "github.com/spacemeshos/go-spacemesh/p2p/server" + rangesync "github.com/spacemeshos/go-spacemesh/sync2/rangesync" + types "github.com/spacemeshos/go-spacemesh/sync2/types" + gomock "go.uber.org/mock/gomock" +) + +// MockOrderedSet is a mock of OrderedSet interface. +type MockOrderedSet struct { + ctrl *gomock.Controller + recorder *MockOrderedSetMockRecorder +} + +// MockOrderedSetMockRecorder is the mock recorder for MockOrderedSet. +type MockOrderedSetMockRecorder struct { + mock *MockOrderedSet +} + +// NewMockOrderedSet creates a new mock instance. +func NewMockOrderedSet(ctrl *gomock.Controller) *MockOrderedSet { + mock := &MockOrderedSet{ctrl: ctrl} + mock.recorder = &MockOrderedSetMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOrderedSet) EXPECT() *MockOrderedSetMockRecorder { + return m.recorder +} + +// Add mocks base method. +func (m *MockOrderedSet) Add(ctx context.Context, k types.Ordered) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Add", ctx, k) + ret0, _ := ret[0].(error) + return ret0 +} + +// Add indicates an expected call of Add. +func (mr *MockOrderedSetMockRecorder) Add(ctx, k any) *MockOrderedSetAddCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockOrderedSet)(nil).Add), ctx, k) + return &MockOrderedSetAddCall{Call: call} +} + +// MockOrderedSetAddCall wrap *gomock.Call +type MockOrderedSetAddCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockOrderedSetAddCall) Return(arg0 error) *MockOrderedSetAddCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockOrderedSetAddCall) Do(f func(context.Context, types.Ordered) error) *MockOrderedSetAddCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockOrderedSetAddCall) DoAndReturn(f func(context.Context, types.Ordered) error) *MockOrderedSetAddCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Copy mocks base method. +func (m *MockOrderedSet) Copy() rangesync.OrderedSet { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Copy") + ret0, _ := ret[0].(rangesync.OrderedSet) + return ret0 +} + +// Copy indicates an expected call of Copy. +func (mr *MockOrderedSetMockRecorder) Copy() *MockOrderedSetCopyCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Copy", reflect.TypeOf((*MockOrderedSet)(nil).Copy)) + return &MockOrderedSetCopyCall{Call: call} +} + +// MockOrderedSetCopyCall wrap *gomock.Call +type MockOrderedSetCopyCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockOrderedSetCopyCall) Return(arg0 rangesync.OrderedSet) *MockOrderedSetCopyCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockOrderedSetCopyCall) Do(f func() rangesync.OrderedSet) *MockOrderedSetCopyCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockOrderedSetCopyCall) DoAndReturn(f func() rangesync.OrderedSet) *MockOrderedSetCopyCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Empty mocks base method. +func (m *MockOrderedSet) Empty(ctx context.Context) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Empty", ctx) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Empty indicates an expected call of Empty. +func (mr *MockOrderedSetMockRecorder) Empty(ctx any) *MockOrderedSetEmptyCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Empty", reflect.TypeOf((*MockOrderedSet)(nil).Empty), ctx) + return &MockOrderedSetEmptyCall{Call: call} +} + +// MockOrderedSetEmptyCall wrap *gomock.Call +type MockOrderedSetEmptyCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockOrderedSetEmptyCall) Return(arg0 bool, arg1 error) *MockOrderedSetEmptyCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockOrderedSetEmptyCall) Do(f func(context.Context) (bool, error)) *MockOrderedSetEmptyCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockOrderedSetEmptyCall) DoAndReturn(f func(context.Context) (bool, error)) *MockOrderedSetEmptyCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// GetRangeInfo mocks base method. +func (m *MockOrderedSet) GetRangeInfo(ctx context.Context, x, y types.Ordered, count int) (rangesync.RangeInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRangeInfo", ctx, x, y, count) + ret0, _ := ret[0].(rangesync.RangeInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRangeInfo indicates an expected call of GetRangeInfo. +func (mr *MockOrderedSetMockRecorder) GetRangeInfo(ctx, x, y, count any) *MockOrderedSetGetRangeInfoCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRangeInfo", reflect.TypeOf((*MockOrderedSet)(nil).GetRangeInfo), ctx, x, y, count) + return &MockOrderedSetGetRangeInfoCall{Call: call} +} + +// MockOrderedSetGetRangeInfoCall wrap *gomock.Call +type MockOrderedSetGetRangeInfoCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockOrderedSetGetRangeInfoCall) Return(arg0 rangesync.RangeInfo, arg1 error) *MockOrderedSetGetRangeInfoCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockOrderedSetGetRangeInfoCall) Do(f func(context.Context, types.Ordered, types.Ordered, int) (rangesync.RangeInfo, error)) *MockOrderedSetGetRangeInfoCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockOrderedSetGetRangeInfoCall) DoAndReturn(f func(context.Context, types.Ordered, types.Ordered, int) (rangesync.RangeInfo, error)) *MockOrderedSetGetRangeInfoCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Has mocks base method. +func (m *MockOrderedSet) Has(ctx context.Context, k types.Ordered) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Has", ctx, k) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Has indicates an expected call of Has. +func (mr *MockOrderedSetMockRecorder) Has(ctx, k any) *MockOrderedSetHasCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Has", reflect.TypeOf((*MockOrderedSet)(nil).Has), ctx, k) + return &MockOrderedSetHasCall{Call: call} +} + +// MockOrderedSetHasCall wrap *gomock.Call +type MockOrderedSetHasCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockOrderedSetHasCall) Return(arg0 bool, arg1 error) *MockOrderedSetHasCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockOrderedSetHasCall) Do(f func(context.Context, types.Ordered) (bool, error)) *MockOrderedSetHasCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockOrderedSetHasCall) DoAndReturn(f func(context.Context, types.Ordered) (bool, error)) *MockOrderedSetHasCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Items mocks base method. +func (m *MockOrderedSet) Items(ctx context.Context) (types.Seq, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Items", ctx) + ret0, _ := ret[0].(types.Seq) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Items indicates an expected call of Items. +func (mr *MockOrderedSetMockRecorder) Items(ctx any) *MockOrderedSetItemsCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Items", reflect.TypeOf((*MockOrderedSet)(nil).Items), ctx) + return &MockOrderedSetItemsCall{Call: call} +} + +// MockOrderedSetItemsCall wrap *gomock.Call +type MockOrderedSetItemsCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockOrderedSetItemsCall) Return(arg0 types.Seq, arg1 error) *MockOrderedSetItemsCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockOrderedSetItemsCall) Do(f func(context.Context) (types.Seq, error)) *MockOrderedSetItemsCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockOrderedSetItemsCall) DoAndReturn(f func(context.Context) (types.Seq, error)) *MockOrderedSetItemsCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Recent mocks base method. +func (m *MockOrderedSet) Recent(ctx context.Context, since time.Time) (types.Seq, int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Recent", ctx, since) + ret0, _ := ret[0].(types.Seq) + ret1, _ := ret[1].(int) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// Recent indicates an expected call of Recent. +func (mr *MockOrderedSetMockRecorder) Recent(ctx, since any) *MockOrderedSetRecentCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Recent", reflect.TypeOf((*MockOrderedSet)(nil).Recent), ctx, since) + return &MockOrderedSetRecentCall{Call: call} +} + +// MockOrderedSetRecentCall wrap *gomock.Call +type MockOrderedSetRecentCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockOrderedSetRecentCall) Return(arg0 types.Seq, arg1 int, arg2 error) *MockOrderedSetRecentCall { + c.Call = c.Call.Return(arg0, arg1, arg2) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockOrderedSetRecentCall) Do(f func(context.Context, time.Time) (types.Seq, int, error)) *MockOrderedSetRecentCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockOrderedSetRecentCall) DoAndReturn(f func(context.Context, time.Time) (types.Seq, int, error)) *MockOrderedSetRecentCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// SplitRange mocks base method. +func (m *MockOrderedSet) SplitRange(ctx context.Context, x, y types.Ordered, count int) (rangesync.SplitInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SplitRange", ctx, x, y, count) + ret0, _ := ret[0].(rangesync.SplitInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SplitRange indicates an expected call of SplitRange. +func (mr *MockOrderedSetMockRecorder) SplitRange(ctx, x, y, count any) *MockOrderedSetSplitRangeCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SplitRange", reflect.TypeOf((*MockOrderedSet)(nil).SplitRange), ctx, x, y, count) + return &MockOrderedSetSplitRangeCall{Call: call} +} + +// MockOrderedSetSplitRangeCall wrap *gomock.Call +type MockOrderedSetSplitRangeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockOrderedSetSplitRangeCall) Return(arg0 rangesync.SplitInfo, arg1 error) *MockOrderedSetSplitRangeCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockOrderedSetSplitRangeCall) Do(f func(context.Context, types.Ordered, types.Ordered, int) (rangesync.SplitInfo, error)) *MockOrderedSetSplitRangeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockOrderedSetSplitRangeCall) DoAndReturn(f func(context.Context, types.Ordered, types.Ordered, int) (rangesync.SplitInfo, error)) *MockOrderedSetSplitRangeCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockRequester is a mock of Requester interface. +type MockRequester struct { + ctrl *gomock.Controller + recorder *MockRequesterMockRecorder +} + +// MockRequesterMockRecorder is the mock recorder for MockRequester. +type MockRequesterMockRecorder struct { + mock *MockRequester +} + +// NewMockRequester creates a new mock instance. +func NewMockRequester(ctrl *gomock.Controller) *MockRequester { + mock := &MockRequester{ctrl: ctrl} + mock.recorder = &MockRequesterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRequester) EXPECT() *MockRequesterMockRecorder { + return m.recorder +} + +// Run mocks base method. +func (m *MockRequester) Run(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Run", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Run indicates an expected call of Run. +func (mr *MockRequesterMockRecorder) Run(arg0 any) *MockRequesterRunCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRequester)(nil).Run), arg0) + return &MockRequesterRunCall{Call: call} +} + +// MockRequesterRunCall wrap *gomock.Call +type MockRequesterRunCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRequesterRunCall) Return(arg0 error) *MockRequesterRunCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRequesterRunCall) Do(f func(context.Context) error) *MockRequesterRunCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRequesterRunCall) DoAndReturn(f func(context.Context) error) *MockRequesterRunCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// StreamRequest mocks base method. +func (m *MockRequester) StreamRequest(arg0 context.Context, arg1 p2p.Peer, arg2 []byte, arg3 server.StreamRequestCallback, arg4 ...string) error { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1, arg2, arg3} + for _, a := range arg4 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "StreamRequest", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// StreamRequest indicates an expected call of StreamRequest. +func (mr *MockRequesterMockRecorder) StreamRequest(arg0, arg1, arg2, arg3 any, arg4 ...any) *MockRequesterStreamRequestCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StreamRequest", reflect.TypeOf((*MockRequester)(nil).StreamRequest), varargs...) + return &MockRequesterStreamRequestCall{Call: call} +} + +// MockRequesterStreamRequestCall wrap *gomock.Call +type MockRequesterStreamRequestCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRequesterStreamRequestCall) Return(arg0 error) *MockRequesterStreamRequestCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRequesterStreamRequestCall) Do(f func(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error) *MockRequesterStreamRequestCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRequesterStreamRequestCall) DoAndReturn(f func(context.Context, p2p.Peer, []byte, server.StreamRequestCallback, ...string) error) *MockRequesterStreamRequestCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/sync2/rangesync/p2p.go b/sync2/rangesync/p2p.go new file mode 100644 index 0000000000..1102a54811 --- /dev/null +++ b/sync2/rangesync/p2p.go @@ -0,0 +1,107 @@ +package rangesync + +import ( + "bytes" + "context" + "io" + + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +type PairwiseSetSyncer struct { + r Requester + opts []RangeSetReconcilerOption +} + +func NewPairwiseSetSyncer(r Requester, opts []RangeSetReconcilerOption) *PairwiseSetSyncer { + return &PairwiseSetSyncer{r: r, opts: opts} +} + +func (pss *PairwiseSetSyncer) Probe( + ctx context.Context, + peer p2p.Peer, + os OrderedSet, + x, y types.KeyBytes, +) (ProbeResult, error) { + var ( + err error + initReq []byte + info RangeInfo + pr ProbeResult + ) + var c wireConduit + rsr := NewRangeSetReconciler(os, pss.opts...) + if x == nil { + initReq, err = c.withInitialRequest(func(c Conduit) error { + info, err = rsr.InitiateProbe(ctx, c) + return err + }) + } else { + initReq, err = c.withInitialRequest(func(c Conduit) error { + info, err = rsr.InitiateBoundedProbe(ctx, c, x, y) + return err + }) + } + if err != nil { + return ProbeResult{}, err + } + err = pss.r.StreamRequest(ctx, peer, initReq, func(ctx context.Context, stream io.ReadWriter) error { + c.stream = stream + var err error + pr, err = rsr.HandleProbeResponse(&c, info) + return err + }) + if err != nil { + return ProbeResult{}, err + } + return pr, nil +} + +func (pss *PairwiseSetSyncer) Sync( + ctx context.Context, + peer p2p.Peer, + os OrderedSet, + x, y types.KeyBytes, +) error { + var c wireConduit + rsr := NewRangeSetReconciler(os, pss.opts...) + var ( + initReq []byte + err error + ) + if x == nil { + initReq, err = c.withInitialRequest(func(c Conduit) error { + return rsr.Initiate(ctx, c) + }) + } else { + initReq, err = c.withInitialRequest(func(c Conduit) error { + return rsr.InitiateBounded(ctx, c, x, y) + }) + } + if err != nil { + return err + } + return pss.r.StreamRequest(ctx, peer, initReq, func(ctx context.Context, stream io.ReadWriter) error { + return c.handleStream(ctx, stream, rsr) + }) +} + +func (pss *PairwiseSetSyncer) Serve( + ctx context.Context, + req []byte, + stream io.ReadWriter, + os OrderedSet, +) error { + var c wireConduit + rsr := NewRangeSetReconciler(os, pss.opts...) + s := struct { + io.Reader + io.Writer + }{ + // prepend the received request to data being read + Reader: io.MultiReader(bytes.NewBuffer(req), stream), + Writer: stream, + } + return c.handleStream(ctx, s, rsr) +} diff --git a/sync2/rangesync/p2p_test.go b/sync2/rangesync/p2p_test.go new file mode 100644 index 0000000000..dc925955ff --- /dev/null +++ b/sync2/rangesync/p2p_test.go @@ -0,0 +1,358 @@ +package rangesync + +import ( + "context" + "io" + "sync/atomic" + "testing" + "time" + + "github.com/jonboulle/clockwork" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + "golang.org/x/sync/errgroup" + + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/p2p/server" + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +type getRequesterFunc func(name string, handler server.StreamHandler, peers ...Requester) (Requester, p2p.Peer) + +type clientServerTester struct { + client Requester + srvPeerID p2p.Peer + bytesReadValue atomic.Int64 + bytesWrittenValue atomic.Int64 +} + +func newClientServerTester( + t *testing.T, + set OrderedSet, + getRequester getRequesterFunc, + opts []RangeSetReconcilerOption, +) (*clientServerTester, context.Context) { + var ( + cst clientServerTester + srv Requester + ) + srvHandler := func(ctx context.Context, req []byte, stream io.ReadWriter) error { + pss := NewPairwiseSetSyncer(nil, opts) + return pss.Serve(ctx, req, wrapStream(&cst, stream), set) + } + srv, cst.srvPeerID = getRequester("srv", srvHandler) + var eg errgroup.Group + ctx, cancel := context.WithCancel(context.Background()) + eg.Go(func() error { + return srv.Run(ctx) + }) + t.Cleanup(func() { + cancel() + eg.Wait() + }) + + cst.client, _ = getRequester("client", nil, srv) + return &cst, ctx +} + +func (cst *clientServerTester) bytesRead() int64 { + return cst.bytesReadValue.Load() +} + +func (cst *clientServerTester) bytesWritten() int64 { + return cst.bytesWrittenValue.Load() +} + +type countingStream struct { + io.ReadWriter + cst *clientServerTester +} + +func (s *countingStream) Read(p []byte) (n int, err error) { + n, err = s.ReadWriter.Read(p) + s.cst.bytesReadValue.Add(int64(n)) + return n, err +} + +func (s *countingStream) Write(p []byte) (n int, err error) { + n, err = s.ReadWriter.Write(p) + s.cst.bytesWrittenValue.Add(int64(n)) + return n, err +} + +func wrapStream(cst *clientServerTester, s io.ReadWriter) io.ReadWriter { + return &countingStream{cst: cst, ReadWriter: s} +} + +func fakeRequesterGetter() getRequesterFunc { + return func(name string, handler server.StreamHandler, peers ...Requester) (Requester, p2p.Peer) { + pid := p2p.Peer(name) + return newFakeRequester(pid, handler, peers...), pid + } +} + +func p2pRequesterGetter(t *testing.T) getRequesterFunc { + mesh, err := mocknet.FullMeshConnected(2) + require.NoError(t, err) + proto := "itest" + opts := []server.Opt{ + server.WithRequestSizeLimit(100_000_000), + server.WithTimeout(10 * time.Second), + server.WithLog(zaptest.NewLogger(t)), + } + return func(name string, handler server.StreamHandler, peers ...Requester) (Requester, p2p.Peer) { + if len(peers) == 0 { + return server.New(mesh.Hosts()[0], proto, handler, opts...), mesh.Hosts()[0].ID() + } + s := server.New(mesh.Hosts()[1], proto, handler, opts...) + // TODO: this 'Eventually' is somewhat misplaced + require.Eventually(t, func() bool { + for _, h := range mesh.Hosts()[0:] { + if len(h.Mux().Protocols()) == 0 { + return false + } + } + return true + }, time.Second, 10*time.Millisecond) + return s, mesh.Hosts()[1].ID() + } +} + +type syncTracer struct { + dumb bool + receivedItems int + sentItems int +} + +var _ Tracer = &syncTracer{} + +func (tr *syncTracer) OnDumbSync() { + tr.dumb = true +} + +func (tr *syncTracer) OnRecent(receivedItems, sentItems int) { + tr.receivedItems += receivedItems + tr.sentItems += sentItems +} + +type fakeRecentSet struct { + OrderedSet + timestamps map[string]time.Time + clock clockwork.Clock +} + +var _ OrderedSet = &fakeRecentSet{} + +var startDate = time.Date(2024, 8, 29, 18, 0, 0, 0, time.UTC) + +func (frs *fakeRecentSet) registerAll(ctx context.Context) error { + frs.timestamps = make(map[string]time.Time) + t := startDate + items, err := CollectSetItems[types.KeyBytes](ctx, frs.OrderedSet) + if err != nil { + return err + } + for _, v := range items { + frs.timestamps[string(v)] = t + t = t.Add(time.Second) + } + return nil +} + +func (frs *fakeRecentSet) Add(ctx context.Context, k types.Ordered) error { + if err := frs.OrderedSet.Add(ctx, k); err != nil { + return err + } + h := k.(types.KeyBytes) + frs.timestamps[string(h)] = frs.clock.Now() + return nil +} + +func (frs *fakeRecentSet) Recent(ctx context.Context, since time.Time) (types.Seq, int, error) { + var items []types.KeyBytes + items, err := CollectSetItems[types.KeyBytes](ctx, frs.OrderedSet) + if err != nil { + return nil, 0, err + } + for _, k := range items { + if !frs.timestamps[string(k)].Before(since) { + items = append(items, k) + } + } + return func(yield func(types.Ordered, error) bool) { + for _, h := range items { + if !yield(h, nil) { + return + } + } + }, len(items), nil +} + +func testWireSync(t *testing.T, getRequester getRequesterFunc) { + for _, tc := range []struct { + name string + cfg hashSyncTestConfig + dumb bool + opts []RangeSetReconcilerOption + advance time.Duration + sentRecent bool + receivedRecent bool + }{ + { + name: "non-dumb sync", + cfg: hashSyncTestConfig{ + maxSendRange: 1, + numTestHashes: 1000, + minNumSpecificA: 8, + maxNumSpecificA: 16, + minNumSpecificB: 8, + maxNumSpecificB: 16, + }, + dumb: false, + }, + { + name: "dumb sync", + cfg: hashSyncTestConfig{ + maxSendRange: 1, + numTestHashes: 1000, + minNumSpecificA: 400, + maxNumSpecificA: 500, + minNumSpecificB: 400, + maxNumSpecificB: 500, + }, + dumb: true, + }, + { + name: "recent sync", + cfg: hashSyncTestConfig{ + maxSendRange: 1, + numTestHashes: 1000, + minNumSpecificA: 400, + maxNumSpecificA: 500, + minNumSpecificB: 400, + maxNumSpecificB: 500, + allowReAdd: true, + }, + dumb: false, + opts: []RangeSetReconcilerOption{ + WithRecentTimeSpan(990 * time.Second), + }, + advance: 1000 * time.Second, + sentRecent: true, + receivedRecent: true, + }, + { + name: "larger sync", + cfg: hashSyncTestConfig{ + maxSendRange: 1, + numTestHashes: 10000, + minNumSpecificA: 4, + maxNumSpecificA: 100, + minNumSpecificB: 4, + maxNumSpecificB: 100, + }, + dumb: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + st := newHashSyncTester(t, tc.cfg) + clock := clockwork.NewFakeClockAt(startDate) + // Note that at this point, the items are already added to the sets + // and thus fakeRecentSet.Add is not invoked for them, just underlying + // set's Add method + setA := &fakeRecentSet{OrderedSet: st.setA, clock: clock} + require.NoError(t, setA.registerAll(context.Background())) + setB := &fakeRecentSet{OrderedSet: st.setB, clock: clock} + require.NoError(t, setB.registerAll(context.Background())) + var tr syncTracer + opts := append(st.opts, WithTracer(&tr), WithClock(clock)) + opts = append(opts, tc.opts...) + opts = opts[0:len(opts):len(opts)] + clock.Advance(tc.advance) + cst, ctx := newClientServerTester(t, setA, getRequester, opts) + // nr := RmmeNumRead() + // nw := RmmeNumWritten() + pss := NewPairwiseSetSyncer(cst.client, opts) + err := pss.Sync(ctx, cst.srvPeerID, setB, nil, nil) + require.NoError(t, err) + + t.Logf("numSpecific: %d, bytesSent %d, bytesReceived %d", + st.numSpecificA+st.numSpecificB, + cst.bytesRead(), cst.bytesWritten()) + require.Equal(t, tc.dumb, tr.dumb, "dumb sync") + require.Equal(t, tc.receivedRecent, tr.receivedItems > 0) + require.Equal(t, tc.sentRecent, tr.sentItems > 0) + st.verify() + }) + } +} + +func TestWireSync(t *testing.T) { + t.Run("fake requester", func(t *testing.T) { + testWireSync(t, fakeRequesterGetter()) + }) + t.Run("p2p", func(t *testing.T) { + testWireSync(t, p2pRequesterGetter(t)) + }) +} + +func testWireProbe(t *testing.T, getRequester getRequesterFunc) { + st := newHashSyncTester(t, hashSyncTestConfig{ + maxSendRange: 1, + numTestHashes: 10000, + minNumSpecificA: 130, + maxNumSpecificA: 130, + minNumSpecificB: 130, + maxNumSpecificB: 130, + }) + cst, ctx := newClientServerTester(t, st.setA, getRequester, st.opts) + pss := NewPairwiseSetSyncer(cst.client, st.opts) + itemsA, err := st.setA.Items(ctx) + require.NoError(t, err) + kA, err := itemsA.First() + require.NoError(t, err) + infoA, err := st.setA.GetRangeInfo(ctx, kA, kA, -1) + require.NoError(t, err) + prA, err := pss.Probe(ctx, cst.srvPeerID, st.setB, nil, nil) + require.NoError(t, err) + require.Equal(t, infoA.Fingerprint, prA.FP) + require.Equal(t, infoA.Count, prA.Count) + require.InDelta(t, 0.98, prA.Sim, 0.05, "sim") + + itemsA, err = st.setA.Items(ctx) + require.NoError(t, err) + kA, err = itemsA.First() + require.NoError(t, err) + partInfoA, err := st.setA.GetRangeInfo(ctx, kA, kA, infoA.Count/2) + require.NoError(t, err) + xK, err := partInfoA.Items.First() + require.NoError(t, err) + x := xK.(types.KeyBytes) + var y types.KeyBytes + n := partInfoA.Count + 1 + for k, err := range partInfoA.Items { + if err != nil { + break + } + y = k.(types.KeyBytes) + n-- + if n == 0 { + break + } + } + prA, err = pss.Probe(ctx, cst.srvPeerID, st.setB, x, y) + require.NoError(t, err) + require.Equal(t, partInfoA.Fingerprint, prA.FP) + require.Equal(t, partInfoA.Count, prA.Count) + require.InDelta(t, 0.98, prA.Sim, 0.1, "sim") +} + +func TestWireProbe(t *testing.T) { + t.Run("fake requester", func(t *testing.T) { + testWireProbe(t, fakeRequesterGetter()) + }) + t.Run("p2p", func(t *testing.T) { + testWireProbe(t, p2pRequesterGetter(t)) + }) +} diff --git a/sync2/hashsync/rangesync.go b/sync2/rangesync/rangesync.go similarity index 77% rename from sync2/hashsync/rangesync.go rename to sync2/rangesync/rangesync.go index 070984d061..8e23361869 100644 --- a/sync2/hashsync/rangesync.go +++ b/sync2/rangesync/rangesync.go @@ -1,10 +1,9 @@ -package hashsync +package rangesync import ( "context" "errors" "fmt" - "iter" "reflect" "slices" "strings" @@ -12,6 +11,8 @@ import ( "github.com/jonboulle/clockwork" "go.uber.org/zap" + + "github.com/spacemeshos/go-spacemesh/sync2/types" ) // Interactions: @@ -242,11 +243,11 @@ func (mtype MessageType) String() string { type SyncMessage interface { Type() MessageType - X() Ordered - Y() Ordered - Fingerprint() any + X() types.Ordered + Y() types.Ordered + Fingerprint() types.Fingerprint Count() int - Keys() []Ordered + Keys() []types.Ordered Since() time.Time } @@ -277,7 +278,7 @@ func SyncMessageToString(m SyncMessage) string { if count := m.Count(); count != 0 { fmt.Fprintf(&sb, " Count=%d", count) } - if fp := m.Fingerprint(); fp != nil { + if fp := m.Fingerprint(); fp != types.EmptyFingerprint() { sb.WriteString(" FP=" + formatID(fp)) } for _, k := range m.Keys() { @@ -287,47 +288,45 @@ func SyncMessageToString(m SyncMessage) string { return sb.String() } -// Conduit handles receiving and sending peer messages -// TODO: replace multiple Send* methods with a single one -// (after de-generalizing messages) +// Conduit handles receiving and sending peer messages. type Conduit interface { // NextMessage returns the next SyncMessage, or nil if there are no more // SyncMessages for this session. NextMessage is only called after a NextItem call // indicates that there are no more items. NextMessage should not be called after - // any of Send...() methods is invoked + // any of Send...() methods is invoked. NextMessage() (SyncMessage, error) // SendFingerprint sends range fingerprint to the peer. // Count must be > 0 - SendFingerprint(x, y Ordered, fingerprint any, count int) error + SendFingerprint(x, y types.Ordered, fingerprint types.Fingerprint, count int) error // SendEmptySet notifies the peer that it we don't have any items. - // The corresponding SyncMessage has Count() == 0, X() == nil and Y() == nil + // The corresponding SyncMessage has Count() == 0, X() == nil and Y() == nil. SendEmptySet() error // SendEmptyRange notifies the peer that the specified range - // is empty on our side. The corresponding SyncMessage has Count() == 0 - SendEmptyRange(x, y Ordered) error + // is empty on our side. The corresponding SyncMessage has Count() == 0. + SendEmptyRange(x, y types.Ordered) error // SendRangeContents notifies the peer that the corresponding range items will // be included in this sync round. The items themselves are sent via - // SendItems - SendRangeContents(x, y Ordered, count int) error - // SendItems sends a chunk of items - SendChunk(items []Ordered) error - // SendEndRound sends a message that signifies the end of sync round + // SendItems. + SendRangeContents(x, y types.Ordered, count int) error + // SendItems sends a chunk of items. + SendChunk(items []types.Ordered) error + // SendEndRound sends a message that signifies the end of sync round. SendEndRound() error - // SendDone sends a message that notifies the peer that sync is finished + // SendDone sends a message that notifies the peer that sync is finished. SendDone() error // SendProbe sends a message requesting fingerprint and count of the // whole range or part of the range. If fingerprint is provided and // it doesn't match the fingerprint on the probe handler side, // the handler must send a sample subset of its items for MinHash // calculation. - SendProbe(x, y Ordered, fingerprint any, sampleSize int) error + SendProbe(x, y types.Ordered, fingerprint types.Fingerprint, sampleSize int) error // SendSample sends a set sample. If 'it' is not nil, the corresponding items are - // included in the sample - SendSample(x, y Ordered, fingerprint any, count, sampleSize int, it Iterator) error - // SendRecent sends recent items + // included in the sample. + SendSample(x, y types.Ordered, fingerprint types.Fingerprint, count, sampleSize int, seq types.Seq) error + // SendRecent sends recent items. SendRecent(since time.Time) error - // ShortenKey shortens the key for minhash calculation - ShortenKey(k Ordered) Ordered + // ShortenKey shortens the key for minhash calculation. + ShortenKey(k types.Ordered) types.Ordered } type RangeSetReconcilerOption func(r *RangeSetReconciler) @@ -360,22 +359,21 @@ func WithMaxDiff(d float64) RangeSetReconcilerOption { } } -// TODO: RangeSetReconciler should sit in a separate package -// and WithRangeReconcilerLogger should be named WithLogger -func WithRangeReconcilerLogger(log *zap.Logger) RangeSetReconcilerOption { +// WithLogger specifies the logger for RangeSetReconciler. +func WithLogger(log *zap.Logger) RangeSetReconcilerOption { return func(r *RangeSetReconciler) { r.log = log } } -// WithRecentTimeSpan specifies the time span for recent items +// WithRecentTimeSpan specifies the time span for recent items. func WithRecentTimeSpan(d time.Duration) RangeSetReconcilerOption { return func(r *RangeSetReconciler) { r.recentTimeSpan = d } } -// Tracer tracks the reconciliation process +// Tracer tracks the reconciliation process. type Tracer interface { // OnDumbSync is called when the difference metric exceeds maxDiff and dumb // reconciliation process is used @@ -396,8 +394,7 @@ func WithTracer(t Tracer) RangeSetReconcilerOption { } } -// TBD: rename -func WithRangeReconcilerClock(c clockwork.Clock) RangeSetReconcilerOption { +func WithClock(c clockwork.Clock) RangeSetReconcilerOption { return func(r *RangeSetReconciler) { r.clock = c } @@ -410,7 +407,7 @@ type ProbeResult struct { } type RangeSetReconciler struct { - is ItemStore + os OrderedSet maxSendRange int itemChunkSize int sampleSize int @@ -421,9 +418,9 @@ type RangeSetReconciler struct { log *zap.Logger } -func NewRangeSetReconciler(is ItemStore, opts ...RangeSetReconcilerOption) *RangeSetReconciler { +func NewRangeSetReconciler(os OrderedSet, opts ...RangeSetReconcilerOption) *RangeSetReconciler { rsr := &RangeSetReconciler{ - is: is, + os: os, maxSendRange: DefaultMaxSendRange, itemChunkSize: DefaultItemChunkSize, sampleSize: DefaultSampleSize, @@ -451,12 +448,13 @@ func NewRangeSetReconciler(is ItemStore, opts ...RangeSetReconcilerOption) *Rang // return fmt.Sprintf("%s", it.Key()) // } -func (rsr *RangeSetReconciler) processSubrange(c Conduit, x, y Ordered, info RangeInfo) error { +func (rsr *RangeSetReconciler) processSubrange(c Conduit, x, y types.Ordered, info RangeInfo) error { // fmt.Fprintf(os.Stderr, "QQQQQ: preceding=%q\n", // qqqqRmmeK(preceding)) // TODO: don't re-request range info for the first part of range after stop rsr.log.Debug("processSubrange", HexField("x", x), HexField("y", y), zap.Int("count", info.Count), HexField("fingerprint", info.Fingerprint)) + // fmt.Fprintf(os.Stderr, "QQQQQ: start=%q end=%q info.Start=%q info.End=%q info.FP=%q x=%q y=%q\n", // qqqqRmmeK(start), qqqqRmmeK(end), qqqqRmmeK(info.Start), qqqqRmmeK(info.End), info.Fingerprint, x, y) switch { @@ -496,15 +494,14 @@ func (rsr *RangeSetReconciler) processSubrange(c Conduit, x, y Ordered, info Ran func (rsr *RangeSetReconciler) splitRange( ctx context.Context, c Conduit, - preceding Iterator, count int, - x, y Ordered, + x, y types.Ordered, ) error { count = count / 2 rsr.log.Debug("handleMessage: PRE split range", HexField("x", x), HexField("y", y), zap.Int("countArg", count)) - si, err := rsr.is.SplitRange(ctx, preceding, x, y, count) + si, err := rsr.os.SplitRange(ctx, x, y, count) if err != nil { return err } @@ -513,12 +510,10 @@ func (rsr *RangeSetReconciler) splitRange( zap.Int("countArg", count), zap.Int("count0", si.Parts[0].Count), HexField("fp0", si.Parts[0].Fingerprint), - IteratorField("start0", si.Parts[0].Start), - IteratorField("end0", si.Parts[0].End), + SeqField("start0", si.Parts[0].Items), zap.Int("count1", si.Parts[1].Count), HexField("fp1", si.Parts[1].Fingerprint), - IteratorField("start1", si.Parts[1].End), - IteratorField("end1", si.Parts[1].End)) + SeqField("start1", si.Parts[1].Items)) if err := rsr.processSubrange(c, x, si.Middle, si.Parts[0]); err != nil { return err } @@ -532,8 +527,8 @@ func (rsr *RangeSetReconciler) splitRange( func (rsr *RangeSetReconciler) sendSmallRange( c Conduit, count int, - it Iterator, - x, y Ordered, + seq types.Seq, + x, y types.Ordered, ) error { if count == 0 { rsr.log.Debug("handleMessage: empty incoming range", @@ -549,49 +544,56 @@ func (rsr *RangeSetReconciler) sendSmallRange( if err := c.SendRangeContents(x, y, count); err != nil { return err } - _, err := rsr.sendItems(c, count, it, nil) + _, err := rsr.sendItems(c, count, seq, nil) return err } func (rsr *RangeSetReconciler) sendItems( c Conduit, count int, - it Iterator, - skipKeys []Ordered, + seq types.Seq, + skipKeys []types.Ordered, ) (int, error) { nSent := 0 skipPos := 0 - for i := 0; i < count; i += rsr.itemChunkSize { - // TBD: do not use chunks, just stream the contentkeys - var keys []Ordered - n := min(rsr.itemChunkSize, count-i) - IN_CHUNK: - for n > 0 { - k, err := it.Key() - if err != nil { - return nSent, err - } - for skipPos < len(skipKeys) { - cmp := k.Compare(skipKeys[skipPos]) - if cmp == 0 { - // we can skip this item. Advance skipPos as there are no duplicates - skipPos++ - continue IN_CHUNK - } - if cmp < 0 { - // current ley is yet to reach the skipped key at skipPos - break - } - // current item is greater than the skipped key at skipPos, - // so skipPos needs to catch up with the iterator + if rsr.itemChunkSize == 0 { + panic("BUG: zero item chunk size") + } + var keys []types.Ordered + n := count + for k, err := range seq { + if err != nil { + return nSent, err + } + for skipPos < len(skipKeys) { + cmp := k.Compare(skipKeys[skipPos]) + if cmp == 0 { + // we can skip this item. Advance skipPos as there are no duplicates skipPos++ + continue + } + if cmp < 0 { + // current ley is yet to reach the skipped key at skipPos + break } - keys = append(keys, k) - if err := it.Next(); err != nil { + // current item is greater than the skipped key at skipPos, + // so skipPos needs to catch up with the iterator + skipPos++ + } + if len(keys) == rsr.itemChunkSize { + if err := c.SendChunk(keys); err != nil { return nSent, err } - n-- + nSent += len(keys) + keys = keys[:0] + } + keys = append(keys, k) + n-- + if n == 0 { + break } + } + if len(keys) != 0 { if err := c.SendChunk(keys); err != nil { return nSent, err } @@ -605,12 +607,10 @@ func (rsr *RangeSetReconciler) sendItems( func (rsr *RangeSetReconciler) handleMessage( ctx context.Context, c Conduit, - preceding Iterator, msg SyncMessage, - receivedKeys []Ordered, + receivedKeys []types.Ordered, ) (done bool, err error) { - rsr.log.Debug("handleMessage", IteratorField("preceding", preceding), - zap.String("msg", SyncMessageToString(msg))) + rsr.log.Debug("handleMessage", zap.String("msg", SyncMessageToString(msg))) x := msg.X() y := msg.Y() done = true @@ -619,22 +619,15 @@ func (rsr *RangeSetReconciler) handleMessage( (msg.Type() == MessageTypeProbe && x == nil && y == nil) { // The peer has no items at all so didn't // even send X & Y (SendEmptySet) - it, err := rsr.is.Min(ctx) - if err != nil { - return false, err - } - if it == nil { + if empty, err := rsr.os.Empty(ctx); err != nil { + return false, fmt.Errorf("checking for empty set: %w", err) + } else if empty { // We don't have any items at all, too if msg.Type() == MessageTypeProbe { - info, err := rsr.is.GetRangeInfo(ctx, preceding, nil, nil, -1) - if err != nil { - return false, err - } - rsr.log.Debug("handleMessage: send probe response", - HexField("fingerpint", info.Fingerprint), - zap.Int("count", info.Count), - IteratorField("it", it)) - if err := c.SendSample(x, y, info.Fingerprint, info.Count, 0, it); err != nil { + rsr.log.Debug("handleMessage: send empty probe response") + if err := c.SendSample( + x, y, types.EmptyFingerprint(), 0, 0, types.EmptySeq(), + ); err != nil { return false, err } } @@ -643,22 +636,25 @@ func (rsr *RangeSetReconciler) handleMessage( } return true, nil } - x, err = it.Key() + items, err := rsr.os.Items(ctx) if err != nil { - return false, err + return false, fmt.Errorf("getting items: %w", err) + } + x, err = items.First() + if err != nil { + return false, fmt.Errorf("getting first item: %w", err) } y = x } else if x == nil || y == nil { return false, fmt.Errorf("bad X or Y in a message of type %s", msg.Type()) } - info, err := rsr.is.GetRangeInfo(ctx, preceding, x, y, -1) + info, err := rsr.os.GetRangeInfo(ctx, x, y, -1) if err != nil { return false, err } rsr.log.Debug("handleMessage: range info", HexField("x", x), HexField("y", y), - IteratorField("start", info.Start), - IteratorField("end", info.End), + SeqField("items", info.Items), zap.Int("count", info.Count), HexField("fingerprint", info.Fingerprint)) @@ -676,8 +672,8 @@ func (rsr *RangeSetReconciler) handleMessage( if info.Count != 0 { done = false rsr.log.Debug("handleMessage: send items", zap.Int("count", info.Count), - IteratorField("start", info.Start)) - if _, err := rsr.sendItems(c, info.Count, info.Start, receivedKeys); err != nil { + SeqField("items", info.Items)) + if _, err := rsr.sendItems(c, info.Count, info.Items, receivedKeys); err != nil { return false, err } } else { @@ -691,20 +687,20 @@ func (rsr *RangeSetReconciler) handleMessage( } else if sampleSize > info.Count { sampleSize = info.Count } - it := info.Start + items := info.Items if fingerprintEqual(msg.Fingerprint(), info.Fingerprint) { // no need to send MinHash items if fingerprints match - it = nil + items = types.EmptySeq() sampleSize = 0 // fmt.Fprintf(os.Stderr, "QQQQQ: fingerprint eq %#v %#v\n", // msg.Fingerprint(), info.Fingerprint) } - if err := c.SendSample(x, y, info.Fingerprint, info.Count, sampleSize, it); err != nil { + if err := c.SendSample(x, y, info.Fingerprint, info.Count, sampleSize, items); err != nil { return false, err } return true, nil case msg.Type() == MessageTypeRecent: - it, count, err := rsr.is.Recent(ctx, msg.Since()) + it, count, err := rsr.os.Recent(ctx, msg.Since()) if err != nil { return false, fmt.Errorf("error getting recent items: %w", err) } @@ -719,20 +715,6 @@ func (rsr *RangeSetReconciler) handleMessage( zap.Int("receivedCount", len(receivedKeys)), zap.Int("sentCount", nSent)) rsr.tracer.OnRecent(len(receivedKeys), nSent) - // if x == nil { - // // FIXME: code duplication - // it, err := rsr.is.Min(ctx) - // if err != nil { - // return false, err - // } - // if it != nil { - // x, err = it.Key() - // if err != nil { - // return false, err - // } - // y = x - // } - // } return false, rsr.initiateBounded(ctx, c, x, y, false) case msg.Type() != MessageTypeFingerprint && msg.Type() != MessageTypeSample: return false, fmt.Errorf("unexpected message type %s", msg.Type()) @@ -751,7 +733,7 @@ func (rsr *RangeSetReconciler) handleMessage( zap.Float64("sim", pr.Sim), zap.Float64("diff", 1-pr.Sim), zap.Float64("maxDiff", rsr.maxDiff)) - if _, err := rsr.sendItems(c, info.Count, info.Start, nil); err != nil { + if _, err := rsr.sendItems(c, info.Count, info.Items, nil); err != nil { return false, err } return false, c.SendRangeContents(x, y, info.Count) @@ -761,34 +743,36 @@ func (rsr *RangeSetReconciler) handleMessage( zap.Float64("diff", 1-pr.Sim), zap.Float64("maxDiff", rsr.maxDiff)) if info.Count > rsr.maxSendRange { - return false, rsr.splitRange(ctx, c, preceding, info.Count, x, y) + return false, rsr.splitRange(ctx, c, info.Count, x, y) } - return false, rsr.sendSmallRange(c, info.Count, info.Start, x, y) + return false, rsr.sendSmallRange(c, info.Count, info.Items, x, y) // case (info.Count+1)/2 <= rsr.maxSendRange: case info.Count <= rsr.maxSendRange: - return false, rsr.sendSmallRange(c, info.Count, info.Start, x, y) + return false, rsr.sendSmallRange(c, info.Count, info.Items, x, y) default: - return false, rsr.splitRange(ctx, c, preceding, info.Count, x, y) + return false, rsr.splitRange(ctx, c, info.Count, x, y) } return done, nil } func (rsr *RangeSetReconciler) Initiate(ctx context.Context, c Conduit) error { - it, err := rsr.is.Min(ctx) - if err != nil { - return err - } - var x Ordered - if it != nil { - x, err = it.Key() + var x types.Ordered + if empty, err := rsr.os.Empty(ctx); err != nil { + return fmt.Errorf("checking for empty set: %w", err) + } else if !empty { + seq, err := rsr.os.Items(ctx) if err != nil { - return err + return fmt.Errorf("error getting items: %w", err) + } + x, err = seq.First() + if err != nil { + return fmt.Errorf("getting first item: %w", err) } } return rsr.InitiateBounded(ctx, c, x, x) } -func (rsr *RangeSetReconciler) InitiateBounded(ctx context.Context, c Conduit, x, y Ordered) error { +func (rsr *RangeSetReconciler) InitiateBounded(ctx context.Context, c Conduit, x, y types.Ordered) error { haveRecent := rsr.recentTimeSpan > 0 if err := rsr.initiateBounded(ctx, c, x, y, haveRecent); err != nil { return err @@ -796,13 +780,13 @@ func (rsr *RangeSetReconciler) InitiateBounded(ctx context.Context, c Conduit, x return c.SendEndRound() } -func (rsr *RangeSetReconciler) initiateBounded(ctx context.Context, c Conduit, x, y Ordered, haveRecent bool) error { +func (rsr *RangeSetReconciler) initiateBounded(ctx context.Context, c Conduit, x, y types.Ordered, haveRecent bool) error { rsr.log.Debug("initiate", HexField("x", x), HexField("y", y)) if x == nil { rsr.log.Debug("initiate: send empty set") return c.SendEmptySet() } - info, err := rsr.is.GetRangeInfo(ctx, nil, x, y, -1) + info, err := rsr.os.GetRangeInfo(ctx, x, y, -1) if err != nil { return fmt.Errorf("get range info: %w", err) } @@ -811,14 +795,14 @@ func (rsr *RangeSetReconciler) initiateBounded(ctx context.Context, c Conduit, x panic("empty full min-min range") case info.Count < rsr.maxSendRange: rsr.log.Debug("initiate: send whole range", zap.Int("count", info.Count)) - if _, err := rsr.sendItems(c, info.Count, info.Start, nil); err != nil { + if _, err := rsr.sendItems(c, info.Count, info.Items, nil); err != nil { return err } return c.SendRangeContents(x, y, info.Count) case haveRecent: rsr.log.Debug("initiate: checking recent items") since := rsr.clock.Now().Add(-rsr.recentTimeSpan) - it, count, err := rsr.is.Recent(ctx, since) + it, count, err := rsr.os.Recent(ctx, since) if err != nil { return fmt.Errorf("error getting recent items: %w", err) } @@ -844,7 +828,7 @@ func (rsr *RangeSetReconciler) initiateBounded(ctx context.Context, c Conduit, x rsr.log.Debug("initiate: send sample", zap.Int("count", info.Count), zap.Int("sampleSize", rsr.sampleSize)) - return c.SendSample(x, y, info.Fingerprint, info.Count, rsr.sampleSize, info.Start) + return c.SendSample(x, y, info.Fingerprint, info.Count, rsr.sampleSize, info.Items) default: rsr.log.Debug("initiate: send fingerprint", zap.Int("count", info.Count)) return c.SendFingerprint(x, y, info.Fingerprint, info.Count) @@ -879,9 +863,9 @@ func (rsr *RangeSetReconciler) InitiateProbe(ctx context.Context, c Conduit) (Ra func (rsr *RangeSetReconciler) InitiateBoundedProbe( ctx context.Context, c Conduit, - x, y Ordered, + x, y types.Ordered, ) (RangeInfo, error) { - info, err := rsr.is.GetRangeInfo(ctx, nil, x, y, -1) + info, err := rsr.os.GetRangeInfo(ctx, x, y, -1) if err != nil { return RangeInfo{}, err } @@ -895,34 +879,36 @@ func (rsr *RangeSetReconciler) InitiateBoundedProbe( return info, nil } -func (rsr *RangeSetReconciler) calcSim(c Conduit, info RangeInfo, remoteSample []Ordered, fp any) (float64, error) { +func (rsr *RangeSetReconciler) calcSim(c Conduit, info RangeInfo, remoteSample []types.Ordered, fp any) (float64, error) { if fingerprintEqual(info.Fingerprint, fp) { return 1, nil } - if info.Start == nil { + if info.Count == 0 { return 0, nil } // for n, k := range remoteSample { // fmt.Fprintf(os.Stderr, "QQQQQ: remoteSample[%d] = %s\n", n, k.(MinhashSampleItem).String()) // } sampleSize := min(info.Count, rsr.sampleSize) - localSample := make([]Ordered, sampleSize) - it := info.Start - for n := 0; n < sampleSize; n++ { - k, err := it.Key() - if err != nil { - return 0, err - } - localSample[n] = c.ShortenKey(k) - // fmt.Fprintf(os.Stderr, "QQQQQ: n %d sampleSize %d info.Count %d rsr.sampleSize %d -- %s -> %s\n", - // n, sampleSize, info.Count, rsr.sampleSize, k.(types.Hash32).String(), - // localSample[n].(MinhashSampleItem).String()) - if err := it.Next(); err != nil { - return 0, err + localSample := make([]types.Ordered, sampleSize) + if sampleSize > 0 { + n := 0 + for k, err := range info.Items { + if err != nil { + return 0, err + } + localSample[n] = c.ShortenKey(k) + // fmt.Fprintf(os.Stderr, "QQQQQ: n %d sampleSize %d info.Count %d rsr.sampleSize %d -- %s -> %s\n", + // n, sampleSize, info.Count, rsr.sampleSize, k.(types.Hash32).String(), + // localSample[n].(MinhashSampleItem).String()) + n++ + if n == sampleSize { + break + } } } - slices.SortFunc(remoteSample, func(a, b Ordered) int { return a.Compare(b) }) - slices.SortFunc(localSample, func(a, b Ordered) int { return a.Compare(b) }) + slices.SortFunc(remoteSample, func(a, b types.Ordered) int { return a.Compare(b) }) + slices.SortFunc(localSample, func(a, b types.Ordered) int { return a.Compare(b) }) numEq := 0 for m, n := 0, 0; m < len(localSample) && n < len(remoteSample); { @@ -1013,9 +999,9 @@ func (rsr *RangeSetReconciler) HandleProbeResponse(c Conduit, info RangeInfo) (p func (rsr *RangeSetReconciler) Process(ctx context.Context, c Conduit) (done bool, err error) { var msgs []SyncMessage - // All of the messages need to be received before processing - // them, as processing the messages involves sending more - // messages back to the peer + // All of the round's messages need to be received before processing them, as + // processing the messages involves sending more messages back to the peer. + // TODO: use proper goroutines in the wireConduit to deal with send/recv blocking. msgs, done, err = rsr.getMessages(c) if done { // items already added @@ -1025,13 +1011,13 @@ func (rsr *RangeSetReconciler) Process(ctx context.Context, c Conduit) (done boo return done, nil } done = true - var receivedKeys []Ordered + var receivedKeys []types.Ordered for _, msg := range msgs { if msg.Type() == MessageTypeItemBatch { for _, k := range msg.Keys() { rsr.log.Debug("Process: add item", HexField("item", k)) - if err := rsr.is.Add(ctx, k); err != nil { - return false, fmt.Errorf("error adding an item to the store: %w", err) + if err := rsr.os.Add(ctx, k); err != nil { + return false, fmt.Errorf("error adding an item to the set: %w", err) } receivedKeys = append(receivedKeys, k) } @@ -1048,7 +1034,7 @@ func (rsr *RangeSetReconciler) Process(ctx context.Context, c Conduit) (done boo // breaks if we capture the iterator from handleMessage and // pass it to the next handleMessage call (it shouldn't) var msgDone bool - msgDone, err = rsr.handleMessage(ctx, c, nil, msg, receivedKeys) + msgDone, err = rsr.handleMessage(ctx, c, msg, receivedKeys) if !msgDone { done = false } @@ -1077,69 +1063,28 @@ func fingerprintEqual(a, b any) bool { return reflect.DeepEqual(a, b) } -type IterEntry[T Ordered] struct { - V T - Err error -} - -func IterItems[T Ordered](ctx context.Context, is ItemStore) iter.Seq2[T, error] { - return iter.Seq2[T, error](func(yield func(T, error) bool) { - var empty T - ctx := context.Background() - it, err := is.Min(ctx) - if err != nil { - yield(empty, err) - return - } - if it == nil { - return - } - k, err := it.Key() - if err != nil { - yield(empty, err) - return - } - info, err := is.GetRangeInfo(ctx, nil, k, k, -1) - if err != nil { - yield(empty, err) - return - } - it, err = is.Min(ctx) - if err != nil { - yield(empty, err) - return - } - for n := 0; n < info.Count; n++ { - k, err := it.Key() - if err != nil { - yield(empty, err) - return - } - if k == nil { - // fmt.Fprintf(os.Stderr, "QQQQQ: it: %#v\n", it) - panic("BUG: iterator exausted before Count reached") - } - yield(k.(T), nil) - if err := it.Next(); err != nil { - yield(empty, err) - return - } - } - }) -} - -// CollectStoreItems returns the list of items in the given store -func CollectStoreItems[T Ordered](ctx context.Context, is ItemStore) (r []T, err error) { - for v, err := range IterItems[T](ctx, is) { +// CollectSetItems returns the list of items in the given set +func CollectSetItems[T types.Ordered](ctx context.Context, os OrderedSet) (r []T, err error) { + items, err := os.Items(ctx) + if err != nil { + return nil, err + } + var first types.Ordered + for v, err := range items { if err != nil { return nil, err } - r = append(r, v) + if first == nil { + first = v + } else if v.Compare(first) == 0 { + break + } + r = append(r, v.(T)) } return r, nil } -// TBD: test: add items to the store even in case of NextMessage() failure +// TBD: test: add items to the set even in case of NextMessage() failure // TBD: !!! use wire types instead of multiple Send* methods in the Conduit interface !!! // TBD: !!! queue outbound messages right in RangeSetReconciler while processing msgs, and no need for done in handleMessage this way ++ no need for complicated logic on the conduit part !!! // TBD: !!! check that done message present !!! diff --git a/sync2/rangesync/rangesync_test.go b/sync2/rangesync/rangesync_test.go new file mode 100644 index 0000000000..d68c7811be --- /dev/null +++ b/sync2/rangesync/rangesync_test.go @@ -0,0 +1,612 @@ +package rangesync + +import ( + "context" + "math/rand" + "slices" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + "golang.org/x/exp/maps" + + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +type rangeMessage struct { + mtype MessageType + x, y types.Ordered + fp types.Fingerprint + count int + keys []types.Ordered + since time.Time +} + +var _ SyncMessage = rangeMessage{} + +func (m rangeMessage) Type() MessageType { return m.mtype } +func (m rangeMessage) X() types.Ordered { return m.x } +func (m rangeMessage) Y() types.Ordered { return m.y } +func (m rangeMessage) Fingerprint() types.Fingerprint { return m.fp } +func (m rangeMessage) Count() int { return m.count } +func (m rangeMessage) Keys() []types.Ordered { return m.keys } +func (m rangeMessage) Since() time.Time { return m.since } + +func (m rangeMessage) String() string { + return SyncMessageToString(m) +} + +// fakeConduit is a fake Conduit for testing purposes that connects two +// RangeSetReconcilers together without any network connection, and makes it easier to see +// which messages are being sent and received. +type fakeConduit struct { + t *testing.T + msgs []rangeMessage + resp []rangeMessage +} + +var _ Conduit = &fakeConduit{} + +func (fc *fakeConduit) gotoResponse() { + fc.msgs = fc.resp + fc.resp = nil +} + +func (fc *fakeConduit) numItems() int { + n := 0 + for _, m := range fc.msgs { + n += len(m.Keys()) + } + return n +} + +func (fc *fakeConduit) NextMessage() (SyncMessage, error) { + if len(fc.msgs) != 0 { + m := fc.msgs[0] + fc.msgs = fc.msgs[1:] + return m, nil + } + + return nil, nil +} + +func (fc *fakeConduit) sendMsg(msg rangeMessage) { + fc.resp = append(fc.resp, msg) +} + +func (fc *fakeConduit) SendFingerprint(x, y types.Ordered, fingerprint types.Fingerprint, count int) error { + require.NotNil(fc.t, x) + require.NotNil(fc.t, y) + require.NotZero(fc.t, count) + require.NotNil(fc.t, fingerprint) + fc.sendMsg(rangeMessage{ + mtype: MessageTypeFingerprint, + x: x, + y: y, + fp: fingerprint, + count: count, + }) + return nil +} + +func (fc *fakeConduit) SendEmptySet() error { + fc.sendMsg(rangeMessage{mtype: MessageTypeEmptySet}) + return nil +} + +func (fc *fakeConduit) SendEmptyRange(x, y types.Ordered) error { + require.NotNil(fc.t, x) + require.NotNil(fc.t, y) + fc.sendMsg(rangeMessage{ + mtype: MessageTypeEmptyRange, + x: x, + y: y, + }) + return nil +} + +func (fc *fakeConduit) SendRangeContents(x, y types.Ordered, count int) error { + require.NotNil(fc.t, x) + require.NotNil(fc.t, y) + fc.sendMsg(rangeMessage{ + mtype: MessageTypeRangeContents, + x: x, + y: y, + count: count, + }) + return nil +} + +func (fc *fakeConduit) SendChunk(items []types.Ordered) error { + require.NotEmpty(fc.t, items) + fc.sendMsg(rangeMessage{ + mtype: MessageTypeItemBatch, + keys: slices.Clone(items), + }) + return nil +} + +func (fc *fakeConduit) SendEndRound() error { + fc.sendMsg(rangeMessage{mtype: MessageTypeEndRound}) + return nil +} + +func (fc *fakeConduit) SendDone() error { + fc.sendMsg(rangeMessage{mtype: MessageTypeDone}) + return nil +} + +func (fc *fakeConduit) SendProbe(x, y types.Ordered, fingerprint types.Fingerprint, sampleSize int) error { + fc.sendMsg(rangeMessage{ + mtype: MessageTypeProbe, + x: x, + y: y, + fp: fingerprint, + count: sampleSize, + }) + return nil +} + +func (fc *fakeConduit) SendSample( + x, y types.Ordered, + fingerprint types.Fingerprint, + count, sampleSize int, + seq types.Seq, +) error { + msg := rangeMessage{ + mtype: MessageTypeSample, + x: x, + y: y, + fp: fingerprint, + count: count, + keys: make([]types.Ordered, sampleSize), + } + n := 0 + for k, err := range seq { + require.NoError(fc.t, err) + require.NotNil(fc.t, k) + msg.keys[n] = k + n++ + if n == sampleSize { + break + } + } + fc.sendMsg(msg) + return nil +} + +func (fc *fakeConduit) SendRecent(since time.Time) error { + fc.sendMsg(rangeMessage{ + mtype: MessageTypeRecent, + since: since, + }) + return nil +} + +func (fc *fakeConduit) ShortenKey(k types.Ordered) types.Ordered { + return k +} + +func makeSet(t *testing.T, items string) *dumbSet { + var s dumbSet + for _, c := range []byte(items) { + require.NoError(t, s.Add(context.Background(), types.KeyBytes{c})) + } + return &s +} + +func setStr(os OrderedSet) string { + ids, err := CollectSetItems[types.KeyBytes](context.Background(), os) + if err != nil { + panic("set error: " + err.Error()) + } + var r strings.Builder + for _, id := range ids { + r.Write(id[:1]) + } + return r.String() +} + +// NOTE: when enabled, this produces A LOT of output during tests (116k+ lines), which +// may be too much if you run the tests in the verbose mode. +// But it's useful for debugging and understanding how sync works, so it's left here for +// now. +var showMessages = false + +func dumpRangeMessages(t *testing.T, msgs []rangeMessage, fmt string, args ...any) { + if !showMessages { + return + } + t.Logf(fmt, args...) + for _, m := range msgs { + t.Logf(" %s", m) + } +} + +func runSync(t *testing.T, syncA, syncB *RangeSetReconciler, maxRounds int) (nRounds, nMsg, nItems int) { + fc := &fakeConduit{t: t} + require.NoError(t, syncA.Initiate(context.Background(), fc)) + return doRunSync(fc, syncA, syncB, maxRounds) +} + +func runBoundedSync( + t *testing.T, + syncA, syncB *RangeSetReconciler, + x, y types.Ordered, + maxRounds int, +) (nRounds, nMsg, nItems int) { + fc := &fakeConduit{t: t} + require.NoError(t, syncA.InitiateBounded(context.Background(), fc, x, y)) + return doRunSync(fc, syncA, syncB, maxRounds) +} + +func doRunSync(fc *fakeConduit, syncA, syncB *RangeSetReconciler, maxRounds int) (nRounds, nMsg, nItems int) { + var i int + aDone, bDone := false, false + dumpRangeMessages(fc.t, fc.resp, "A %q -> B %q (init):", setStr(syncA.os), setStr(syncB.os)) + dumpRangeMessages(fc.t, fc.resp, "A -> B (init):") + for i = 0; ; i++ { + if i == maxRounds { + require.FailNow(fc.t, "too many rounds", "didn't reconcile in %d rounds", i) + } + fc.gotoResponse() + nMsg += len(fc.msgs) + nItems += fc.numItems() + var err error + bDone, err = syncB.Process(context.Background(), fc) + require.NoError(fc.t, err) + // a party should never send anything in response to the "done" message + require.False(fc.t, aDone && !bDone, "A is done but B after that is not") + dumpRangeMessages(fc.t, fc.resp, "B %q -> A %q:", setStr(syncA.os), setStr(syncB.os)) + dumpRangeMessages(fc.t, fc.resp, "B -> A:") + if aDone && bDone { + require.Empty(fc.t, fc.resp, "got messages from B in response to done msg from A") + break + } + fc.gotoResponse() + nMsg += len(fc.msgs) + nItems += fc.numItems() + aDone, err = syncA.Process(context.Background(), fc) + require.NoError(fc.t, err) + dumpRangeMessages(fc.t, fc.msgs, "A %q --> B %q:", setStr(syncB.os), setStr(syncA.os)) + dumpRangeMessages(fc.t, fc.resp, "A -> B:") + require.False(fc.t, bDone && !aDone, "B is done but A after that is not") + if aDone && bDone { + require.Empty(fc.t, fc.resp, "got messages from A in response to done msg from B") + break + } + } + return i + 1, nMsg, nItems +} + +func runProbe(t *testing.T, from, to *RangeSetReconciler) ProbeResult { + fc := &fakeConduit{t: t} + info, err := from.InitiateProbe(context.Background(), fc) + require.NoError(t, err) + return doRunProbe(fc, from, to, info) +} + +func runBoundedProbe(t *testing.T, from, to *RangeSetReconciler, x, y types.Ordered) ProbeResult { + fc := &fakeConduit{t: t} + info, err := from.InitiateBoundedProbe(context.Background(), fc, x, y) + require.NoError(t, err) + return doRunProbe(fc, from, to, info) +} + +func doRunProbe(fc *fakeConduit, from, to *RangeSetReconciler, info RangeInfo) ProbeResult { + require.NotEmpty(fc.t, fc.resp, "empty initial round") + fc.gotoResponse() + done, err := to.Process(context.Background(), fc) + require.True(fc.t, done) + require.NoError(fc.t, err) + fc.gotoResponse() + pr, err := from.HandleProbeResponse(fc, info) + require.NoError(fc.t, err) + require.Nil(fc.t, fc.resp, "got messages from Probe in response to done msg") + return pr +} + +func TestRangeSync(t *testing.T) { + for _, tc := range []struct { + name string + a, b string + finalA, finalB string + x, y string + countA, countB int + fpA, fpB types.Fingerprint + maxRounds [4]int + sim float64 + }{ + { + name: "empty sets", + a: "", + b: "", + finalA: "", + finalB: "", + countA: 0, + countB: 0, + fpA: types.EmptyFingerprint(), + fpB: types.EmptyFingerprint(), + maxRounds: [4]int{1, 1, 1, 1}, + sim: 1, + }, + { + name: "empty to non-empty", + a: "", + b: "abcd", + finalA: "abcd", + finalB: "abcd", + countA: 0, + countB: 4, + fpA: types.EmptyFingerprint(), + fpB: stringToFP("abcd"), + maxRounds: [4]int{2, 2, 2, 2}, + sim: 0, + }, + { + name: "non-empty to empty", + a: "abcd", + b: "", + finalA: "abcd", + finalB: "abcd", + countA: 4, + countB: 0, + fpA: stringToFP("abcd"), + fpB: types.EmptyFingerprint(), + maxRounds: [4]int{2, 2, 2, 2}, + sim: 0, + }, + { + name: "non-intersecting sets", + a: "ab", + b: "cd", + finalA: "abcd", + finalB: "abcd", + countA: 2, + countB: 2, + fpA: stringToFP("ab"), + fpB: stringToFP("cd"), + maxRounds: [4]int{3, 2, 2, 2}, + sim: 0, + }, + { + name: "intersecting sets", + a: "acdefghijklmn", + b: "bcdopqr", + finalA: "abcdefghijklmnopqr", + finalB: "abcdefghijklmnopqr", + countA: 13, + countB: 7, + fpA: stringToFP("acdefghijklmn"), + fpB: stringToFP("bcdopqr"), + maxRounds: [4]int{4, 4, 3, 3}, + sim: 0.153, + }, + { + name: "bounded reconciliation", + a: "acdefghijklmn", + b: "bcdopqr", + finalA: "abcdefghijklmn", + finalB: "abcdefgopqr", + x: "a", + y: "h", + countA: 6, + countB: 3, + fpA: stringToFP("acdefg"), + fpB: stringToFP("bcd"), + maxRounds: [4]int{3, 3, 2, 2}, + sim: 0.333, + }, + { + name: "bounded reconciliation with rollover", + a: "acdefghijklmn", + b: "bcdopqr", + finalA: "acdefghijklmnopqr", + finalB: "bcdhijklmnopqr", + x: "h", + y: "a", + countA: 7, + countB: 4, + fpA: stringToFP("hijklmn"), + fpB: stringToFP("opqr"), + maxRounds: [4]int{4, 3, 3, 2}, + sim: 0, + }, + { + name: "sync against 1-element set", + a: "bcd", + b: "a", + finalA: "abcd", + finalB: "abcd", + countA: 3, + countB: 1, + fpA: stringToFP("bcd"), + fpB: stringToFP("a"), + maxRounds: [4]int{2, 2, 2, 2}, + sim: 0, + }, + } { + t.Run(tc.name, func(t *testing.T) { + logger := zaptest.NewLogger(t) + for n, maxSendRange := range []int{1, 2, 3, 4} { + t.Logf("maxSendRange: %d", maxSendRange) + setA := makeSet(t, tc.a) + setA.disableReAdd = true + syncA := NewRangeSetReconciler(setA, + WithLogger(logger.Named("A")), + WithMaxSendRange(maxSendRange), + WithItemChunkSize(3)) + setB := makeSet(t, tc.b) + setB.disableReAdd = true + syncB := NewRangeSetReconciler(setB, + WithLogger(logger.Named("B")), + WithMaxSendRange(maxSendRange), + WithItemChunkSize(3)) + + var ( + nRounds int + prBA, prAB ProbeResult + ) + if tc.x == "" { + prBA = runProbe(t, syncB, syncA) + prAB = runProbe(t, syncA, syncB) + nRounds, _, _ = runSync(t, syncA, syncB, tc.maxRounds[n]) + } else { + x := types.KeyBytes(tc.x) + y := types.KeyBytes(tc.y) + prBA = runBoundedProbe(t, syncB, syncA, x, y) + prAB = runBoundedProbe(t, syncA, syncB, x, y) + nRounds, _, _ = runBoundedSync(t, syncA, syncB, x, y, tc.maxRounds[n]) + } + t.Logf("%s: maxSendRange %d: %d rounds", tc.name, maxSendRange, nRounds) + + require.Equal(t, tc.countA, prBA.Count, "countA") + require.Equal(t, tc.countB, prAB.Count, "countB") + require.Equal(t, tc.fpA, prBA.FP, "fpA") + require.Equal(t, tc.fpB, prAB.FP, "fpB") + require.Equal(t, tc.finalA, setStr(setA), "finalA") + require.Equal(t, tc.finalB, setStr(setB), "finalB") + require.InDelta(t, tc.sim, prAB.Sim, 0.01, "prAB.Sim") + require.InDelta(t, tc.sim, prBA.Sim, 0.01, "prBA.Sim") + } + }) + } +} + +func TestRandomSync(t *testing.T) { + var bytesA, bytesB []byte + defer func() { + if t.Failed() { + t.Logf("Random sync failed: %q <-> %q", bytesA, bytesB) + } + }() + for i := 0; i < 1000; i++ { + var chars []byte + for c := byte(33); c < 127; c++ { + chars = append(chars, c) + } + + bytesA = append([]byte(nil), chars...) + rand.Shuffle(len(bytesA), func(i, j int) { + bytesA[i], bytesA[j] = bytesA[j], bytesA[i] + }) + bytesA = bytesA[:rand.Intn(len(bytesA))] + setA := makeSet(t, string(bytesA)) + + bytesB = append([]byte(nil), chars...) + rand.Shuffle(len(bytesB), func(i, j int) { + bytesB[i], bytesB[j] = bytesB[j], bytesB[i] + }) + bytesB = bytesB[:rand.Intn(len(bytesB))] + setB := makeSet(t, string(bytesB)) + + keySet := make(map[byte]struct{}) + for _, c := range append(bytesA, bytesB...) { + keySet[byte(c)] = struct{}{} + } + + expectedSet := maps.Keys(keySet) + slices.Sort(expectedSet) + + maxSendRange := rand.Intn(16) + 1 + syncA := NewRangeSetReconciler(setA, + WithMaxSendRange(maxSendRange), + WithItemChunkSize(3)) + syncB := NewRangeSetReconciler(setB, + WithMaxSendRange(maxSendRange), + WithItemChunkSize(3)) + + runSync(t, syncA, syncB, max(len(expectedSet), 2)) // FIXME: less rounds! + // t.Logf("maxSendRange %d a %d b %d n %d", maxSendRange, len(bytesA), len(bytesB), n) + require.Equal(t, setStr(setA), setStr(setB)) + require.Equal(t, string(expectedSet), setStr(setA), + "expected set for %q<->%q", bytesA, bytesB) + } +} + +type hashSyncTestConfig struct { + maxSendRange int + numTestHashes int + minNumSpecificA int + maxNumSpecificA int + minNumSpecificB int + maxNumSpecificB int + allowReAdd bool +} + +type hashSyncTester struct { + t *testing.T + cfg hashSyncTestConfig + src []types.KeyBytes + setA, setB OrderedSet + opts []RangeSetReconcilerOption + numSpecificA int + numSpecificB int +} + +func newHashSyncTester(t *testing.T, cfg hashSyncTestConfig) *hashSyncTester { + st := &hashSyncTester{ + t: t, + cfg: cfg, + src: make([]types.KeyBytes, cfg.numTestHashes), + opts: []RangeSetReconcilerOption{ + WithMaxSendRange(cfg.maxSendRange), + WithMaxDiff(0.1), + }, + numSpecificA: rand.Intn(cfg.maxNumSpecificA+1-cfg.minNumSpecificA) + cfg.minNumSpecificA, + numSpecificB: rand.Intn(cfg.maxNumSpecificB+1-cfg.minNumSpecificB) + cfg.minNumSpecificB, + } + + for n := range st.src { + st.src[n] = types.RandomKeyBytes(32) + } + + sliceA := st.src[:cfg.numTestHashes-st.numSpecificB] + st.setA = NewDumbHashSet(!cfg.allowReAdd) + for _, h := range sliceA { + require.NoError(t, st.setA.Add(context.Background(), h)) + } + + sliceB := slices.Clone(st.src[:cfg.numTestHashes-st.numSpecificB-st.numSpecificA]) + sliceB = append(sliceB, st.src[cfg.numTestHashes-st.numSpecificB:]...) + st.setB = NewDumbHashSet(!cfg.allowReAdd) + for _, h := range sliceB { + require.NoError(t, st.setB.Add(context.Background(), h)) + } + + slices.SortFunc(st.src, func(a, b types.KeyBytes) int { + return a.Compare(b) + }) + + return st +} + +func (st *hashSyncTester) verify() { + itemsA, err := CollectSetItems[types.KeyBytes](context.Background(), st.setA) + require.NoError(st.t, err) + itemsB, err := CollectSetItems[types.KeyBytes](context.Background(), st.setB) + require.NoError(st.t, err) + require.Equal(st.t, itemsA, itemsB) + require.Equal(st.t, st.src, itemsA) +} + +func TestSyncHash(t *testing.T) { + st := newHashSyncTester(t, hashSyncTestConfig{ + maxSendRange: 1, + numTestHashes: 10000, + minNumSpecificA: 4, + maxNumSpecificA: 100, + minNumSpecificB: 4, + maxNumSpecificB: 100, + }) + syncA := NewRangeSetReconciler(st.setA, st.opts...) + syncB := NewRangeSetReconciler(st.setB, st.opts...) + nRounds, nMsg, nItems := runSync(t, syncA, syncB, 100) + numSpecific := st.numSpecificA + st.numSpecificB + itemCoef := float64(nItems) / float64(numSpecific) + t.Logf("numSpecific: %d, nRounds: %d, nMsg: %d, nItems: %d, itemCoef: %.2f", + numSpecific, nRounds, nMsg, nItems, itemCoef) + st.verify() +} diff --git a/sync2/rangesync/wire_conduit.go b/sync2/rangesync/wire_conduit.go new file mode 100644 index 0000000000..c6aa8c2081 --- /dev/null +++ b/sync2/rangesync/wire_conduit.go @@ -0,0 +1,232 @@ +package rangesync + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "time" + + "github.com/spacemeshos/go-spacemesh/codec" + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +type sendable interface { + codec.Encodable + Type() MessageType +} + +type wireConduit struct { + stream io.ReadWriter + initReqBuf *bytes.Buffer +} + +var _ Conduit = &wireConduit{} + +func (c *wireConduit) NextMessage() (SyncMessage, error) { + var b [1]byte + if _, err := io.ReadFull(c.stream, b[:]); err != nil { + if !errors.Is(err, io.EOF) { + return nil, err + } + return nil, nil + } + mtype := MessageType(b[0]) + switch mtype { + case MessageTypeDone: + return &DoneMessage{}, nil + case MessageTypeEndRound: + return &EndRoundMessage{}, nil + case MessageTypeItemBatch: + var m ItemBatchMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + return &m, nil + case MessageTypeEmptySet: + return &EmptySetMessage{}, nil + case MessageTypeEmptyRange: + var m EmptyRangeMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + return &m, nil + case MessageTypeFingerprint: + var m FingerprintMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + return &m, nil + case MessageTypeRangeContents: + var m RangeContentsMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + return &m, nil + case MessageTypeProbe: + var m ProbeMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + return &m, nil + case MessageTypeSample: + var m SampleMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + return &m, nil + case MessageTypeRecent: + var m RecentMessage + if _, err := codec.DecodeFrom(c.stream, &m); err != nil { + return nil, err + } + return &m, nil + default: + return nil, fmt.Errorf("invalid message code %02x", b[0]) + } +} + +func (c *wireConduit) send(m sendable) error { + var stream io.Writer + if c.initReqBuf != nil { + stream = c.initReqBuf + } else if c.stream == nil { + panic("BUG: wireConduit: no stream") + } else { + stream = c.stream + } + b := []byte{byte(m.Type())} + if _, err := stream.Write(b); err != nil { + return err + } + _, err := codec.EncodeTo(stream, m) + return err +} + +func (c *wireConduit) SendFingerprint(x, y types.Ordered, fp types.Fingerprint, count int) error { + return c.send(&FingerprintMessage{ + RangeX: OrderedToCompactHash(x), + RangeY: OrderedToCompactHash(y), + RangeFingerprint: fp, + NumItems: uint32(count), + }) +} + +func (c *wireConduit) SendEmptySet() error { + return c.send(&EmptySetMessage{}) +} + +func (c *wireConduit) SendEmptyRange(x, y types.Ordered) error { + return c.send(&EmptyRangeMessage{ + RangeX: OrderedToCompactHash(x), + RangeY: OrderedToCompactHash(y), + }) +} + +func (c *wireConduit) SendRangeContents(x, y types.Ordered, count int) error { + return c.send(&RangeContentsMessage{ + RangeX: OrderedToCompactHash(x), + RangeY: OrderedToCompactHash(y), + NumItems: uint32(count), + }) +} + +func (c *wireConduit) SendChunk(items []types.Ordered) error { + msg := ItemBatchMessage{ + ContentKeys: KeyCollection{ + Keys: make([]types.KeyBytes, len(items)), + }, + } + for n, k := range items { + msg.ContentKeys.Keys[n] = k.(types.KeyBytes) + } + return c.send(&msg) +} + +func (c *wireConduit) SendEndRound() error { + return c.send(&EndRoundMessage{}) +} + +func (c *wireConduit) SendDone() error { + return c.send(&DoneMessage{}) +} + +func (c *wireConduit) SendProbe(x, y types.Ordered, fp types.Fingerprint, sampleSize int) error { + m := &ProbeMessage{ + RangeFingerprint: fp, + SampleSize: uint32(sampleSize), + } + if x == nil && y == nil { + return c.send(m) + } else if x == nil || y == nil { + panic("BUG: SendProbe: bad range: just one of the bounds is nil") + } + m.RangeX = OrderedToCompactHash(x) + m.RangeY = OrderedToCompactHash(y) + return c.send(m) +} + +func (c *wireConduit) SendSample( + x, y types.Ordered, + fp types.Fingerprint, + count, sampleSize int, + seq types.Seq, +) error { + m := &SampleMessage{ + RangeFingerprint: fp, + NumItems: uint32(count), + Sample: make([]MinhashSampleItem, sampleSize), + } + n := 0 + for k, err := range seq { + if err != nil { + return err + } + m.Sample[n] = MinhashSampleItemFromKeyBytes(k.(types.KeyBytes)) + n++ + if n == sampleSize { + break + } + } + if x == nil && y == nil { + return c.send(m) + } else if x == nil || y == nil { + panic("BUG: SendProbe: bad range: just one of the bounds is nil") + } + m.RangeX = OrderedToCompactHash(x) + m.RangeY = OrderedToCompactHash(y) + return c.send(m) +} + +func (c *wireConduit) SendRecent(since time.Time) error { + return c.send(&RecentMessage{ + SinceTime: uint64(since.UnixNano()), + }) +} + +func (c *wireConduit) withInitialRequest(toCall func(Conduit) error) ([]byte, error) { + c.initReqBuf = new(bytes.Buffer) + defer func() { c.initReqBuf = nil }() + if err := toCall(c); err != nil { + return nil, err + } + return c.initReqBuf.Bytes(), nil +} + +func (c *wireConduit) handleStream(ctx context.Context, stream io.ReadWriter, rsr *RangeSetReconciler) error { + c.stream = stream + for { + // Process() will receive all items and messages from the peer + syncDone, err := rsr.Process(ctx, c) + if err != nil { + return err + } else if syncDone { + return nil + } + } +} + +func (c *wireConduit) ShortenKey(k types.Ordered) types.Ordered { + return MinhashSampleItemFromKeyBytes(k.(types.KeyBytes)) +} diff --git a/sync2/rangesync/wire_conduit_test.go b/sync2/rangesync/wire_conduit_test.go new file mode 100644 index 0000000000..d3cb9a066d --- /dev/null +++ b/sync2/rangesync/wire_conduit_test.go @@ -0,0 +1,315 @@ +package rangesync + +import ( + "bytes" + "context" + "fmt" + "io" + "slices" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/p2p/server" + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +type incomingRequest struct { + initialRequest []byte + stream io.ReadWriter +} + +type fakeRequester struct { + id p2p.Peer + handler server.StreamHandler + peers map[p2p.Peer]*fakeRequester + reqCh chan incomingRequest +} + +var _ Requester = &fakeRequester{} + +func newFakeRequester(id p2p.Peer, handler server.StreamHandler, peers ...Requester) *fakeRequester { + fr := &fakeRequester{ + id: id, + handler: handler, + reqCh: make(chan incomingRequest), + peers: make(map[p2p.Peer]*fakeRequester), + } + for _, p := range peers { + pfr := p.(*fakeRequester) + fr.peers[pfr.id] = pfr + } + return fr +} + +func (fr *fakeRequester) Run(ctx context.Context) error { + if fr.handler == nil { + panic("no handler") + } + for { + var req incomingRequest + select { + case <-ctx.Done(): + return nil + case req = <-fr.reqCh: + } + if err := fr.handler(ctx, req.initialRequest, req.stream); err != nil { + panic("handler error: " + err.Error()) + } + } +} + +func (fr *fakeRequester) request( + ctx context.Context, + pid p2p.Peer, + initialRequest []byte, + callback server.StreamRequestCallback, +) error { + p, found := fr.peers[pid] + if !found { + return fmt.Errorf("bad peer %q", pid) + } + r, w := io.Pipe() + defer r.Close() + defer w.Close() + stream := struct { + io.Reader + io.Writer + }{ + Reader: r, + Writer: w, + } + select { + case p.reqCh <- incomingRequest{ + initialRequest: initialRequest, + stream: stream, + }: + case <-ctx.Done(): + return ctx.Err() + } + return callback(ctx, stream) +} + +func (fr *fakeRequester) StreamRequest( + ctx context.Context, + pid p2p.Peer, + initialRequest []byte, + callback server.StreamRequestCallback, + extraProtocols ...string, +) error { + return fr.request(ctx, pid, initialRequest, callback) +} + +type fakeSend struct { + x, y types.Ordered + count int + fp types.Fingerprint + items []types.Ordered + endRound bool + done bool +} + +func (fs *fakeSend) send(c Conduit) error { + switch { + case fs.endRound: + return c.SendEndRound() + case fs.done: + return c.SendDone() + case len(fs.items) != 0: + return c.SendChunk(slices.Clone(fs.items)) + case fs.x == nil || fs.y == nil: + return c.SendEmptySet() + case fs.count == 0: + return c.SendEmptyRange(fs.x, fs.y) + case fs.fp != types.EmptyFingerprint(): + return c.SendFingerprint(fs.x, fs.y, fs.fp, fs.count) + default: + return c.SendRangeContents(fs.x, fs.y, fs.count) + } +} + +type fakeRound struct { + name string + expectMsgs []SyncMessage + toSend []*fakeSend +} + +func (r *fakeRound) handleMessages(t *testing.T, c Conduit) error { + var msgs []SyncMessage + for { + msg, err := c.NextMessage() + if err != nil { + return fmt.Errorf("NextMessage(): %w", err) + } else if msg == nil { + break + } + msgs = append(msgs, msg) + if msg.Type() == MessageTypeDone || msg.Type() == MessageTypeEndRound { + break + } + } + require.Equal(t, r.expectMsgs, msgs, "messages for round %q", r.name) + return nil +} + +func (r *fakeRound) handleConversation(t *testing.T, c *wireConduit) error { + if err := r.handleMessages(t, c); err != nil { + return err + } + for _, s := range r.toSend { + if err := s.send(c); err != nil { + return err + } + } + return nil +} + +func makeTestStreamHandler(t *testing.T, c *wireConduit, rounds []fakeRound) server.StreamHandler { + cbk := makeTestRequestCallback(t, c, rounds) + return func(ctx context.Context, initialRequest []byte, stream io.ReadWriter) error { + t.Logf("init request bytes: %d", len(initialRequest)) + s := struct { + io.Reader + io.Writer + }{ + // prepend the received request to data being read + Reader: io.MultiReader(bytes.NewBuffer(initialRequest), stream), + Writer: stream, + } + return cbk(ctx, s) + } +} + +func makeTestRequestCallback(t *testing.T, c *wireConduit, rounds []fakeRound) server.StreamRequestCallback { + return func(ctx context.Context, stream io.ReadWriter) error { + if c == nil { + c = &wireConduit{stream: stream} + } else { + c.stream = stream + } + for _, round := range rounds { + if err := round.handleConversation(t, c); err != nil { + return err + } + } + return nil + } +} + +func TestWireConduit(t *testing.T) { + hs := make([]types.KeyBytes, 16) + for n := range hs { + hs[n] = types.RandomKeyBytes(32) + } + fp := types.Fingerprint(hs[2][:12]) + srvHandler := makeTestStreamHandler(t, nil, []fakeRound{ + { + name: "server got 1st request", + expectMsgs: []SyncMessage{ + &FingerprintMessage{ + RangeX: KeyBytesToCompact(hs[0]), + RangeY: KeyBytesToCompact(hs[1]), + RangeFingerprint: fp, + NumItems: 4, + }, + &EndRoundMessage{}, + }, + toSend: []*fakeSend{ + { + x: hs[0], + y: hs[3], + count: 2, + }, + { + x: hs[3], + y: hs[6], + count: 2, + }, + { + items: []types.Ordered{hs[4], hs[5], hs[7], hs[8]}, + }, + { + endRound: true, + }, + }, + }, + { + name: "server got 2nd request", + expectMsgs: []SyncMessage{ + &ItemBatchMessage{ + ContentKeys: KeyCollection{ + Keys: []types.KeyBytes{hs[9], hs[10], hs[11]}, + }, + }, + &EndRoundMessage{}, + }, + toSend: []*fakeSend{ + { + done: true, + }, + }, + }, + }) + + srv := newFakeRequester("srv", srvHandler) + var eg errgroup.Group + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + cancel() + eg.Wait() + }() + eg.Go(func() error { + return srv.Run(ctx) + }) + + client := newFakeRequester("client", nil, srv) + var c wireConduit + initReq, err := c.withInitialRequest(func(c Conduit) error { + if err := c.SendFingerprint(hs[0], hs[1], fp, 4); err != nil { + return err + } + return c.SendEndRound() + }) + require.NoError(t, err) + clientCbk := makeTestRequestCallback(t, &c, []fakeRound{ + { + name: "client got 1st response", + expectMsgs: []SyncMessage{ + &RangeContentsMessage{ + RangeX: KeyBytesToCompact(hs[0]), + RangeY: KeyBytesToCompact(hs[3]), + NumItems: 2, + }, + &RangeContentsMessage{ + RangeX: KeyBytesToCompact(hs[3]), + RangeY: KeyBytesToCompact(hs[6]), + NumItems: 2, + }, + &ItemBatchMessage{ + ContentKeys: KeyCollection{ + Keys: []types.KeyBytes{hs[4], hs[5], hs[7], hs[8]}, + }, + }, + &EndRoundMessage{}, + }, + toSend: []*fakeSend{ + { + items: []types.Ordered{hs[9], hs[10], hs[11]}, + }, + { + endRound: true, + }, + }, + }, + { + name: "client got 2nd response", + expectMsgs: []SyncMessage{ + &DoneMessage{}, + }, + }, + }) + err = client.StreamRequest(context.Background(), "srv", initReq, clientCbk) + require.NoError(t, err) +} diff --git a/sync2/rangesync/wire_helpers.go b/sync2/rangesync/wire_helpers.go new file mode 100644 index 0000000000..986071595c --- /dev/null +++ b/sync2/rangesync/wire_helpers.go @@ -0,0 +1,173 @@ +package rangesync + +import ( + "errors" + "fmt" + + "github.com/spacemeshos/go-scale" + "github.com/spacemeshos/go-spacemesh/sync2/types" +) + +// CompactHash encodes hashes in a compact form, skipping trailing zeroes. +// It also supports a nil hash (no value). +// The encoding format is as follows: +// byte 0: spec byte +// bytes 1..n: data bytes +// +// The format of the spec byte is as follows: +// bits 0..5: number of non-zero leading bytes +// bits 6..7: hash type +// +// The following hash types are supported: +// 0: nil hash +// 1: 32-byte hash +// 2,3: reserved + +// NOTE: when adding new hash types, we need to add a mechanism that makes sure that every +// received hash is of the expected type. Alternatively, we need to add some kind of +// context to the scale.Decoder / scale.Encoder, which may contain the size of hashes to +// be used. + +const ( + compactHashTypeNil = 0 + compactHashType32 = 1 + compactHashSizeBits = 6 + maxCompactHashSize = 32 +) + +var errInvalidCompactHash = errors.New("invalid compact hash") + +type CompactHash struct { + H types.KeyBytes +} + +// DecodeScale implements scale.Decodable. +func (c *CompactHash) DecodeScale(dec *scale.Decoder) (int, error) { + var h [maxCompactHashSize]byte + b, total, err := scale.DecodeByte(dec) + switch { + case err != nil: + return total, err + case b>>compactHashSizeBits == compactHashTypeNil: + c.H = nil + return total, nil + case b>>compactHashSizeBits != compactHashType32: + return total, errInvalidCompactHash + case b != 0: + l := b & ((1 << compactHashSizeBits) - 1) + n, err := scale.DecodeByteArray(dec, h[:l]) + total += n + if err != nil { + return total, err + } + } + c.H = h[:] + return total, nil +} + +// EncodeScale implements scale.Encodable. +func (c *CompactHash) EncodeScale(enc *scale.Encoder) (int, error) { + if c.H == nil { + return scale.EncodeByte(enc, compactHashTypeNil< 0; b-- { + if c.H[b-1] != 0 { + break + } + } + + total, err := scale.EncodeByte(enc, b|(compactHashType32<= 0; i-- { + k[i]++ + if k[i] != 0 { + return false + } + } + + return true +} + +func (k KeyBytes) Zero() { + for i := range k { + k[i] = 0 + } +} + +func (k KeyBytes) IsZero() bool { + for _, b := range k { + if b != 0 { + return false + } + } + return true +} + +// RandomKeyBytes generates random data in bytes for testing. +func RandomKeyBytes(size int) KeyBytes { + b := make([]byte, size) + _, err := rand.Read(b) + if err != nil { + return nil + } + return b +} + +func HexToKeyBytes(s string) KeyBytes { + b, err := hex.DecodeString(s) + if err != nil { + panic("bad hex key bytes: " + err.Error()) + } + return KeyBytes(b) +} + +type Fingerprint [FingerprintSize]byte + +func RandomFingerprint() Fingerprint { + var fp Fingerprint + _, err := rand.Read(fp[:]) + if err != nil { + panic("failed to generate random fingerprint: " + err.Error()) + } + return fp +} + +func EmptyFingerprint() Fingerprint { + return Fingerprint{} +} + +func (fp Fingerprint) Compare(other Fingerprint) int { + return bytes.Compare(fp[:], other[:]) +} + +func (fp Fingerprint) String() string { + return hex.EncodeToString(fp[:]) +} + +func (fp *Fingerprint) Update(h []byte) { + for n := range *fp { + (*fp)[n] ^= h[n] + } +} + +func (fp *Fingerprint) BitFromLeft(n int) bool { + if n > FingerprintSize*8 { + panic("BUG: bad fingerprint bit index") + } + return (fp[n>>3]>>(7-n&0x7))&1 != 0 +} + +func HexToFingerprint(s string) Fingerprint { + b, err := hex.DecodeString(s) + if err != nil { + panic("bad hex fingerprint: " + err.Error()) + } + var fp Fingerprint + if len(b) != len(fp) { + panic("bad hex fingerprint") + } + copy(fp[:], b) + return fp +} diff --git a/sync2/types/types_test.go b/sync2/types/types_test.go new file mode 100644 index 0000000000..0c4ace6c87 --- /dev/null +++ b/sync2/types/types_test.go @@ -0,0 +1,101 @@ +package types + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +var fakeSeq Seq = func(yield func(Ordered, error) bool) { + items := []KeyBytes{ + {1}, + {2}, + {3}, + {4}, + } + for { + for _, item := range items { + if !yield(item, nil) { + return + } + } + } +} + +var fakeErr = errors.New("fake error") + +var fakeErrSeq Seq = func(yield func(Ordered, error) bool) { + items := []KeyBytes{ + {1}, + {2}, + } + for _, item := range items { + if !yield(item, nil) { + return + } + } + yield(nil, fakeErr) +} + +var fakeErrOnlySeq Seq = func(yield func(Ordered, error) bool) { + yield(nil, fakeErr) +} + +func TestFirst(t *testing.T) { + k, err := fakeSeq.First() + require.NoError(t, err) + require.Equal(t, KeyBytes{1}, k) + k, err = fakeErrSeq.First() + require.NoError(t, err) + require.Equal(t, KeyBytes{1}, k) + k, err = fakeErrOnlySeq.First() + require.Equal(t, fakeErr, err) + require.Nil(t, k) +} + +func TestGetN(t *testing.T) { + actual, err := GetN[KeyBytes](fakeSeq, 2) + require.NoError(t, err) + require.Equal(t, []KeyBytes{{1}, {2}}, actual) + actual, err = GetN[KeyBytes](fakeSeq, 5) + require.NoError(t, err) + require.Equal(t, []KeyBytes{{1}, {2}, {3}, {4}, {1}}, actual) + actual, err = GetN[KeyBytes](fakeErrSeq, 2) + require.NoError(t, err) + require.Equal(t, []KeyBytes{{1}, {2}}, actual) + actual, err = GetN[KeyBytes](fakeErrSeq, 5) + require.Equal(t, fakeErr, err) + require.Nil(t, actual) + actual, err = GetN[KeyBytes](fakeErrOnlySeq, 2) + require.Equal(t, fakeErr, err) + require.Nil(t, actual) +} + +func TestIncID(t *testing.T) { + for _, tc := range []struct { + id, expected KeyBytes + overflow bool + }{ + { + id: KeyBytes{0x00, 0x00, 0x00, 0x00}, + expected: KeyBytes{0x00, 0x00, 0x00, 0x01}, + overflow: false, + }, + { + id: KeyBytes{0x00, 0x00, 0x00, 0xff}, + expected: KeyBytes{0x00, 0x00, 0x01, 0x00}, + overflow: false, + }, + { + id: KeyBytes{0xff, 0xff, 0xff, 0xff}, + expected: KeyBytes{0x00, 0x00, 0x00, 0x00}, + overflow: true, + }, + } { + id := make(KeyBytes, len(tc.id)) + copy(id, tc.id) + require.Equal(t, tc.overflow, id.Inc()) + require.Equal(t, tc.expected, id) + } +} From f9ed0a0bb73df1891352e7de72ececfc091eb662 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 4 Sep 2024 10:14:16 +0400 Subject: [PATCH 72/76] sync2: fix combined sequence tests --- sync2/dbsync/combine_seqs_test.go | 184 +++++++++++++++--------------- 1 file changed, 92 insertions(+), 92 deletions(-) diff --git a/sync2/dbsync/combine_seqs_test.go b/sync2/dbsync/combine_seqs_test.go index 1e7572306f..7eebd7bd5d 100644 --- a/sync2/dbsync/combine_seqs_test.go +++ b/sync2/dbsync/combine_seqs_test.go @@ -106,98 +106,98 @@ func TestCombineSeqs(t *testing.T) { result string startingPoint string }{ - // { - // seqs: []string{"abcd"}, - // indices: []int{0}, - // result: "abcd", - // startingPoint: "a", - // }, - // { - // seqs: []string{"abcd"}, - // indices: []int{0}, - // result: "abcd", - // startingPoint: "c", - // }, - // { - // seqs: []string{"abcd"}, - // indices: []int{2}, - // result: "cdab", - // startingPoint: "c", - // }, - // { - // seqs: []string{"abcd$"}, - // indices: []int{0}, - // result: "abcd$", - // startingPoint: "a", - // }, - // { - // seqs: []string{"abcd!"}, - // indices: []int{0}, - // result: "abcd!", - // startingPoint: "a", - // }, - // { - // seqs: []string{"abcd", "efgh"}, - // indices: []int{0, 0}, - // result: "abcdefgh", - // startingPoint: "a", - // }, - // { - // seqs: []string{"aceg", "bdfh"}, - // indices: []int{0, 0}, - // result: "abcdefgh", - // startingPoint: "a", - // }, - // { - // seqs: []string{"abcd$", "efgh$"}, - // indices: []int{0, 0}, - // result: "abcdefgh$", - // startingPoint: "a", - // }, - // { - // seqs: []string{"aceg$", "bdfh$"}, - // indices: []int{0, 0}, - // result: "abcdefgh$", - // startingPoint: "a", - // }, - // { - // seqs: []string{"abcd!", "efgh!"}, - // indices: []int{0, 0}, - // result: "abcd!", - // startingPoint: "a", - // }, - // { - // seqs: []string{"aceg!", "bdfh!"}, - // indices: []int{0, 0}, - // result: "abcdefg!", - // startingPoint: "a", - // }, - // { - // // wraparound: - // // "ac"+"bdefgh" - // // abcdefgh ==> - // // defghabc - // // starting point is d. - // // Each sequence must either start after the starting point, or - // // all of its elements are considered to be below the starting - // // point. "ac" is considered to be wrapped around initially - // seqs: []string{"ac", "bdefgh"}, - // indices: []int{0, 1}, - // result: "defghabc", - // startingPoint: "d", - // }, - // { - // seqs: []string{"bc", "ae"}, - // indices: []int{0, 1}, - // result: "eabc", - // startingPoint: "d", - // }, - // { - // seqs: []string{"ac", "bfg", "deh"}, - // indices: []int{0, 0, 0}, - // result: "abcdefgh", - // startingPoint: "a", - // }, + { + seqs: []string{"abcd"}, + indices: []int{0}, + result: "abcd", + startingPoint: "a", + }, + { + seqs: []string{"abcd"}, + indices: []int{0}, + result: "abcd", + startingPoint: "c", + }, + { + seqs: []string{"abcd"}, + indices: []int{2}, + result: "cdab", + startingPoint: "c", + }, + { + seqs: []string{"abcd$"}, + indices: []int{0}, + result: "abcd$", + startingPoint: "a", + }, + { + seqs: []string{"abcd!"}, + indices: []int{0}, + result: "abcd!", + startingPoint: "a", + }, + { + seqs: []string{"abcd", "efgh"}, + indices: []int{0, 0}, + result: "abcdefgh", + startingPoint: "a", + }, + { + seqs: []string{"aceg", "bdfh"}, + indices: []int{0, 0}, + result: "abcdefgh", + startingPoint: "a", + }, + { + seqs: []string{"abcd$", "efgh$"}, + indices: []int{0, 0}, + result: "abcdefgh$", + startingPoint: "a", + }, + { + seqs: []string{"aceg$", "bdfh$"}, + indices: []int{0, 0}, + result: "abcdefgh$", + startingPoint: "a", + }, + { + seqs: []string{"abcd!", "efgh!"}, + indices: []int{0, 0}, + result: "abcd!", + startingPoint: "a", + }, + { + seqs: []string{"aceg!", "bdfh!"}, + indices: []int{0, 0}, + result: "abcdefg!", + startingPoint: "a", + }, + { + // wraparound: + // "ac"+"bdefgh" + // abcdefgh ==> + // defghabc + // starting point is d. + // Each sequence must either start after the starting point, or + // all of its elements are considered to be below the starting + // point. "ac" is considered to be wrapped around initially + seqs: []string{"ac", "bdefgh"}, + indices: []int{0, 1}, + result: "defghabc", + startingPoint: "d", + }, + { + seqs: []string{"bc", "ae"}, + indices: []int{0, 1}, + result: "eabc", + startingPoint: "d", + }, + { + seqs: []string{"ac", "bfg", "deh"}, + indices: []int{0, 0, 0}, + result: "abcdefgh", + startingPoint: "a", + }, { seqs: []string{"abdefgh", "c"}, indices: []int{0, 0}, From 4708f0716e00bc5a0b88cc609895568c45548904 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 5 Sep 2024 21:42:19 +0400 Subject: [PATCH 73/76] sync2: fix multipeer test --- sync2/multipeer/setsyncbase_test.go | 31 +++++++++++++++++------------ 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/sync2/multipeer/setsyncbase_test.go b/sync2/multipeer/setsyncbase_test.go index 4c32e45764..381cde146b 100644 --- a/sync2/multipeer/setsyncbase_test.go +++ b/sync2/multipeer/setsyncbase_test.go @@ -70,13 +70,13 @@ func (st *setSyncBaseTester) expectCopy(ctx context.Context, addedKeys ...types. return copy } -func (st *setSyncBaseTester) expectSyncStore( +func (st *setSyncBaseTester) expectSync( ctx context.Context, p p2p.Peer, ss Syncer, addedKeys ...types.KeyBytes, ) { - st.ps.EXPECT().SyncStore(ctx, p, ss, nil, nil). + st.ps.EXPECT().Sync(ctx, p, ss, nil, nil). DoAndReturn(func( ctx context.Context, p p2p.Peer, @@ -90,14 +90,19 @@ func (st *setSyncBaseTester) expectSyncStore( }) } -func (st *setSyncBaseTester) failToSyncStore( +func (st *setSyncBaseTester) failToSync( ctx context.Context, p p2p.Peer, ss Syncer, err error, ) { - st.ps.EXPECT().SyncStore(ctx, p, ss, nil, nil). - DoAndReturn(func(ctx context.Context, p p2p.Peer, os rangesync.OrderedSet, x, y types.KeyBytes) error { + st.ps.EXPECT().Sync(ctx, p, ss, nil, nil). + DoAndReturn(func( + ctx context.Context, + p p2p.Peer, + os rangesync.OrderedSet, + x, y types.KeyBytes, + ) error { return err }) } @@ -126,8 +131,8 @@ func TestSetSyncBase(t *testing.T) { Count: 42, Sim: 0.99, } - store := st.expectCopy(ctx) - st.ps.EXPECT().Probe(ctx, p2p.Peer("p1"), store, nil, nil).Return(expPr, nil) + set := st.expectCopy(ctx) + st.ps.EXPECT().Probe(ctx, p2p.Peer("p1"), set, nil, nil).Return(expPr, nil) pr, err := st.ssb.Probe(ctx, p2p.Peer("p1")) require.NoError(t, err) require.Equal(t, expPr, pr) @@ -145,12 +150,12 @@ func TestSetSyncBase(t *testing.T) { x := types.RandomKeyBytes(32) y := types.RandomKeyBytes(32) - st.ps.EXPECT().SyncStore(ctx, p2p.Peer("p1"), ss, x, y) + st.ps.EXPECT().Sync(ctx, p2p.Peer("p1"), ss, x, y) require.NoError(t, ss.Sync(ctx, x, y)) st.os.EXPECT().Has(gomock.Any(), addedKey) st.os.EXPECT().Add(ctx, addedKey) - st.expectSyncStore(ctx, p2p.Peer("p1"), ss, addedKey) + st.expectSync(ctx, p2p.Peer("p1"), ss, addedKey) require.NoError(t, ss.Sync(ctx, nil, nil)) close(st.getWaitCh(addedKey)) @@ -173,7 +178,7 @@ func TestSetSyncBase(t *testing.T) { st.os.EXPECT().Add(ctx, addedKey) for i := 0; i < 3; i++ { st.os.EXPECT().Has(gomock.Any(), addedKey) - st.expectSyncStore(ctx, p2p.Peer("p1"), ss, addedKey) + st.expectSync(ctx, p2p.Peer("p1"), ss, addedKey) require.NoError(t, ss.Sync(ctx, nil, nil)) } close(st.getWaitCh(addedKey)) @@ -198,7 +203,7 @@ func TestSetSyncBase(t *testing.T) { st.os.EXPECT().Has(gomock.Any(), k2) st.os.EXPECT().Add(ctx, k1) st.os.EXPECT().Add(ctx, k2) - st.expectSyncStore(ctx, p2p.Peer("p1"), ss, k1, k2) + st.expectSync(ctx, p2p.Peer("p1"), ss, k1, k2) require.NoError(t, ss.Sync(ctx, nil, nil)) close(st.getWaitCh(k1)) close(st.getWaitCh(k2)) @@ -223,7 +228,7 @@ func TestSetSyncBase(t *testing.T) { st.os.EXPECT().Has(gomock.Any(), k2) // k1 is not propagated to syncBase due to the handler failure st.os.EXPECT().Add(ctx, k2) - st.expectSyncStore(ctx, p2p.Peer("p1"), ss, k1, k2) + st.expectSync(ctx, p2p.Peer("p1"), ss, k1, k2) require.NoError(t, ss.Sync(ctx, nil, nil)) handlerErr := errors.New("fail") st.getWaitCh(k1) <- handlerErr @@ -234,7 +239,7 @@ func TestSetSyncBase(t *testing.T) { require.ElementsMatch(t, []types.KeyBytes{k1, k2}, handledKeys) }) - t.Run("synctree based item store", func(t *testing.T) { + t.Run("real item set", func(t *testing.T) { t.Parallel() hs := make([]types.KeyBytes, 4) for n := range hs { From 9597db289e4ae06966b4ed0e4ab16e9774c4f419 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Fri, 6 Sep 2024 00:17:29 +0400 Subject: [PATCH 74/76] sync2: add syncedness check --- sync2/multipeer/multipeer.go | 30 +++++++++++++++++ sync2/multipeer/multipeer_test.go | 2 ++ sync2/multipeer/synclist.go | 54 +++++++++++++++++++++++++++++++ sync2/multipeer/synclist_test.go | 33 +++++++++++++++++++ sync2/p2p.go | 4 +++ sync2/p2p_test.go | 9 +++++- 6 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 sync2/multipeer/synclist.go create mode 100644 sync2/multipeer/synclist_test.go diff --git a/sync2/multipeer/multipeer.go b/sync2/multipeer/multipeer.go index 9df538670b..59af74b7ec 100644 --- a/sync2/multipeer/multipeer.go +++ b/sync2/multipeer/multipeer.go @@ -78,6 +78,23 @@ func WithLogger(logger *zap.Logger) MultiPeerReconcilerOpt { } } +// WithMinFullSyncednessCount sets the minimum number of full syncs that must +// have happened within the fullSyncednessPeriod for the node to be considered +// fully synced. +func WithMinFullSyncednessCount(count int) MultiPeerReconcilerOpt { + return func(mpr *MultiPeerReconciler) { + mpr.minFullSyncednessCount = count + } +} + +// WithFullSyncednessPeriod sets the duration within which the minimum number +// of full syncs must have happened for the node to be considered fully synced. +func WithFullSyncednessPeriod(d time.Duration) MultiPeerReconcilerOpt { + return func(mpr *MultiPeerReconciler) { + mpr.fullSyncednessPeriod = d + } +} + func withClock(clock clockwork.Clock) MultiPeerReconcilerOpt { return func(mpr *MultiPeerReconciler) { mpr.clock = clock @@ -123,6 +140,9 @@ type MultiPeerReconciler struct { keyLen int maxDepth int runner syncRunner + minFullSyncednessCount int + fullSyncednessPeriod time.Duration + sl *syncList } func NewMultiPeerReconciler( @@ -146,6 +166,8 @@ func NewMultiPeerReconciler( clock: clockwork.NewRealClock(), keyLen: keyLen, maxDepth: maxDepth, + minFullSyncednessCount: 3, + fullSyncednessPeriod: 15 * time.Minute, } for _, opt := range opts { opt(mpr) @@ -153,6 +175,7 @@ func NewMultiPeerReconciler( if mpr.runner == nil { mpr.runner = &runner{mpr: mpr} } + mpr.sl = newSyncList(mpr.clock, mpr.minFullSyncednessCount, mpr.fullSyncednessPeriod) return mpr } @@ -226,6 +249,7 @@ func (mpr *MultiPeerReconciler) fullSync(ctx context.Context, syncPeers []p2p.Pe err := syncer.Sync(ctx, nil, nil) switch { case err == nil: + mpr.sl.noteSync() case errors.Is(err, context.Canceled): return err default: @@ -331,3 +355,9 @@ LOOP: cancel() return errors.Join(err, mpr.syncBase.Wait()) } + +// Synced returns true if the node is considered synced, that is, the specified +// number of syncs has happened within the specified duration of time. +func (mpr *MultiPeerReconciler) Synced() bool { + return mpr.sl.synced() +} diff --git a/sync2/multipeer/multipeer_test.go b/sync2/multipeer/multipeer_test.go index 235c1a9fc6..bad2095f9c 100644 --- a/sync2/multipeer/multipeer_test.go +++ b/sync2/multipeer/multipeer_test.go @@ -161,6 +161,7 @@ func TestMultiPeerSync(t *testing.T) { ctx := mt.start() mt.addPeers(10) mt.syncBase.EXPECT().Count(gomock.Any()).Return(100, nil).AnyTimes() + require.False(t, mt.reconciler.Synced()) for i := 0; i < numSyncs; i++ { mt.expectProbe(6, rangesync.ProbeResult{ FP: "foo", @@ -173,6 +174,7 @@ func TestMultiPeerSync(t *testing.T) { mt.clock.Advance(time.Minute) mt.satisfy() } + require.True(t, mt.reconciler.Synced()) mt.syncBase.EXPECT().Wait() }) diff --git a/sync2/multipeer/synclist.go b/sync2/multipeer/synclist.go new file mode 100644 index 0000000000..93effaa616 --- /dev/null +++ b/sync2/multipeer/synclist.go @@ -0,0 +1,54 @@ +package multipeer + +import ( + "container/list" + "sync" + "time" + + "github.com/jonboulle/clockwork" +) + +// syncList keeps track of recent full syncs and reports whether the node is synced, that +// is, the specified number of syncs has happened within the specified duration of time +type syncList struct { + mtx sync.Mutex + clock clockwork.Clock + minSyncCount int + duration time.Duration + syncs list.List +} + +func newSyncList(clock clockwork.Clock, minSyncCount int, duration time.Duration) *syncList { + return &syncList{ + clock: clock, + minSyncCount: minSyncCount, + duration: duration, + } +} + +func (sl *syncList) prune(now time.Time) { + t := now.Add(-sl.duration) + for sl.syncs.Len() != 0 { + el := sl.syncs.Back() + if t.After(el.Value.(time.Time)) { + sl.syncs.Remove(el) + } else { + break + } + } +} + +func (sl *syncList) noteSync() { + sl.mtx.Lock() + defer sl.mtx.Unlock() + now := sl.clock.Now() + sl.prune(now) + sl.syncs.PushFront(now) +} + +func (sl *syncList) synced() bool { + sl.mtx.Lock() + defer sl.mtx.Unlock() + sl.prune(sl.clock.Now()) + return sl.syncs.Len() >= sl.minSyncCount +} diff --git a/sync2/multipeer/synclist_test.go b/sync2/multipeer/synclist_test.go new file mode 100644 index 0000000000..249f9f51fc --- /dev/null +++ b/sync2/multipeer/synclist_test.go @@ -0,0 +1,33 @@ +package multipeer + +import ( + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +func TestSyncList(t *testing.T) { + clk := clockwork.NewFakeClock() + sl := newSyncList(clk, 3, 5*time.Minute) + require.False(t, sl.synced()) + sl.noteSync() + require.False(t, sl.synced()) + clk.Advance(time.Minute) + sl.noteSync() + require.False(t, sl.synced()) + clk.Advance(time.Minute) + sl.noteSync() + require.True(t, sl.synced()) + clk.Advance(time.Minute) + // 3 minutes have passed + require.True(t, sl.synced()) + clk.Advance(2*time.Minute + 30*time.Second) + // 5 minutes 30 s have passed + require.False(t, sl.synced()) + sl.noteSync() + require.True(t, sl.synced()) + // make sure the list is pruned and is not growing indefinitely + require.Equal(t, 3, sl.syncs.Len()) +} diff --git a/sync2/p2p.go b/sync2/p2p.go index 633341eda8..fde939729c 100644 --- a/sync2/p2p.go +++ b/sync2/p2p.go @@ -135,3 +135,7 @@ func (s *P2PHashSync) Stop() { s.logger.Error("P2PHashSync terminated with an error", zap.Error(err)) } } + +func (s *P2PHashSync) Synced() bool { + return s.reconciler.Synced() +} diff --git a/sync2/p2p_test.go b/sync2/p2p_test.go index 7c4c4a7cae..0b6fe9f589 100644 --- a/sync2/p2p_test.go +++ b/sync2/p2p_test.go @@ -2,6 +2,7 @@ package sync2 import ( "context" + "fmt" "sync" "testing" "time" @@ -57,19 +58,25 @@ func TestP2P(t *testing.T) { return nil } os := rangesync.NewDumbHashSet(true) - hs[n] = NewP2PHashSync(logger, host, os, 32, 24, "sync2test", ps, handler, cfg) + hs[n] = NewP2PHashSync( + logger.Named(fmt.Sprintf("node%d", n)), + host, os, 32, 24, "sync2test", ps, handler, cfg) if n == 0 { is := hs[n].Set() for _, h := range initialSet { is.Add(context.Background(), h) } } + require.False(t, hs[n].Synced()) hs[n].Start() } require.Eventually(t, func() bool { for _, hsync := range hs { // use a snapshot to avoid races + if !hsync.Synced() { + return false + } os := hsync.Set().Copy() empty, err := os.Empty(context.Background()) require.NoError(t, err) From 7968770c1e29da47c6438c6d91f124750a3a70f8 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Fri, 6 Sep 2024 02:12:07 +0400 Subject: [PATCH 75/76] sync2: don't pre-read round messages in RangeSetReconciler --- sync2/rangesync/log.go | 3 + sync2/rangesync/p2p.go | 19 +++++- sync2/rangesync/rangesync.go | 95 +++++++++++++--------------- sync2/rangesync/rangesync_test.go | 6 +- sync2/rangesync/wire_conduit.go | 90 ++++++++++++++++++-------- sync2/rangesync/wire_conduit_test.go | 39 +++++++----- 6 files changed, 153 insertions(+), 99 deletions(-) diff --git a/sync2/rangesync/log.go b/sync2/rangesync/log.go index 25a979ac1a..46d2df170f 100644 --- a/sync2/rangesync/log.go +++ b/sync2/rangesync/log.go @@ -25,6 +25,9 @@ func (f seqFormatter) String() string { } func SeqField(name string, seq types.Seq) zap.Field { + if seq == nil { + return zap.String(name, "") + } return zap.Stringer(name, seqFormatter{seq: seq}) } diff --git a/sync2/rangesync/p2p.go b/sync2/rangesync/p2p.go index 1102a54811..043e13d140 100644 --- a/sync2/rangesync/p2p.go +++ b/sync2/rangesync/p2p.go @@ -47,8 +47,9 @@ func (pss *PairwiseSetSyncer) Probe( return ProbeResult{}, err } err = pss.r.StreamRequest(ctx, peer, initReq, func(ctx context.Context, stream io.ReadWriter) error { - c.stream = stream var err error + c.begin(ctx, stream) + defer c.end() pr, err = rsr.HandleProbeResponse(&c, info) return err }) @@ -83,7 +84,13 @@ func (pss *PairwiseSetSyncer) Sync( return err } return pss.r.StreamRequest(ctx, peer, initReq, func(ctx context.Context, stream io.ReadWriter) error { - return c.handleStream(ctx, stream, rsr) + c.begin(ctx, stream) + defer c.end() + if err := rsr.Run(ctx, &c); err != nil { + c.closeStream() // stop the writer + return err + } + return nil }) } @@ -103,5 +110,11 @@ func (pss *PairwiseSetSyncer) Serve( Reader: io.MultiReader(bytes.NewBuffer(req), stream), Writer: stream, } - return c.handleStream(ctx, s, rsr) + c.begin(ctx, s) + defer c.end() + if err := rsr.Run(ctx, &c); err != nil { + c.closeStream() // stop the writer + return err + } + return nil } diff --git a/sync2/rangesync/rangesync.go b/sync2/rangesync/rangesync.go index 8e23361869..31ffa6c5d2 100644 --- a/sync2/rangesync/rangesync.go +++ b/sync2/rangesync/rangesync.go @@ -835,27 +835,6 @@ func (rsr *RangeSetReconciler) initiateBounded(ctx context.Context, c Conduit, x } } -func (rsr *RangeSetReconciler) getMessages(c Conduit) (msgs []SyncMessage, done bool, err error) { - for { - msg, err := c.NextMessage() - switch { - case err != nil: - return msgs, false, err - case msg == nil: - return msgs, false, errors.New("no end round marker") - default: - switch msg.Type() { - case MessageTypeEndRound: - return msgs, false, nil - case MessageTypeDone: - return msgs, true, nil - default: - msgs = append(msgs, msg) - } - } - } -} - func (rsr *RangeSetReconciler) InitiateProbe(ctx context.Context, c Conduit) (RangeInfo, error) { return rsr.InitiateBoundedProbe(ctx, c, nil, nil) } @@ -997,23 +976,29 @@ func (rsr *RangeSetReconciler) HandleProbeResponse(c Conduit, info RangeInfo) (p } } -func (rsr *RangeSetReconciler) Process(ctx context.Context, c Conduit) (done bool, err error) { - var msgs []SyncMessage - // All of the round's messages need to be received before processing them, as - // processing the messages involves sending more messages back to the peer. - // TODO: use proper goroutines in the wireConduit to deal with send/recv blocking. - msgs, done, err = rsr.getMessages(c) - if done { - // items already added - if len(msgs) != 0 { - return false, errors.New("no extra messages expected along with 'done' message") - } - return done, nil - } +var errNoEndMarker = errors.New("no end round marker") +var errEmptyRound = errors.New("empty round") + +func (rsr *RangeSetReconciler) doRound(ctx context.Context, c Conduit) (done bool, err error) { done = true var receivedKeys []types.Ordered - for _, msg := range msgs { - if msg.Type() == MessageTypeItemBatch { + nHandled := 0 +RECV_LOOP: + for { + msg, err := c.NextMessage() + switch { + case err != nil: + return false, err + case msg == nil: + return false, errNoEndMarker + } + switch msg.Type() { + case MessageTypeEndRound: + break RECV_LOOP + case MessageTypeDone: + return true, nil + case MessageTypeItemBatch: + nHandled++ for _, k := range msg.Keys() { rsr.log.Debug("Process: add item", HexField("item", k)) if err := rsr.os.Add(ctx, k); err != nil { @@ -1024,30 +1009,26 @@ func (rsr *RangeSetReconciler) Process(ctx context.Context, c Conduit) (done boo continue } - // If there was an error, just add any items received, - // but ignore other messages - if err != nil { - continue - } - // TODO: pass preceding range. Somehow, currently the code // breaks if we capture the iterator from handleMessage and // pass it to the next handleMessage call (it shouldn't) - var msgDone bool - msgDone, err = rsr.handleMessage(ctx, c, msg, receivedKeys) + msgDone, err := rsr.handleMessage(ctx, c, msg, receivedKeys) + if err != nil { + return false, err + } + nHandled++ if !msgDone { done = false } - receivedKeys = nil - } - - if err != nil { - return false, err + receivedKeys = receivedKeys[:0] } - if done { + switch { + case done: err = c.SendDone() - } else { + case nHandled == 0: + err = errEmptyRound + default: err = c.SendEndRound() } @@ -1057,6 +1038,18 @@ func (rsr *RangeSetReconciler) Process(ctx context.Context, c Conduit) (done boo return done, nil } +func (rsr *RangeSetReconciler) Run(ctx context.Context, c Conduit) error { + for { + // Process() will receive all items and messages from the peer + syncDone, err := rsr.doRound(ctx, c) + if err != nil { + return err + } else if syncDone { + return nil + } + } +} + func fingerprintEqual(a, b any) bool { // FIXME: use Fingerprint interface with Equal() method for fingerprints // but still allow nil fingerprints diff --git a/sync2/rangesync/rangesync_test.go b/sync2/rangesync/rangesync_test.go index d68c7811be..9214dfee1b 100644 --- a/sync2/rangesync/rangesync_test.go +++ b/sync2/rangesync/rangesync_test.go @@ -255,7 +255,7 @@ func doRunSync(fc *fakeConduit, syncA, syncB *RangeSetReconciler, maxRounds int) nMsg += len(fc.msgs) nItems += fc.numItems() var err error - bDone, err = syncB.Process(context.Background(), fc) + bDone, err = syncB.doRound(context.Background(), fc) require.NoError(fc.t, err) // a party should never send anything in response to the "done" message require.False(fc.t, aDone && !bDone, "A is done but B after that is not") @@ -268,7 +268,7 @@ func doRunSync(fc *fakeConduit, syncA, syncB *RangeSetReconciler, maxRounds int) fc.gotoResponse() nMsg += len(fc.msgs) nItems += fc.numItems() - aDone, err = syncA.Process(context.Background(), fc) + aDone, err = syncA.doRound(context.Background(), fc) require.NoError(fc.t, err) dumpRangeMessages(fc.t, fc.msgs, "A %q --> B %q:", setStr(syncB.os), setStr(syncA.os)) dumpRangeMessages(fc.t, fc.resp, "A -> B:") @@ -298,7 +298,7 @@ func runBoundedProbe(t *testing.T, from, to *RangeSetReconciler, x, y types.Orde func doRunProbe(fc *fakeConduit, from, to *RangeSetReconciler, info RangeInfo) ProbeResult { require.NotEmpty(fc.t, fc.resp, "empty initial round") fc.gotoResponse() - done, err := to.Process(context.Background(), fc) + done, err := to.doRound(context.Background(), fc) require.True(fc.t, done) require.NoError(fc.t, err) fc.gotoResponse() diff --git a/sync2/rangesync/wire_conduit.go b/sync2/rangesync/wire_conduit.go index c6aa8c2081..bd6b0accc8 100644 --- a/sync2/rangesync/wire_conduit.go +++ b/sync2/rangesync/wire_conduit.go @@ -10,6 +10,11 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/sync2/types" + "golang.org/x/sync/errgroup" +) + +const ( + writeQueueSize = 10000 ) type sendable interface { @@ -20,10 +25,56 @@ type sendable interface { type wireConduit struct { stream io.ReadWriter initReqBuf *bytes.Buffer + eg errgroup.Group + sendCh chan sendable } var _ Conduit = &wireConduit{} +type deadline interface { + SetDeadline(time.Time) error +} + +func (c *wireConduit) closeStream() { + if closer, ok := c.stream.(io.Closer); ok { + closer.Close() + } +} + +func (c *wireConduit) begin(ctx context.Context, s io.ReadWriter) { + if c.stream != nil { + panic("BUG: wireConduit: begin() already called for this wireConduit") + } + c.stream = s + c.sendCh = make(chan sendable, writeQueueSize) + c.eg.Go(func() error { + for { + select { + case <-ctx.Done(): + c.closeStream() + return ctx.Err() + case m, ok := <-c.sendCh: + if !ok { + return nil + } + if err := writeMessage(c.stream, m); err != nil { + c.closeStream() + return err + } + } + } + }) +} + +func (c *wireConduit) end() { + if c.stream == nil { + panic("BUG: wireConduit: end() called without begin()") + } + close(c.sendCh) + c.eg.Wait() + c.stream = nil +} + func (c *wireConduit) NextMessage() (SyncMessage, error) { var b [1]byte if _, err := io.ReadFull(c.stream, b[:]); err != nil { @@ -88,20 +139,11 @@ func (c *wireConduit) NextMessage() (SyncMessage, error) { } func (c *wireConduit) send(m sendable) error { - var stream io.Writer - if c.initReqBuf != nil { - stream = c.initReqBuf - } else if c.stream == nil { - panic("BUG: wireConduit: no stream") - } else { - stream = c.stream + if c.initReqBuf == nil { + c.sendCh <- m + return nil } - b := []byte{byte(m.Type())} - if _, err := stream.Write(b); err != nil { - return err - } - _, err := codec.EncodeTo(stream, m) - return err + return writeMessage(c.initReqBuf, m) } func (c *wireConduit) SendFingerprint(x, y types.Ordered, fp types.Fingerprint, count int) error { @@ -214,19 +256,15 @@ func (c *wireConduit) withInitialRequest(toCall func(Conduit) error) ([]byte, er return c.initReqBuf.Bytes(), nil } -func (c *wireConduit) handleStream(ctx context.Context, stream io.ReadWriter, rsr *RangeSetReconciler) error { - c.stream = stream - for { - // Process() will receive all items and messages from the peer - syncDone, err := rsr.Process(ctx, c) - if err != nil { - return err - } else if syncDone { - return nil - } - } -} - func (c *wireConduit) ShortenKey(k types.Ordered) types.Ordered { return MinhashSampleItemFromKeyBytes(k.(types.KeyBytes)) } + +func writeMessage(w io.Writer, m sendable) error { + b := []byte{byte(m.Type())} + if _, err := w.Write(b); err != nil { + return err + } + _, err := codec.EncodeTo(w, m) + return err +} diff --git a/sync2/rangesync/wire_conduit_test.go b/sync2/rangesync/wire_conduit_test.go index d3cb9a066d..7b707e68bc 100644 --- a/sync2/rangesync/wire_conduit_test.go +++ b/sync2/rangesync/wire_conduit_test.go @@ -3,6 +3,7 @@ package rangesync import ( "bytes" "context" + "errors" "fmt" "io" "slices" @@ -16,6 +17,15 @@ import ( "github.com/spacemeshos/go-spacemesh/sync2/types" ) +type pipeStream struct { + io.ReadCloser + io.WriteCloser +} + +func (ps *pipeStream) Close() error { + return errors.Join(ps.ReadCloser.Close(), ps.WriteCloser.Close()) +} + type incomingRequest struct { initialRequest []byte stream io.ReadWriter @@ -71,25 +81,22 @@ func (fr *fakeRequester) request( if !found { return fmt.Errorf("bad peer %q", pid) } - r, w := io.Pipe() - defer r.Close() - defer w.Close() - stream := struct { - io.Reader - io.Writer - }{ - Reader: r, - Writer: w, + rClient, wServer := io.Pipe() + rServer, wClient := io.Pipe() + for _, s := range []io.Closer{rClient, wClient, rServer, wServer} { + defer s.Close() } + clientStream := &pipeStream{ReadCloser: rClient, WriteCloser: wClient} + serverStream := &pipeStream{ReadCloser: rServer, WriteCloser: wServer} select { case p.reqCh <- incomingRequest{ initialRequest: initialRequest, - stream: stream, + stream: serverStream, }: case <-ctx.Done(): return ctx.Err() } - return callback(ctx, stream) + return callback(ctx, clientStream) } func (fr *fakeRequester) StreamRequest( @@ -111,7 +118,7 @@ type fakeSend struct { done bool } -func (fs *fakeSend) send(c Conduit) error { +func (fs *fakeSend) send(c Conduit, t *testing.T, name string) error { // QQQQQ: rm t and name switch { case fs.endRound: return c.SendEndRound() @@ -159,7 +166,7 @@ func (r *fakeRound) handleConversation(t *testing.T, c *wireConduit) error { return err } for _, s := range r.toSend { - if err := s.send(c); err != nil { + if err := s.send(c, t, r.name); err != nil { return err } } @@ -185,10 +192,10 @@ func makeTestStreamHandler(t *testing.T, c *wireConduit, rounds []fakeRound) ser func makeTestRequestCallback(t *testing.T, c *wireConduit, rounds []fakeRound) server.StreamRequestCallback { return func(ctx context.Context, stream io.ReadWriter) error { if c == nil { - c = &wireConduit{stream: stream} - } else { - c.stream = stream + c = &wireConduit{} } + c.begin(ctx, stream) + defer c.end() for _, round := range rounds { if err := round.handleConversation(t, c); err != nil { return err From 238bbb3deaeba157d3f348a6af29b399133600dc Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Fri, 6 Sep 2024 20:52:09 +0400 Subject: [PATCH 76/76] sync2: initial syncer integration --- fetch/fetch.go | 9 +++ sync2/atxs.go | 98 ++++++++++++++++++++++++++++ sync2/dbsync/syncedtable.go | 8 +++ sync2/dbsync/syncedtable_test.go | 11 +--- sync2/p2p.go | 46 +++++++++++-- sync2/p2p_test.go | 1 + sync2/rangesync/dumbset.go | 4 ++ sync2/rangesync/interface.go | 4 ++ sync2/rangesync/wire_conduit_test.go | 4 +- syncer/syncer.go | 89 +++++++++++++++++++++++++ 10 files changed, 258 insertions(+), 16 deletions(-) create mode 100644 sync2/atxs.go diff --git a/fetch/fetch.go b/fetch/fetch.go index 94e93caea6..bd34657c87 100644 --- a/fetch/fetch.go +++ b/fetch/fetch.go @@ -10,6 +10,7 @@ import ( "sync" "time" + corehost "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "go.uber.org/zap" "golang.org/x/sync/errgroup" @@ -1006,3 +1007,11 @@ func (f *Fetch) SelectBestShuffled(n int) []p2p.Peer { }) return peers } + +func (f *Fetch) Host() corehost.Host { + return f.host.(corehost.Host) +} + +func (f *Fetch) Peers() *peers.Peers { + return f.peers +} diff --git a/sync2/atxs.go b/sync2/atxs.go new file mode 100644 index 0000000000..c3b1e4afb5 --- /dev/null +++ b/sync2/atxs.go @@ -0,0 +1,98 @@ +package sync2 + +import ( + "context" + + "github.com/libp2p/go-libp2p/core/host" + + smtypes "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/fetch/peers" + "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2/dbsync" + "github.com/spacemeshos/go-spacemesh/sync2/multipeer" + "github.com/spacemeshos/go-spacemesh/sync2/types" + "github.com/spacemeshos/go-spacemesh/system" + "go.uber.org/zap" +) + +const ( + oldAtxProto = "sync2/old-atx" + curAtxProto = "sync2/cur-atx" + oldAtxMaxDepth = 16 + curAtxMaxDepth = 24 +) + +type Fetcher interface { + system.AtxFetcher + Host() host.Host + Peers() *peers.Peers + RegisterPeerHash(peer p2p.Peer, hash smtypes.Hash32) +} + +type layerTicker interface { + CurrentLayer() smtypes.LayerID +} + +func AtxHandler(f Fetcher) multipeer.SyncKeyHandler { + return func(ctx context.Context, k types.Ordered, peer p2p.Peer) error { + var id smtypes.ATXID + copy(id[:], k.(types.KeyBytes)) + f.RegisterPeerHash(peer, id.Hash32()) + if err := f.GetAtxs(ctx, []smtypes.ATXID{id}); err != nil { + return err + } + return nil + } +} + +func atxsTable(epochFilter string, curEpoch smtypes.EpochID) *dbsync.SyncedTable { + return &dbsync.SyncedTable{ + TableName: "atxs", + IDColumn: "id", + TimestampColumn: "received", + Filter: dbsync.MustParseSQLExpr(epochFilter), + Binder: func(s *sql.Statement) { + s.BindInt64(1, int64(curEpoch)) + }, + } +} + +func newAtxSyncer( + logger *zap.Logger, + cfg Config, + db sql.StateDatabase, + f Fetcher, + ticker layerTicker, + epochFilter string, + maxDepth int, + proto string, +) *P2PHashSync { + // TODO: handle epoch switch + curEpoch := ticker.CurrentLayer().GetEpoch() + curSet := dbsync.NewDBSet(db, atxsTable(epochFilter, curEpoch), 32, curAtxMaxDepth) + return NewP2PHashSync(logger, f.Host(), curSet, 32, maxDepth, proto, f.Peers(), AtxHandler(f), cfg) +} + +func NewCurAtxSyncer( + logger *zap.Logger, + cfg Config, + db sql.StateDatabase, + f Fetcher, + ticker layerTicker, +) *P2PHashSync { + return newAtxSyncer(logger, cfg, db, f, ticker, "epoch = ?", curAtxMaxDepth, curAtxProto) +} + +func NewOldAtxSyncer( + logger *zap.Logger, + cfg Config, + db sql.StateDatabase, + f Fetcher, + ticker layerTicker, +) *P2PHashSync { + return newAtxSyncer(logger, cfg, db, f, ticker, "epoch < ?", oldAtxMaxDepth, oldAtxProto) +} + +// TODO: test +// TODO: per-round SQL transactions diff --git a/sync2/dbsync/syncedtable.go b/sync2/dbsync/syncedtable.go index 350e15107b..44c8370b9a 100644 --- a/sync2/dbsync/syncedtable.go +++ b/sync2/dbsync/syncedtable.go @@ -306,3 +306,11 @@ func (sts *SyncedTableSnapshot) loadRecent( dec) return err } + +func MustParseSQLExpr(s string) rsql.Expr { + expr, err := rsql.ParseExprString(s) + if err != nil { + panic("error parsing SQL expression: " + err.Error()) + } + return expr +} diff --git a/sync2/dbsync/syncedtable_test.go b/sync2/dbsync/syncedtable_test.go index 916429c4e1..67aa4b8d44 100644 --- a/sync2/dbsync/syncedtable_test.go +++ b/sync2/dbsync/syncedtable_test.go @@ -3,7 +3,6 @@ package dbsync import ( "testing" - rsql "github.com/rqlite/sql" "github.com/stretchr/testify/require" "github.com/spacemeshos/go-spacemesh/common/util" @@ -11,12 +10,6 @@ import ( "github.com/spacemeshos/go-spacemesh/sync2/types" ) -func parseSQLExpr(t *testing.T, s string) rsql.Expr { - expr, err := rsql.ParseExprString(s) - require.NoError(t, err) - return expr -} - func TestSyncedTable_GenSQL(t *testing.T) { for _, tc := range []struct { name string @@ -47,7 +40,7 @@ func TestSyncedTable_GenSQL(t *testing.T) { st: SyncedTable{ TableName: "atxs", IDColumn: "id", - Filter: parseSQLExpr(t, "epoch = ?"), + Filter: MustParseSQLExpr("epoch = ?"), TimestampColumn: "received", }, allRC: `SELECT "id" FROM "atxs" WHERE "epoch" = ? AND "rowid" <= ?`, @@ -296,7 +289,7 @@ func TestSyncedTable_LoadIDs(t *testing.T) { TableName: "atxs", IDColumn: "id", TimestampColumn: "received", - Filter: parseSQLExpr(t, "epoch = ?"), + Filter: MustParseSQLExpr("epoch = ?"), Binder: func(stmt *sql.Statement) { stmt.BindInt64(1, 2) }, diff --git a/sync2/p2p.go b/sync2/p2p.go index fde939729c..4ff17069ef 100644 --- a/sync2/p2p.go +++ b/sync2/p2p.go @@ -30,6 +30,8 @@ type Config struct { MinSplitSyncPeers int `mapstructure:"min-split-sync-peers"` MinCompleteFraction float64 `mapstructure:"min-complete-fraction"` SplitSyncGracePeriod time.Duration `mapstructure:"split-sync-grace-period"` + RecentTimeSpan time.Duration `mapstructure:"recent-time-span"` + EnableActiveSync bool `mapstructure:"enable-active-sync"` } func DefaultConfig() Config { @@ -50,6 +52,7 @@ func DefaultConfig() Config { type P2PHashSync struct { logger *zap.Logger + cfg Config h host.Host os rangesync.OrderedSet syncBase multipeer.SyncBase @@ -75,14 +78,19 @@ func NewP2PHashSync( logger: logger, h: h, os: os, + cfg: cfg, } s.srv = server.New(h, proto, s.handle, server.WithTimeout(cfg.Timeout), server.WithLog(logger)) - ps := rangesync.NewPairwiseSetSyncer(s.srv, []rangesync.RangeSetReconcilerOption{ + rangeSyncOpts := []rangesync.RangeSetReconcilerOption{ rangesync.WithMaxSendRange(cfg.MaxSendRange), rangesync.WithSampleSize(cfg.SampleSize), - }) + } + if cfg.RecentTimeSpan > 0 { + rangeSyncOpts = append(rangeSyncOpts, rangesync.WithRecentTimeSpan(cfg.RecentTimeSpan)) + } + ps := rangesync.NewPairwiseSetSyncer(s.srv, rangeSyncOpts) s.syncBase = multipeer.NewSetSyncBase(ps, s.os, handler) s.reconciler = multipeer.NewMultiPeerReconciler( s.syncBase, peers, keyLen, maxDepth, @@ -120,14 +128,26 @@ func (s *P2PHashSync) Start() { s.start.Do(func() { var ctx context.Context ctx, s.cancel = context.WithCancel(context.Background()) - s.eg.Go(func() error { return s.srv.Run(ctx) }) - s.eg.Go(func() error { return s.reconciler.Run(ctx) }) + s.eg.Go(func() error { + s.logger.Info("loading the set") + // We pre-load the set to avoid waiting for it to load during a + // sync request + if err := s.os.EnsureLoaded(ctx); err != nil { + return err + } + if s.cfg.EnableActiveSync { + s.eg.Go(func() error { return s.reconciler.Run(ctx) }) + } + return s.srv.Run(ctx) + }) s.running.Store(true) }) } func (s *P2PHashSync) Stop() { - s.running.Store(false) + if !s.running.CompareAndSwap(true, false) { + return + } if s.cancel != nil { s.cancel() } @@ -139,3 +159,19 @@ func (s *P2PHashSync) Stop() { func (s *P2PHashSync) Synced() bool { return s.reconciler.Synced() } + +var errStopped = errors.New("atx syncer stopped") + +func (s *P2PHashSync) WaitForSync(ctx context.Context) error { + for !s.Synced() { + if !s.running.Load() { + return errStopped + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(50 * time.Millisecond): + } + } + return nil +} diff --git a/sync2/p2p_test.go b/sync2/p2p_test.go index 0b6fe9f589..472c20c622 100644 --- a/sync2/p2p_test.go +++ b/sync2/p2p_test.go @@ -44,6 +44,7 @@ func TestP2P(t *testing.T) { } } cfg := DefaultConfig() + cfg.EnableActiveSync = true cfg.SyncInterval = 100 * time.Millisecond host := mesh.Hosts()[n] handler := func(ctx context.Context, k types.Ordered, peer p2p.Peer) error { diff --git a/sync2/rangesync/dumbset.go b/sync2/rangesync/dumbset.go index 19b150d2f0..e343a0d1b3 100644 --- a/sync2/rangesync/dumbset.go +++ b/sync2/rangesync/dumbset.go @@ -91,6 +91,10 @@ type dumbSet struct { var _ OrderedSet = &dumbSet{} +func (ds *dumbSet) EnsureLoaded(ctx context.Context) error { + return nil +} + func (ds *dumbSet) Add(ctx context.Context, k types.Ordered) error { id := k.(types.KeyBytes) if len(ds.keys) == 0 { diff --git a/sync2/rangesync/interface.go b/sync2/rangesync/interface.go index 8ada2d35d2..3f104d2987 100644 --- a/sync2/rangesync/interface.go +++ b/sync2/rangesync/interface.go @@ -32,6 +32,10 @@ type SplitInfo struct { // OrderedSet represents the set that can be synced against a remote peer type OrderedSet interface { + // EnsureLoaded ensures that the set is loaded and ready for use. + // It may do nothing in case of in-memory sets, but may trigger loading + // from disk or database in case of on-disk or remote sets. + EnsureLoaded(ctx context.Context) error // Add adds a key to the set Add(ctx context.Context, k types.Ordered) error // GetRangeInfo returns RangeInfo for the item range in the tree. diff --git a/sync2/rangesync/wire_conduit_test.go b/sync2/rangesync/wire_conduit_test.go index 7b707e68bc..a03619abd4 100644 --- a/sync2/rangesync/wire_conduit_test.go +++ b/sync2/rangesync/wire_conduit_test.go @@ -118,7 +118,7 @@ type fakeSend struct { done bool } -func (fs *fakeSend) send(c Conduit, t *testing.T, name string) error { // QQQQQ: rm t and name +func (fs *fakeSend) send(c Conduit) error { switch { case fs.endRound: return c.SendEndRound() @@ -166,7 +166,7 @@ func (r *fakeRound) handleConversation(t *testing.T, c *wireConduit) error { return err } for _, s := range r.toSend { - if err := s.send(c, t, r.name); err != nil { + if err := s.send(c); err != nil { return err } } diff --git a/syncer/syncer.go b/syncer/syncer.go index 71332b0da3..d87392ec9b 100644 --- a/syncer/syncer.go +++ b/syncer/syncer.go @@ -18,6 +18,8 @@ import ( "github.com/spacemeshos/go-spacemesh/log" "github.com/spacemeshos/go-spacemesh/mesh" "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sync2" "github.com/spacemeshos/go-spacemesh/syncer/atxsync" "github.com/spacemeshos/go-spacemesh/syncer/malsync" "github.com/spacemeshos/go-spacemesh/system" @@ -39,6 +41,9 @@ type Config struct { OutOfSyncThresholdLayers uint32 `mapstructure:"out-of-sync-threshold"` AtxSync atxsync.Config `mapstructure:"atx-sync"` MalSync malsync.Config `mapstructure:"malfeasance-sync"` + EnableSyncV2 bool `mapstructure:"enable-sync-v2"` + OldAtxSyncCfg sync2.Config `mapstructure:"old-atx-sync"` + CurAtxSyncCfg sync2.Config `mapstructure:"cur-atx-sync"` } // DefaultConfig for the syncer. @@ -162,6 +167,9 @@ type Syncer struct { eg errgroup.Group stop context.CancelFunc + + curAtxSyncV2 *sync2.P2PHashSync + oldAtxSyncV2 *sync2.P2PHashSync } // NewSyncer creates a new Syncer instance. @@ -207,6 +215,22 @@ func NewSyncer( s.isBusy.Store(false) s.lastLayerSynced.Store(s.mesh.LatestLayer().Uint32()) s.lastEpochSynced.Store(types.GetEffectiveGenesis().GetEpoch().Uint32() - 1) + if s.cfg.EnableSyncV2 { + s.curAtxSyncV2 = sync2.NewCurAtxSyncer( + s.logger.Named("cur-atx-sync"), + s.cfg.CurAtxSyncCfg, + cdb.Database.(sql.StateDatabase), + fetcher.(sync2.Fetcher), + ticker, + ) + s.oldAtxSyncV2 = sync2.NewOldAtxSyncer( + s.logger.Named("old-atx-sync"), + s.cfg.CurAtxSyncCfg, + cdb.Database.(sql.StateDatabase), + fetcher.(sync2.Fetcher), + ticker, + ) + } return s } @@ -471,10 +495,74 @@ func (s *Syncer) synchronize(ctx context.Context) bool { return success } +func (s *Syncer) syncAtxV2(ctx context.Context) error { + currentLayer := s.ticker.CurrentLayer() + publish := currentLayer.GetEpoch() + if publish == 0 { + return nil // nothing to sync in epoch 0 + } + if !s.ListenToATXGossip() { + // TODO: syncv2 + s.logger.Debug("syncing atx from genesis", + log.ZContext(ctx), + zap.Stringer("current layer", currentLayer), + zap.Stringer("last epoch", s.lastAtxEpoch()), + ) + s.oldAtxSyncV2.Start() + if err := s.oldAtxSyncV2.WaitForSync(ctx); err != nil { + return fmt.Errorf("error syncing old ATXs: %w", err) + } + s.logger.Debug("atxs synced to epoch", + log.ZContext(ctx), zap.Stringer("last epoch", s.lastAtxEpoch())) + + // TODO: use syncv2 for malfeasance proofs too + s.logger.Info("syncing malicious proofs", log.ZContext(ctx)) + if err := s.syncMalfeasance(ctx, currentLayer.GetEpoch()); err != nil { + return err + } + s.logger.Info("malicious IDs synced", log.ZContext(ctx)) + s.setATXSynced() + } + + // TODO: advance upon epoch change + // if current.OrdinalInEpoch() <= uint32(float64(types.GetLayersPerEpoch())*s.cfg.EpochEndFraction) { + // // not advancing yet ... + // } + s.curAtxSyncV2.Start() + if !s.malSync.started { + s.malSync.started = true + s.malSync.eg.Go(func() error { + select { + case <-ctx.Done(): + return nil + case <-s.awaitATXSyncedCh: + err := s.malsyncer.DownloadLoop(ctx) + if err != nil && !errors.Is(err, context.Canceled) { + s.logger.Error("malfeasance sync failed", log.ZContext(ctx), zap.Error(err)) + } + return nil + } + }) + } + + return nil +} + func (s *Syncer) syncAtx(ctx context.Context) error { + if s.cfg.EnableSyncV2 && s.cfg.OldAtxSyncCfg.EnableActiveSync && s.cfg.CurAtxSyncCfg.EnableActiveSync { + return s.syncAtxV2(ctx) + } else if s.cfg.EnableSyncV2 { + if s.cfg.OldAtxSyncCfg.EnableActiveSync || s.cfg.CurAtxSyncCfg.EnableActiveSync { + return errors.New("should enable both old & new atx syncv2 or disable both") + } + // start server-only syncers + s.curAtxSyncV2.Start() + s.oldAtxSyncV2.Start() + } current := s.ticker.CurrentLayer() // on startup always download all activations that were published before current epoch if !s.ListenToATXGossip() { + // TODO: syncv2 s.logger.Debug("syncing atx from genesis", log.ZContext(ctx), zap.Stringer("current layer", current), @@ -511,6 +599,7 @@ func (s *Syncer) syncAtx(ctx context.Context) error { s.backgroundSync.epoch.Store(0) } if s.backgroundSync.epoch.Load() == 0 && publish.Uint32() != 0 { + // TODO: syncv2 s.logger.Debug("download atx for epoch in background", zap.Stringer("publish", publish), log.ZContext(ctx)) s.backgroundSync.epoch.Store(publish.Uint32()) ctx, cancel := context.WithCancel(ctx)