From 036bd1ab09e389280a8c5dc3120825a1d61069e6 Mon Sep 17 00:00:00 2001 From: xkey Date: Mon, 5 Aug 2019 11:40:18 +0800 Subject: [PATCH] pkg/adt: fix interval tree black-height property based on rbtree Author: xkey ref. https://github.com/etcd-io/etcd/pull/10978 Signed-off-by: Gyuho Lee --- pkg/adt/interval_tree.go | 185 ++++++++++++++++++---------------- pkg/adt/interval_tree_test.go | 2 +- 2 files changed, 101 insertions(+), 86 deletions(-) diff --git a/pkg/adt/interval_tree.go b/pkg/adt/interval_tree.go index 2b13623cc3a..2e5b2ddb882 100644 --- a/pkg/adt/interval_tree.go +++ b/pkg/adt/interval_tree.go @@ -87,39 +87,39 @@ type intervalNode struct { c rbcolor } -func (x *intervalNode) color() rbcolor { - if x == nil { +func (x *intervalNode) color(sentinel *intervalNode) rbcolor { + if x == sentinel { return black } return x.c } -func (x *intervalNode) height() int { - if x == nil { +func (x *intervalNode) height(sentinel *intervalNode) int { + if x == sentinel { return 0 } - ld := x.left.height() - rd := x.right.height() + ld := x.left.height(sentinel) + rd := x.right.height(sentinel) if ld < rd { return rd + 1 } return ld + 1 } -func (x *intervalNode) min() *intervalNode { - for x.left != nil { +func (x *intervalNode) min(sentinel *intervalNode) *intervalNode { + for x.left != sentinel { x = x.left } return x } // successor is the next in-order node in the tree -func (x *intervalNode) successor() *intervalNode { - if x.right != nil { - return x.right.min() +func (x *intervalNode) successor(sentinel *intervalNode) *intervalNode { + if x.right != sentinel { + return x.right.min(sentinel) } y := x.parent - for y != nil && x == y.right { + for y != sentinel && x == y.right { x = y y = y.parent } @@ -127,14 +127,14 @@ func (x *intervalNode) successor() *intervalNode { } // updateMax updates the maximum values for a node and its ancestors -func (x *intervalNode) updateMax() { - for x != nil { +func (x *intervalNode) updateMax(sentinel *intervalNode) { + for x != sentinel { oldmax := x.max max := x.iv.Ivl.End - if x.left != nil && x.left.max.Compare(max) > 0 { + if x.left != sentinel && x.left.max.Compare(max) > 0 { max = x.left.max } - if x.right != nil && x.right.max.Compare(max) > 0 { + if x.right != sentinel && x.right.max.Compare(max) > 0 { max = x.right.max } if oldmax.Compare(max) == 0 { @@ -148,25 +148,25 @@ func (x *intervalNode) updateMax() { type nodeVisitor func(n *intervalNode) bool // visit will call a node visitor on each node that overlaps the given interval -func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) bool { - if x == nil { +func (x *intervalNode) visit(iv *Interval, sentinel *intervalNode, nv nodeVisitor) bool { + if x == sentinel { return true } v := iv.Compare(&x.iv.Ivl) switch { case v < 0: - if !x.left.visit(iv, nv) { + if !x.left.visit(iv, sentinel, nv) { return false } case v > 0: maxiv := Interval{x.iv.Ivl.Begin, x.max} if maxiv.Compare(iv) == 0 { - if !x.left.visit(iv, nv) || !x.right.visit(iv, nv) { + if !x.left.visit(iv, sentinel, nv) || !x.right.visit(iv, sentinel, nv) { return false } } default: - if !x.left.visit(iv, nv) || !nv(x) || !x.right.visit(iv, nv) { + if !x.left.visit(iv, sentinel, nv) || !nv(x) || !x.right.visit(iv, sentinel, nv) { return false } } @@ -211,9 +211,18 @@ type IntervalTree interface { // NewIntervalTree returns a new interval tree. func NewIntervalTree() IntervalTree { + sentinel := &intervalNode{ + iv: IntervalValue{}, + max: nil, + left: nil, + right: nil, + parent: nil, + c: black, + } return &intervalTree{ - root: nil, - count: 0, + root: sentinel, + count: 0, + sentinel: sentinel, } } @@ -221,9 +230,11 @@ type intervalTree struct { root *intervalNode count int - // TODO: use 'sentinel' as a dummy object to simplify boundary conditions + // red-black NIL node + // use 'sentinel' as a dummy object to simplify boundary conditions // use the sentinel to treat a nil child of a node x as an ordinary node whose parent is x // use one shared sentinel to represent all nil leaves and the root's parent + sentinel *intervalNode } // TODO: make this consistent with textbook implementation @@ -263,24 +274,25 @@ type intervalTree struct { // true if a node is in fact removed. func (ivt *intervalTree) Delete(ivl Interval) bool { z := ivt.find(ivl) - if z == nil { + if z == ivt.sentinel { return false } y := z - if z.left != nil && z.right != nil { - y = z.successor() + if z.left != ivt.sentinel && z.right != ivt.sentinel { + y = z.successor(ivt.sentinel) } - x := y.left - if x == nil { + x := ivt.sentinel + if y.left != ivt.sentinel { + x = y.left + } else if y.right != ivt.sentinel { x = y.right } - if x != nil { - x.parent = y.parent - } - if y.parent == nil { + x.parent = y.parent + + if y.parent == ivt.sentinel { ivt.root = x } else { if y == y.parent.left { @@ -288,14 +300,14 @@ func (ivt *intervalTree) Delete(ivl Interval) bool { } else { y.parent.right = x } - y.parent.updateMax() + y.parent.updateMax(ivt.sentinel) } if y != z { z.iv = y.iv - z.updateMax() + z.updateMax(ivt.sentinel) } - if y.color() == black && x != nil { + if y.color(ivt.sentinel) == black { ivt.deleteFixup(x) } @@ -348,10 +360,10 @@ func (ivt *intervalTree) Delete(ivl Interval) bool { // 40. x.color = BLACK // func (ivt *intervalTree) deleteFixup(x *intervalNode) { - for x != ivt.root && x.color() == black && x.parent != nil { + for x != ivt.root && x.color(ivt.sentinel) == black { if x == x.parent.left { // line 3-20 w := x.parent.right - if w.color() == red { + if w.color(ivt.sentinel) == red { w.c = black x.parent.c = red ivt.rotateLeft(x.parent) @@ -360,28 +372,26 @@ func (ivt *intervalTree) deleteFixup(x *intervalNode) { if w == nil { break } - if w.left.color() == black && w.right.color() == black { + if w.left.color(ivt.sentinel) == black && w.right.color(ivt.sentinel) == black { w.c = red x = x.parent } else { - if w.right.color() == black { + if w.right.color(ivt.sentinel) == black { w.left.c = black w.c = red ivt.rotateRight(w) w = x.parent.right } - w.c = x.parent.color() + w.c = x.parent.color(ivt.sentinel) x.parent.c = black w.right.c = black ivt.rotateLeft(x.parent) x = ivt.root } - } else { // line 22-38 - // same as above but with left and right exchanged w := x.parent.left - if w.color() == red { + if w.color(ivt.sentinel) == red { w.c = black x.parent.c = red ivt.rotateRight(x.parent) @@ -390,17 +400,17 @@ func (ivt *intervalTree) deleteFixup(x *intervalNode) { if w == nil { break } - if w.left.color() == black && w.right.color() == black { + if w.left.color(ivt.sentinel) == black && w.right.color(ivt.sentinel) == black { w.c = red x = x.parent } else { - if w.left.color() == black { + if w.left.color(ivt.sentinel) == black { w.right.c = black w.c = red ivt.rotateLeft(w) w = x.parent.left } - w.c = x.parent.color() + w.c = x.parent.color(ivt.sentinel) x.parent.c = black w.left.c = black ivt.rotateRight(x.parent) @@ -419,9 +429,9 @@ func (ivt *intervalTree) createIntervalNode(ivl Interval, val interface{}) *inte iv: IntervalValue{ivl, val}, max: ivl.End, c: red, - left: nil, - right: nil, - parent: nil, + left: ivt.sentinel, + right: ivt.sentinel, + parent: ivt.sentinel, } } @@ -458,10 +468,10 @@ func (ivt *intervalTree) createIntervalNode(ivl Interval, val interface{}) *inte // Insert adds a node with the given interval into the tree. func (ivt *intervalTree) Insert(ivl Interval, val interface{}) { - var y *intervalNode + y := ivt.sentinel z := ivt.createIntervalNode(ivl, val) x := ivt.root - for x != nil { + for x != ivt.sentinel { y = x if z.iv.Ivl.Begin.Compare(x.iv.Ivl.Begin) < 0 { x = x.left @@ -471,7 +481,7 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) { } z.parent = y - if y == nil { + if y == ivt.sentinel { ivt.root = z } else { if z.iv.Ivl.Begin.Compare(y.iv.Ivl.Begin) < 0 { @@ -479,7 +489,7 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) { } else { y.right = z } - y.updateMax() + y.updateMax(ivt.sentinel) } z.c = red @@ -522,10 +532,11 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) { // 30. T.root.color = BLACK // func (ivt *intervalTree) insertFixup(z *intervalNode) { - for z.parent != nil && z.parent.parent != nil && z.parent.color() == red { + for z.parent.color(ivt.sentinel) == red { if z.parent == z.parent.parent.left { // line 3-15 + y := z.parent.parent.right - if y.color() == red { + if y.color(ivt.sentinel) == red { y.c = black z.parent.c = black z.parent.parent.c = red @@ -542,7 +553,7 @@ func (ivt *intervalTree) insertFixup(z *intervalNode) { } else { // line 16-28 // same as then with left/right exchanged y := z.parent.parent.left - if y.color() == red { + if y.color(ivt.sentinel) == red { y.c = black z.parent.c = black z.parent.parent.c = red @@ -588,23 +599,27 @@ func (ivt *intervalTree) insertFixup(z *intervalNode) { // 18. x.p = y // func (ivt *intervalTree) rotateLeft(x *intervalNode) { + // rotateLeft x must have right child + if x.right == ivt.sentinel { + return + } + // line 2-3 y := x.right x.right = y.left // line 5-6 - if y.left != nil { + if y.left != ivt.sentinel { y.left.parent = x } - - x.updateMax() + x.updateMax(ivt.sentinel) // line 10-15, 18 ivt.replaceParent(x, y) // line 17 y.left = x - y.updateMax() + y.updateMax(ivt.sentinel) } // rotateRight moves x so it is right of its left child @@ -630,7 +645,8 @@ func (ivt *intervalTree) rotateLeft(x *intervalNode) { // 18. x.p = y // func (ivt *intervalTree) rotateRight(x *intervalNode) { - if x == nil { + // rotateRight x must have left child + if x.left == ivt.sentinel { return } @@ -639,24 +655,23 @@ func (ivt *intervalTree) rotateRight(x *intervalNode) { x.left = y.right // line 5-6 - if y.right != nil { + if y.right != ivt.sentinel { y.right.parent = x } - - x.updateMax() + x.updateMax(ivt.sentinel) // line 10-15, 18 ivt.replaceParent(x, y) // line 17 y.right = x - y.updateMax() + y.updateMax(ivt.sentinel) } // replaceParent replaces x's parent with y func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) { y.parent = x.parent - if x.parent == nil { + if x.parent == ivt.sentinel { ivt.root = y } else { if x == x.parent.left { @@ -664,7 +679,7 @@ func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) { } else { x.parent.right = y } - x.parent.updateMax() + x.parent.updateMax(ivt.sentinel) } x.parent = y } @@ -673,7 +688,7 @@ func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) { func (ivt *intervalTree) Len() int { return ivt.count } // Height is the number of levels in the tree; one node has height 1. -func (ivt *intervalTree) Height() int { return ivt.root.height() } +func (ivt *intervalTree) Height() int { return ivt.root.height(ivt.sentinel) } // MaxHeight is the expected maximum tree height given the number of nodes func (ivt *intervalTree) MaxHeight() int { @@ -686,11 +701,12 @@ type IntervalVisitor func(n *IntervalValue) bool // Visit calls a visitor function on every tree node intersecting the given interval. // It will visit each interval [x, y) in ascending order sorted on x. func (ivt *intervalTree) Visit(ivl Interval, ivv IntervalVisitor) { - ivt.root.visit(&ivl, func(n *intervalNode) bool { return ivv(&n.iv) }) + ivt.root.visit(&ivl, ivt.sentinel, func(n *intervalNode) bool { return ivv(&n.iv) }) } // find the exact node for a given interval -func (ivt *intervalTree) find(ivl Interval) (ret *intervalNode) { +func (ivt *intervalTree) find(ivl Interval) *intervalNode { + ret := ivt.sentinel f := func(n *intervalNode) bool { if n.iv.Ivl != ivl { return true @@ -698,14 +714,14 @@ func (ivt *intervalTree) find(ivl Interval) (ret *intervalNode) { ret = n return false } - ivt.root.visit(&ivl, f) + ivt.root.visit(&ivl, ivt.sentinel, f) return ret } // Find gets the IntervalValue for the node matching the given interval func (ivt *intervalTree) Find(ivl Interval) (ret *IntervalValue) { n := ivt.find(ivl) - if n == nil { + if n == ivt.sentinel { return nil } return &n.iv @@ -714,14 +730,14 @@ func (ivt *intervalTree) Find(ivl Interval) (ret *IntervalValue) { // Intersects returns true if there is some tree node intersecting the given interval. func (ivt *intervalTree) Intersects(iv Interval) bool { x := ivt.root - for x != nil && iv.Compare(&x.iv.Ivl) != 0 { - if x.left != nil && x.left.max.Compare(iv.Begin) > 0 { + for x != ivt.sentinel && iv.Compare(&x.iv.Ivl) != 0 { + if x.left != ivt.sentinel && x.left.max.Compare(iv.Begin) > 0 { x = x.left } else { x = x.right } } - return x != nil + return x != ivt.sentinel } // Contains returns true if the interval tree's keys cover the entire given interval. @@ -789,7 +805,7 @@ func (vi visitedInterval) String() string { // visitLevel traverses tree in level order. // used for testing func (ivt *intervalTree) visitLevel() []visitedInterval { - if ivt.root == nil { + if ivt.root == ivt.sentinel { return nil } @@ -804,22 +820,21 @@ func (ivt *intervalTree) visitLevel() []visitedInterval { f := queue[0] queue = queue[1:] - ivt := visitedInterval{ + vi := visitedInterval{ root: f.node.iv.Ivl, - color: f.node.color(), + color: f.node.color(ivt.sentinel), depth: f.depth, } - - if f.node.left != nil { - ivt.left = f.node.left.iv.Ivl + if f.node.left != ivt.sentinel { + vi.left = f.node.left.iv.Ivl queue = append(queue, pair{f.node.left, f.depth + 1}) } - if f.node.right != nil { - ivt.right = f.node.right.iv.Ivl + if f.node.right != ivt.sentinel { + vi.right = f.node.right.iv.Ivl queue = append(queue, pair{f.node.right, f.depth + 1}) } - rs = append(rs, ivt) + rs = append(rs, vi) } return rs diff --git a/pkg/adt/interval_tree_test.go b/pkg/adt/interval_tree_test.go index 00b60670a04..069166e8fa5 100644 --- a/pkg/adt/interval_tree_test.go +++ b/pkg/adt/interval_tree_test.go @@ -298,7 +298,7 @@ func TestIntervalTreeDelete(t *testing.T) { // / \ / // [238,239] [292,293] [953,954] // - t.Logf("level order after deleting '11' expected %v, got %v", expectedAfterDelete11, visitsAfterDelete11) + t.Fatalf("level order after deleting '11' expected %v, got %v", expectedAfterDelete11, visitsAfterDelete11) } }