diff --git a/pkg/adt/interval_tree.go b/pkg/adt/interval_tree.go index 2b13623cc3a..cf898f93e12 100644 --- a/pkg/adt/interval_tree.go +++ b/pkg/adt/interval_tree.go @@ -87,39 +87,39 @@ type intervalNode struct { c rbcolor } -func (x *intervalNode) color() rbcolor { - if x == nil { +func (x *intervalNode) color(nilNode *intervalNode) rbcolor { + if x == nilNode { return black } return x.c } -func (x *intervalNode) height() int { - if x == nil { +func (x *intervalNode) height(nilNode *intervalNode) int { + if x == nilNode { return 0 } - ld := x.left.height() - rd := x.right.height() + ld := x.left.height(nilNode) + rd := x.right.height(nilNode) if ld < rd { return rd + 1 } return ld + 1 } -func (x *intervalNode) min() *intervalNode { - for x.left != nil { +func (x *intervalNode) min(nilNode *intervalNode) *intervalNode { + for x.left != nilNode { x = x.left } return x } // successor is the next in-order node in the tree -func (x *intervalNode) successor() *intervalNode { - if x.right != nil { - return x.right.min() +func (x *intervalNode) successor(nilNode *intervalNode) *intervalNode { + if x.right != nilNode { + return x.right.min(nilNode) } y := x.parent - for y != nil && x == y.right { + for y != nilNode && x == y.right { x = y y = y.parent } @@ -127,14 +127,14 @@ func (x *intervalNode) successor() *intervalNode { } // updateMax updates the maximum values for a node and its ancestors -func (x *intervalNode) updateMax() { - for x != nil { +func (x *intervalNode) updateMax(nilNode *intervalNode) { + for x != nilNode { oldmax := x.max max := x.iv.Ivl.End - if x.left != nil && x.left.max.Compare(max) > 0 { + if x.left != nilNode && x.left.max.Compare(max) > 0 { max = x.left.max } - if x.right != nil && x.right.max.Compare(max) > 0 { + if x.right != nilNode && x.right.max.Compare(max) > 0 { max = x.right.max } if oldmax.Compare(max) == 0 { @@ -148,25 +148,25 @@ func (x *intervalNode) updateMax() { type nodeVisitor func(n *intervalNode) bool // visit will call a node visitor on each node that overlaps the given interval -func (x *intervalNode) visit(iv *Interval, nv nodeVisitor) bool { - if x == nil { +func (x *intervalNode) visit(iv *Interval, nilNode *intervalNode, nv nodeVisitor) bool { + if x == nilNode { return true } v := iv.Compare(&x.iv.Ivl) switch { case v < 0: - if !x.left.visit(iv, nv) { + if !x.left.visit(iv, nilNode, nv) { return false } case v > 0: maxiv := Interval{x.iv.Ivl.Begin, x.max} if maxiv.Compare(iv) == 0 { - if !x.left.visit(iv, nv) || !x.right.visit(iv, nv) { + if !x.left.visit(iv, nilNode, nv) || !x.right.visit(iv, nilNode, nv) { return false } } default: - if !x.left.visit(iv, nv) || !nv(x) || !x.right.visit(iv, nv) { + if !x.left.visit(iv, nilNode, nv) || !nv(x) || !x.right.visit(iv, nilNode, nv) { return false } } @@ -211,19 +211,26 @@ type IntervalTree interface { // NewIntervalTree returns a new interval tree. func NewIntervalTree() IntervalTree { - return &intervalTree{ - root: nil, - count: 0, - } + tree := &intervalTree{} + + tree.nilNode = &intervalNode{} + + tree.root = tree.nilNode + tree.nilNode.c = black + + return tree } type intervalTree struct { root *intervalNode count int - // TODO: use 'sentinel' as a dummy object to simplify boundary conditions + // use 'sentinel' as a dummy object to simplify boundary conditions // use the sentinel to treat a nil child of a node x as an ordinary node whose parent is x // use one shared sentinel to represent all nil leaves and the root's parent + + // red-black NIL node, can not use golang nil instead + nilNode *intervalNode } // TODO: make this consistent with textbook implementation @@ -263,24 +270,25 @@ type intervalTree struct { // true if a node is in fact removed. func (ivt *intervalTree) Delete(ivl Interval) bool { z := ivt.find(ivl) - if z == nil { + if z == ivt.nilNode { return false } y := z - if z.left != nil && z.right != nil { - y = z.successor() + if z.left != ivt.nilNode && z.right != ivt.nilNode { + y = z.successor(ivt.nilNode) } - x := y.left - if x == nil { + x := ivt.nilNode + if y.left != ivt.nilNode { + x = y.left + } else if y.right != ivt.nilNode { x = y.right } - if x != nil { - x.parent = y.parent - } - if y.parent == nil { + x.parent = y.parent + + if y.parent == ivt.nilNode { ivt.root = x } else { if y == y.parent.left { @@ -288,14 +296,14 @@ func (ivt *intervalTree) Delete(ivl Interval) bool { } else { y.parent.right = x } - y.parent.updateMax() + y.parent.updateMax(ivt.nilNode) } if y != z { z.iv = y.iv - z.updateMax() + z.updateMax(ivt.nilNode) } - if y.color() == black && x != nil { + if y.color(ivt.nilNode) == black { ivt.deleteFixup(x) } @@ -348,59 +356,51 @@ func (ivt *intervalTree) Delete(ivl Interval) bool { // 40. x.color = BLACK // func (ivt *intervalTree) deleteFixup(x *intervalNode) { - for x != ivt.root && x.color() == black && x.parent != nil { + for x != ivt.root && x.color(ivt.nilNode) == black { if x == x.parent.left { // line 3-20 w := x.parent.right - if w.color() == red { + if w.color(ivt.nilNode) == red { w.c = black x.parent.c = red ivt.rotateLeft(x.parent) w = x.parent.right } - if w == nil { - break - } - if w.left.color() == black && w.right.color() == black { + if w.left.color(ivt.nilNode) == black && w.right.color(ivt.nilNode) == black { w.c = red x = x.parent } else { - if w.right.color() == black { + if w.right.color(ivt.nilNode) == black { w.left.c = black w.c = red ivt.rotateRight(w) w = x.parent.right } - w.c = x.parent.color() + w.c = x.parent.color(ivt.nilNode) x.parent.c = black w.right.c = black ivt.rotateLeft(x.parent) x = ivt.root } - } else { // line 22-38 - // same as above but with left and right exchanged w := x.parent.left - if w.color() == red { + if w.color(ivt.nilNode) == red { w.c = black x.parent.c = red ivt.rotateRight(x.parent) w = x.parent.left } - if w == nil { - break - } - if w.left.color() == black && w.right.color() == black { + if w.left.color(ivt.nilNode) == black && w.right.color(ivt.nilNode) == black { w.c = red x = x.parent } else { - if w.left.color() == black { + if w.left.color(ivt.nilNode) == black { w.right.c = black w.c = red ivt.rotateLeft(w) w = x.parent.left } - w.c = x.parent.color() + w.c = x.parent.color(ivt.nilNode) x.parent.c = black w.left.c = black ivt.rotateRight(x.parent) @@ -408,10 +408,7 @@ func (ivt *intervalTree) deleteFixup(x *intervalNode) { } } } - - if x != nil { - x.c = black - } + x.c = black } func (ivt *intervalTree) createIntervalNode(ivl Interval, val interface{}) *intervalNode { @@ -419,9 +416,9 @@ func (ivt *intervalTree) createIntervalNode(ivl Interval, val interface{}) *inte iv: IntervalValue{ivl, val}, max: ivl.End, c: red, - left: nil, - right: nil, - parent: nil, + left: ivt.nilNode, + right: ivt.nilNode, + parent: ivt.nilNode, } } @@ -458,10 +455,10 @@ func (ivt *intervalTree) createIntervalNode(ivl Interval, val interface{}) *inte // Insert adds a node with the given interval into the tree. func (ivt *intervalTree) Insert(ivl Interval, val interface{}) { - var y *intervalNode + y := ivt.nilNode z := ivt.createIntervalNode(ivl, val) x := ivt.root - for x != nil { + for x != ivt.nilNode { y = x if z.iv.Ivl.Begin.Compare(x.iv.Ivl.Begin) < 0 { x = x.left @@ -471,7 +468,7 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) { } z.parent = y - if y == nil { + if y == ivt.nilNode { ivt.root = z } else { if z.iv.Ivl.Begin.Compare(y.iv.Ivl.Begin) < 0 { @@ -479,7 +476,7 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) { } else { y.right = z } - y.updateMax() + y.updateMax(ivt.nilNode) } z.c = red @@ -520,12 +517,13 @@ func (ivt *intervalTree) Insert(ivl Interval, val interface{}) { // 28. LEFT-ROTATE(T, z.p.p) // 29. // 30. T.root.color = BLACK -// + func (ivt *intervalTree) insertFixup(z *intervalNode) { - for z.parent != nil && z.parent.parent != nil && z.parent.color() == red { + for z.parent.color(ivt.nilNode) == red { if z.parent == z.parent.parent.left { // line 3-15 + y := z.parent.parent.right - if y.color() == red { + if y.color(ivt.nilNode) == red { y.c = black z.parent.c = black z.parent.parent.c = red @@ -542,7 +540,7 @@ func (ivt *intervalTree) insertFixup(z *intervalNode) { } else { // line 16-28 // same as then with left/right exchanged y := z.parent.parent.left - if y.color() == red { + if y.color(ivt.nilNode) == red { y.c = black z.parent.c = black z.parent.parent.c = red @@ -586,25 +584,25 @@ func (ivt *intervalTree) insertFixup(z *intervalNode) { // 16. // 17. y.left = x // 18. x.p = y -// func (ivt *intervalTree) rotateLeft(x *intervalNode) { - // line 2-3 + // rotateLeft x must have right child + if x.right == ivt.nilNode { + return + } + y := x.right x.right = y.left - - // line 5-6 - if y.left != nil { + if y.left != ivt.nilNode { y.left.parent = x } - - x.updateMax() + x.updateMax(ivt.nilNode) // line 10-15, 18 ivt.replaceParent(x, y) // line 17 y.left = x - y.updateMax() + y.updateMax(ivt.nilNode) } // rotateRight moves x so it is right of its left child @@ -628,35 +626,31 @@ func (ivt *intervalTree) rotateLeft(x *intervalNode) { // 16. // 17. y.right = x // 18. x.p = y -// func (ivt *intervalTree) rotateRight(x *intervalNode) { - if x == nil { + // rotateRight x must have left child + if x.left == ivt.nilNode { return } - // line 2-3 y := x.left x.left = y.right - - // line 5-6 - if y.right != nil { + if y.right != ivt.nilNode { y.right.parent = x } - - x.updateMax() + x.updateMax(ivt.nilNode) // line 10-15, 18 ivt.replaceParent(x, y) // line 17 y.right = x - y.updateMax() + y.updateMax(ivt.nilNode) } // replaceParent replaces x's parent with y func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) { y.parent = x.parent - if x.parent == nil { + if x.parent == ivt.nilNode { ivt.root = y } else { if x == x.parent.left { @@ -664,7 +658,7 @@ func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) { } else { x.parent.right = y } - x.parent.updateMax() + x.parent.updateMax(ivt.nilNode) } x.parent = y } @@ -673,7 +667,7 @@ func (ivt *intervalTree) replaceParent(x *intervalNode, y *intervalNode) { func (ivt *intervalTree) Len() int { return ivt.count } // Height is the number of levels in the tree; one node has height 1. -func (ivt *intervalTree) Height() int { return ivt.root.height() } +func (ivt *intervalTree) Height() int { return ivt.root.height(ivt.nilNode) } // MaxHeight is the expected maximum tree height given the number of nodes func (ivt *intervalTree) MaxHeight() int { @@ -686,11 +680,12 @@ type IntervalVisitor func(n *IntervalValue) bool // Visit calls a visitor function on every tree node intersecting the given interval. // It will visit each interval [x, y) in ascending order sorted on x. func (ivt *intervalTree) Visit(ivl Interval, ivv IntervalVisitor) { - ivt.root.visit(&ivl, func(n *intervalNode) bool { return ivv(&n.iv) }) + ivt.root.visit(&ivl, ivt.nilNode, func(n *intervalNode) bool { return ivv(&n.iv) }) } // find the exact node for a given interval -func (ivt *intervalTree) find(ivl Interval) (ret *intervalNode) { +func (ivt *intervalTree) find(ivl Interval) *intervalNode { + ret := ivt.nilNode f := func(n *intervalNode) bool { if n.iv.Ivl != ivl { return true @@ -698,14 +693,16 @@ func (ivt *intervalTree) find(ivl Interval) (ret *intervalNode) { ret = n return false } - ivt.root.visit(&ivl, f) + + ivt.root.visit(&ivl, ivt.nilNode, f) + return ret } // Find gets the IntervalValue for the node matching the given interval func (ivt *intervalTree) Find(ivl Interval) (ret *IntervalValue) { n := ivt.find(ivl) - if n == nil { + if n == ivt.nilNode { return nil } return &n.iv @@ -714,14 +711,14 @@ func (ivt *intervalTree) Find(ivl Interval) (ret *IntervalValue) { // Intersects returns true if there is some tree node intersecting the given interval. func (ivt *intervalTree) Intersects(iv Interval) bool { x := ivt.root - for x != nil && iv.Compare(&x.iv.Ivl) != 0 { - if x.left != nil && x.left.max.Compare(iv.Begin) > 0 { + for x != ivt.nilNode && iv.Compare(&x.iv.Ivl) != 0 { + if x.left != ivt.nilNode && x.left.max.Compare(iv.Begin) > 0 { x = x.left } else { x = x.right } } - return x != nil + return x != ivt.nilNode } // Contains returns true if the interval tree's keys cover the entire given interval. @@ -789,7 +786,7 @@ func (vi visitedInterval) String() string { // visitLevel traverses tree in level order. // used for testing func (ivt *intervalTree) visitLevel() []visitedInterval { - if ivt.root == nil { + if ivt.root == ivt.nilNode { return nil } @@ -804,22 +801,22 @@ func (ivt *intervalTree) visitLevel() []visitedInterval { f := queue[0] queue = queue[1:] - ivt := visitedInterval{ + ivt2 := visitedInterval{ root: f.node.iv.Ivl, - color: f.node.color(), + color: f.node.color(ivt.nilNode), depth: f.depth, } - if f.node.left != nil { - ivt.left = f.node.left.iv.Ivl + if f.node.left != ivt.nilNode { + ivt2.left = f.node.left.iv.Ivl queue = append(queue, pair{f.node.left, f.depth + 1}) } - if f.node.right != nil { - ivt.right = f.node.right.iv.Ivl + if f.node.right != ivt.nilNode { + ivt2.right = f.node.right.iv.Ivl queue = append(queue, pair{f.node.right, f.depth + 1}) } - rs = append(rs, ivt) + rs = append(rs, ivt2) } return rs diff --git a/pkg/adt/interval_tree_test.go b/pkg/adt/interval_tree_test.go index 00b60670a04..069166e8fa5 100644 --- a/pkg/adt/interval_tree_test.go +++ b/pkg/adt/interval_tree_test.go @@ -298,7 +298,7 @@ func TestIntervalTreeDelete(t *testing.T) { // / \ / // [238,239] [292,293] [953,954] // - t.Logf("level order after deleting '11' expected %v, got %v", expectedAfterDelete11, visitsAfterDelete11) + t.Fatalf("level order after deleting '11' expected %v, got %v", expectedAfterDelete11, visitsAfterDelete11) } }