From 57e780f3e5b90738a3fbf009f4045ffd806c9d71 Mon Sep 17 00:00:00 2001 From: Derek Collison Date: Sat, 18 May 2024 10:09:53 -0700 Subject: [PATCH] Make sure to not run away when acking from the middle of pending. Also optimize loop for interest based streams to ack msgs based on actual floor of pending. Signed-off-by: Derek Collison --- server/consumer.go | 32 ++++++++++++++----- server/filestore.go | 13 ++++++-- server/jetstream_test.go | 67 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 100 insertions(+), 12 deletions(-) diff --git a/server/consumer.go b/server/consumer.go index 2b6148bb5f..34cac17007 100644 --- a/server/consumer.go +++ b/server/consumer.go @@ -2748,7 +2748,7 @@ func (o *consumer) processAckMsg(sseq, dseq, dc uint64, reply string, doSample b return } - var sagap uint64 + var sgap, floor uint64 var needSignal bool switch o.cfg.AckPolicy { @@ -2792,12 +2792,29 @@ func (o *consumer) processAckMsg(sseq, dseq, dc uint64, reply string, doSample b if o.maxp > 0 && len(o.pending) >= o.maxp { needSignal = true } - sagap = sseq - o.asflr + sgap = sseq - o.asflr + floor = sgap // start at same and set lower as we go. o.adflr, o.asflr = dseq, sseq - for seq := sseq; seq > sseq-sagap && len(o.pending) > 0; seq-- { + + remove := func(seq uint64) { delete(o.pending, seq) delete(o.rdc, seq) o.removeFromRedeliverQueue(seq) + if seq < floor { + floor = seq + } + } + // Determine if smarter to walk all of pending vs the sequence range. + if sgap > uint64(len(o.pending)) { + for seq := range o.pending { + if seq <= sseq { + remove(seq) + } + } + } else { + for seq := sseq; seq > sseq-sgap && len(o.pending) > 0; seq-- { + remove(seq) + } } case AckNone: // FIXME(dlc) - This is error but do we care? @@ -2808,20 +2825,19 @@ func (o *consumer) processAckMsg(sseq, dseq, dc uint64, reply string, doSample b // Update underlying store. o.updateAcks(dseq, sseq, reply) - clustered := o.node != nil - // In case retention changes for a stream, this ought to have been updated // using the consumer lock to avoid a race. retention := o.retention + clustered := o.node != nil o.mu.Unlock() // Let the owning stream know if we are interest or workqueue retention based. // If this consumer is clustered this will be handled by processReplicatedAck // after the ack has propagated. if !clustered && mset != nil && retention != LimitsPolicy { - if sagap > 1 { - // FIXME(dlc) - This is very inefficient, will need to fix. - for seq := sseq; seq > sseq-sagap; seq-- { + if sgap > 1 { + // FIXME(dlc) - This can very inefficient, will need to fix. + for seq := sseq; seq >= floor; seq-- { mset.ackMsg(o, seq) } } else { diff --git a/server/filestore.go b/server/filestore.go index 75d521fba9..9bceef3520 100644 --- a/server/filestore.go +++ b/server/filestore.go @@ -8616,9 +8616,16 @@ func (o *consumerFileStore) UpdateAcks(dseq, sseq uint64) error { sgap := sseq - o.state.AckFloor.Stream o.state.AckFloor.Consumer = dseq o.state.AckFloor.Stream = sseq - for seq := sseq; seq > sseq-sgap && len(o.state.Pending) > 0; seq-- { - delete(o.state.Pending, seq) - if len(o.state.Redelivered) > 0 { + if sgap > uint64(len(o.state.Pending)) { + for seq := range o.state.Pending { + if seq <= sseq { + delete(o.state.Pending, seq) + delete(o.state.Redelivered, seq) + } + } + } else { + for seq := sseq; seq > sseq-sgap && len(o.state.Pending) > 0; seq-- { + delete(o.state.Pending, seq) delete(o.state.Redelivered, seq) } } diff --git a/server/jetstream_test.go b/server/jetstream_test.go index 76f8223c3b..dc4535bcb6 100644 --- a/server/jetstream_test.go +++ b/server/jetstream_test.go @@ -22523,7 +22523,7 @@ func TestJetStreamAckAllWithLargeFirstSequenceAndNoAckFloor(t *testing.T) { js.Publish("foo.bar", []byte("hello")) } - ss, err := js.PullSubscribe("foo.*", "TEST", nats.AckAll()) + ss, err := js.PullSubscribe("foo.*", "C1", nats.AckAll()) require_NoError(t, err) msgs, err := ss.Fetch(10, nats.MaxWait(100*time.Millisecond)) require_NoError(t, err) @@ -22542,4 +22542,69 @@ func TestJetStreamAckAllWithLargeFirstSequenceAndNoAckFloor(t *testing.T) { _, err = js.StreamInfo("TEST", nats.MaxWait(250*time.Millisecond)) require_NoError(t, err) + + // Now make sure that if we ack in the middle, meaning we still have ack pending, + // that we do the right thing as well. + ss, err = js.PullSubscribe("foo.*", "C2", nats.AckAll()) + require_NoError(t, err) + msgs, err = ss.Fetch(10, nats.MaxWait(100*time.Millisecond)) + require_NoError(t, err) + require_Equal(t, len(msgs), 10) + + start = time.Now() + msgs[5].AckSync() + if elapsed := time.Since(start); elapsed > 250*time.Millisecond { + t.Fatalf("AckSync took too long %v", elapsed) + } + + // Make sure next fetch works right away with low timeout. + msgs, err = ss.Fetch(10, nats.MaxWait(100*time.Millisecond)) + require_NoError(t, err) + require_Equal(t, len(msgs), 10) + + _, err = js.StreamInfo("TEST", nats.MaxWait(250*time.Millisecond)) + require_NoError(t, err) +} + +func TestJetStreamAckAllWithLargeFirstSequenceAndNoAckFloorWithInterestPolicy(t *testing.T) { + s := RunBasicJetStreamServer(t) + defer s.Shutdown() + + // Client for API requests. + nc, js := jsClientConnect(t, s) + defer nc.Close() + + _, err := js.AddStream(&nats.StreamConfig{ + Name: "TEST", + Subjects: []string{"foo.>"}, + Retention: nats.InterestPolicy, + }) + require_NoError(t, err) + + // Set first sequence to something very big here. This shows the issue with AckAll the + // first time it is called and existing ack floor is 0. + err = js.PurgeStream("TEST", &nats.StreamPurgeRequest{Sequence: 10_000_000_000}) + require_NoError(t, err) + + ss, err := js.PullSubscribe("foo.*", "C1", nats.AckAll()) + require_NoError(t, err) + + // Now add in 100 msgs + for i := 0; i < 100; i++ { + js.Publish("foo.bar", []byte("hello")) + } + + msgs, err := ss.Fetch(10, nats.MaxWait(100*time.Millisecond)) + require_NoError(t, err) + require_Equal(t, len(msgs), 10) + + start := time.Now() + msgs[5].AckSync() + if elapsed := time.Since(start); elapsed > 250*time.Millisecond { + t.Fatalf("AckSync took too long %v", elapsed) + } + + // We are testing for run away loops acking messages in the stream that are not there. + _, err = js.StreamInfo("TEST", nats.MaxWait(100*time.Millisecond)) + require_NoError(t, err) }