diff --git a/queue.go b/queue.go index 57a6fd060..4df0434e0 100644 --- a/queue.go +++ b/queue.go @@ -268,24 +268,6 @@ func (q *TransmitLimitedQueue) addItem(cur *limitedBroadcast) { } } -// getTransmitRange returns a pair of min/max values for transmit values -// represented by the current queue contents. Both values represent actual -// transmit values on the interval [0, len). You must already hold the mutex. -func (q *TransmitLimitedQueue) getTransmitRange() (minTransmit, maxTransmit int) { - if q.lenLocked() == 0 { - return 0, 0 - } - minItem, maxItem := q.tq.Min(), q.tq.Max() - if minItem == nil || maxItem == nil { - return 0, 0 - } - - min := minItem.(*limitedBroadcast).transmits - max := maxItem.(*limitedBroadcast).transmits - - return min, max -} - // GetBroadcasts is used to get a number of broadcasts, up to a byte limit // and applying a per-message overhead as provided. func (q *TransmitLimitedQueue) GetBroadcasts(overhead, limit int) [][]byte { @@ -302,76 +284,40 @@ func (q *TransmitLimitedQueue) GetBroadcasts(overhead, limit int) [][]byte { var ( bytesUsed int toSend [][]byte - reinsert []*limitedBroadcast + picked []*limitedBroadcast ) - - // Visit fresher items first, but only look at stuff that will fit. - // We'll go tier by tier, grabbing the largest items first. - minTr, maxTr := q.getTransmitRange() - for transmits := minTr; transmits <= maxTr; /*do not advance automatically*/ { - free := int64(limit - bytesUsed - overhead) + var lb *limitedBroadcast + var free int64 + q.tq.Ascend(func(item btree.Item) bool { + lb = item.(*limitedBroadcast) + free = int64(limit - bytesUsed - overhead) if free <= 0 { - break // bail out early - } - - // Search for the least element on a given tier (by transmit count) as - // defined in the limitedBroadcast.Less function that will fit into our - // remaining space. - greaterOrEqual := &limitedBroadcast{ - transmits: transmits, - msgLen: free, - id: math.MaxInt64, - } - lessThan := &limitedBroadcast{ - transmits: transmits + 1, - msgLen: math.MaxInt64, - id: math.MaxInt64, + return false // bail out early } - var keep *limitedBroadcast - q.tq.AscendRange(greaterOrEqual, lessThan, func(item btree.Item) bool { - cur := item.(*limitedBroadcast) - // Check if this is within our limits - if int64(len(cur.b.Message())) > free { - // If this happens it's a bug in the datastructure or - // surrounding use doing something like having len(Message()) - // change over time. There's enough going on here that it's - // probably sane to just skip it and move on for now. - return true - } - keep = cur - return false - }) - if keep == nil { - // No more items of an appropriate size in the tier. - transmits++ - continue + if int64(len(lb.b.Message())) > free { + return true // continue to next message } - - msg := keep.b.Message() - - // Add to slice to send - bytesUsed += overhead + len(msg) + // msg ok to broadcast + msg := lb.b.Message() toSend = append(toSend, msg) + bytesUsed += overhead + len(msg) + picked = append(picked, lb) + return true + }) - // Check if we should stop transmission - q.deleteItem(keep) - if keep.transmits+1 >= transmitLimit { - keep.b.Finished() + // delete the picked message from queue. + // check the transmitted times + // to decide whether to finish or to continue transmission + for _, lb := range picked { + q.deleteItem(lb) + lb.transmits++ + if lb.transmits >= transmitLimit { + lb.b.Finished() } else { - // We need to bump this item down to another transmit tier, but - // because it would be in the same direction that we're walking the - // tiers, we will have to delay the reinsertion until we are - // finished our search. Otherwise we'll possibly re-add the message - // when we ascend to the next tier. - keep.transmits++ - reinsert = append(reinsert, keep) + q.addItem(lb) } } - for _, cur := range reinsert { - q.addItem(cur) - } - return toSend } diff --git a/queue_test.go b/queue_test.go index b49f647a1..7bc2ebb2f 100644 --- a/queue_test.go +++ b/queue_test.go @@ -4,6 +4,7 @@ package memberlist import ( + "strings" "testing" "github.com/google/btree" @@ -114,39 +115,65 @@ func TestTransmitLimited_GetBroadcasts(t *testing.T) { } func TestTransmitLimited_GetBroadcasts_Limit(t *testing.T) { - q := &TransmitLimitedQueue{RetransmitMult: 1, NumNodes: func() int { return 10 }} + q := &TransmitLimitedQueue{RetransmitMult: 1, NumNodes: func() int { return 100 }} require.Equal(t, int64(0), q.idGen, "the id generator seed starts at zero") - require.Equal(t, 2, retransmitLimit(q.RetransmitMult, q.NumNodes()), "sanity check transmit limits") + require.Equal(t, 3, retransmitLimit(q.RetransmitMult, q.NumNodes()), "sanity check transmit limits") // 18 bytes per message - q.QueueBroadcast(&memberlistBroadcast{"test", []byte("1. this is a test."), nil}) - q.QueueBroadcast(&memberlistBroadcast{"foo", []byte("2. this is a test."), nil}) - q.QueueBroadcast(&memberlistBroadcast{"bar", []byte("3. this is a test."), nil}) - q.QueueBroadcast(&memberlistBroadcast{"baz", []byte("4. this is a test."), nil}) + q.queueBroadcast(&memberlistBroadcast{"test", []byte("1. this is a test."), nil}, 1) + q.queueBroadcast(&memberlistBroadcast{"foo", []byte("2. this is a test."), nil}, 1) + // 54 bytes per message + q.queueBroadcast(&memberlistBroadcast{"bar", []byte(strings.Repeat("3. this is a test.", 3)), nil}, 0) + q.queueBroadcast(&memberlistBroadcast{"baz", []byte(strings.Repeat("4. this is a test.", 3)), nil}, 1) require.Equal(t, int64(4), q.idGen, "we handed out 4 IDs") + dump := q.orderedView(false) + if dump[0].b.(*memberlistBroadcast).node != "bar" { + t.Fatalf("missing bar") + } + if dump[1].b.(*memberlistBroadcast).node != "baz" { + t.Fatalf("missing baz") + } + if dump[2].b.(*memberlistBroadcast).node != "foo" { + t.Fatalf("missing foo") + } + if dump[3].b.(*memberlistBroadcast).node != "test" { + t.Fatalf("missing test") + } + // 3 byte overhead, should only get 3 messages back - partial1 := q.GetBroadcasts(3, 80) + partial1 := q.GetBroadcasts(3, 99) require.Equal(t, 3, len(partial1), "missing messages: %v", prettyPrintMessages(partial1)) require.Equal(t, int64(4), q.idGen, "id generator doesn't reset until empty") - partial2 := q.GetBroadcasts(3, 80) + partial2 := q.GetBroadcasts(3, 99) require.Equal(t, 3, len(partial2), "missing messages: %v", prettyPrintMessages(partial2)) require.Equal(t, int64(4), q.idGen, "id generator doesn't reset until empty") // Only two not expired - partial3 := q.GetBroadcasts(3, 80) - require.Equal(t, 2, len(partial3), "missing messages: %v", prettyPrintMessages(partial3)) + partial3 := q.GetBroadcasts(3, 99) + require.Equal(t, 1, len(partial3), "missing messages: %v", prettyPrintMessages(partial3)) + + require.Equal(t, int64(4), q.idGen, "id generator doesn't reset until empty") + // Only two not expired + partial4 := q.GetBroadcasts(3, 99) + require.Equal(t, 1, len(partial4), "missing messages: %v", prettyPrintMessages(partial3)) + + require.Equal(t, int64(4), q.idGen, "id generator doesn't reset until empty") + + // Only one not expired + partial5 := q.GetBroadcasts(3, 99) + require.Equal(t, 1, len(partial5), "missing messages: %v", prettyPrintMessages(partial5)) require.Equal(t, int64(0), q.idGen, "id generator resets on empty") // Should get nothing - partial5 := q.GetBroadcasts(3, 80) - require.Equal(t, 0, len(partial5), "missing messages: %v", prettyPrintMessages(partial5)) + partial6 := q.GetBroadcasts(3, 99) + require.Equal(t, 0, len(partial6), "missing messages: %v", prettyPrintMessages(partial6)) require.Equal(t, int64(0), q.idGen, "id generator resets on empty") }