Skip to content

Commit

Permalink
UsageTracker: add methods to create & load snapshot
Browse files Browse the repository at this point in the history
Still to be wired to Kafka/ObjectStorage.

Signed-off-by: Oleg Zaytsev <mail@olegzaytsev.com>
  • Loading branch information
colega committed Dec 4, 2024
1 parent 28b21a9 commit 5a769aa
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 6 deletions.
16 changes: 10 additions & 6 deletions pkg/usagetracker/tracker_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
"go.uber.org/atomic"
)

const snapshotEncodingVersion = 1

const shards = 128

// trackerStore holds the core business logic of the usage-tracker abstracted in a testable way.
Expand Down Expand Up @@ -82,7 +84,7 @@ type tenantInfo struct {
// trackSeries is used in tests so we can provide custom time.Now() value.
// trackSeries will modify and reuse the input series slice.
func (t *trackerStore) trackSeries(ctx context.Context, tenantID string, series []uint64, timeNow time.Time) (rejected []uint64, err error) {
info := t.getOrCreateUpdatedTenantInfo(tenantID)
info := t.getOrCreateTenantInfo(tenantID)
limit := t.limiter.localSeriesLimit(tenantID)
if limit == 0 {
limit = math.MaxUint64
Expand All @@ -101,7 +103,7 @@ func (t *trackerStore) trackSeries(ctx context.Context, tenantID string, series
i0 := 0
for i := 1; i <= len(series); i++ {
// Track series if shard changes on the next element or if we're at the end of series.
if currentShard := series[i0] % shards; i == len(series) || currentShard != series[i]%shards {
if currentShard := uint8(series[i0] % shards); i == len(series) || currentShard != uint8(series[i]%shards) {
shard := t.getOrCreateTenantShard(tenantID, currentShard, limit)
created, rejected = shard.trackSeries(series[i0:i], now, &info.series, limit, created, rejected)
i0 = i
Expand All @@ -117,7 +119,7 @@ func (t *trackerStore) trackSeries(ctx context.Context, tenantID string, series
return rejected, nil
}

func (t *trackerStore) getOrCreateTenantShard(userID string, shard uint64, limit uint64) *tenantShard {
func (t *trackerStore) getOrCreateTenantShard(userID string, shard uint8, limit uint64) *tenantShard {
t.lock[shard].RLock()
m := t.data[shard][userID]
if m != nil {
Expand All @@ -140,7 +142,7 @@ func (t *trackerStore) getOrCreateTenantShard(userID string, shard uint64, limit
return m
}

func (t *trackerStore) getOrCreateUpdatedTenantInfo(tenantID string) *tenantInfo {
func (t *trackerStore) getOrCreateTenantInfo(tenantID string) *tenantInfo {
t.tenantsMtx.RLock()
if info, ok := t.tenants[tenantID]; ok {
// It is important to mark it as not marked for deletion, before we release the read lock.
Expand All @@ -163,11 +165,13 @@ func (t *trackerStore) getOrCreateUpdatedTenantInfo(tenantID string) *tenantInfo

// casIfGreater will CAS ts to now if now is greater than the value stored in ts.
// It will retry until it's possible to CAS or the condition is not et.
func casIfGreater(now minutes, ts *atomic.Int32) {
lastSeen := minutes(ts.Load())
// It returns the last seen value that was stored in ts before CAS-ing.
func casIfGreater(now minutes, ts *atomic.Int32) (lastSeen minutes) {
lastSeen = minutes(ts.Load())
for now.greaterThan(lastSeen) && !ts.CompareAndSwap(int32(lastSeen), int32(now)) {
lastSeen = minutes(ts.Load())
}
return lastSeen
}

func (t *trackerStore) cleanup(now time.Time) {
Expand Down
173 changes: 173 additions & 0 deletions pkg/usagetracker/tracker_store_snapshot.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// SPDX-License-Identifier: AGPL-3.0-only

package usagetracker

import (
"fmt"
"maps"
"time"

"github.com/prometheus/prometheus/tsdb/encoding"
"go.uber.org/atomic"
)

func (t *trackerStore) snapshot(shard uint8, now time.Time, buf []byte) []byte {
t.lock[shard].RLock()
shardTenants := maps.Clone(t.data[shard])
t.lock[shard].RUnlock()

snapshot := encoding.Encbuf{B: buf[:0]}
snapshot.PutByte(snapshotEncodingVersion)
snapshot.PutByte(shard)
snapshot.PutBE64(uint64(now.Unix()))
snapshot.PutUvarint64(uint64(len(shardTenants)))
for tenantID, shard := range shardTenants {
shard.RLock()
shardClone := maps.Clone(shard.series)
shard.RUnlock()
snapshot.PutUvarintStr(tenantID)
snapshot.PutUvarint64(uint64(len(shardClone)))
for s, ts := range shardClone {
snapshot.PutBE64(s)
snapshot.PutByte(byte(ts.Load()))
}
}
return snapshot.Get()
}

func (t *trackerStore) loadSnapshot(data []byte, now time.Time) error {
snapshot := encoding.Decbuf{B: data}
version := snapshot.Byte()
if err := snapshot.Err(); err != nil {
return fmt.Errorf("invalid snapshot format, expected version: %w", err)
}
if version != snapshotEncodingVersion {
return fmt.Errorf("unexpected snapshot version %d", version)
}
shard := snapshot.Byte()
if err := snapshot.Err(); err != nil {
return fmt.Errorf("invalid snapshot format, shard expected: %w", err)
}
if shard >= shards {
return fmt.Errorf("invalid snapshot format, shard %d out of bounds", shard)
}

snapshotTime := time.Unix(int64(snapshot.Be64()), 0)
if err := snapshot.Err(); err != nil {
return fmt.Errorf("invalid snapshot format, time expected: %w", err)
}
if snapshotAge := now.Sub(snapshotTime); snapshotAge > time.Hour {
return fmt.Errorf("snapshot is too old, snapshot time is %s (%d ago)", snapshotTime, snapshotAge)
}

tenantsLen := snapshot.Uvarint64()
if err := snapshot.Err(); err != nil {
return fmt.Errorf("invalid snapshot format, expected tenants len: %w", err)
}

t.lock[shard].RLock()
localShardTenantsClone := maps.Clone(t.data[shard])
t.lock[shard].RUnlock()

// We won't be holding the mutex on tenants series of each shard while checking timestamps.
// If we find a too old lastSeen, we might try to update it on our copy of the pointer to atomic.Uint64,
// but since we're not holding the mutex, it might be evicted at the same time.
// We fix that by requiring a mutex (at least read mutex) for updating lastSeen on values that are beyond 3/4 expiration.
mutexWatermark := toMinutes(now.Add(time.Duration(-3. / 4. * float64(t.idleTimeout))))

// Some series might have been right on the boundary of being evicted when we took the snapshot.
// Don't load them.
expirationWatermark := toMinutes(now.Add(-t.idleTimeout))

for i := 0; i < int(tenantsLen); i++ {
// We don't check for tenantID string length here, because we don't require it to be non-empty when we track series.
tenantID := snapshot.UvarintStr()
if err := snapshot.Err(); err != nil {
return fmt.Errorf("failed to read tenant ID %d: %w", i, err)
}
localTenant := localShardTenantsClone[tenantID]
if localTenant == nil {
// We know nothing about this tenant, maybe we need to create it
localTenant = t.getOrCreateTenantShard(tenantID, shard, t.limiter.localSeriesLimit(tenantID))
}
info := t.getOrCreateTenantInfo(tenantID)
if err := localTenant.loadSnapshot(&snapshot, &info.series, mutexWatermark, expirationWatermark); err != nil {
return fmt.Errorf("failed loading snapshot for tenant %s (%d): %w", tenantID, i, err)
}
}
return nil
}

func (shard *tenantShard) loadSnapshot(snapshot *encoding.Decbuf, totalTenantSeries *atomic.Uint64, mutexWatermark, expirationWatermark minutes) error {
type entry struct {
series uint64
ts minutes
}

seriesLen := int(snapshot.Uvarint64())
if err := snapshot.Err(); err != nil {
return fmt.Errorf("failed to read series len: %w", err)
}
shard.RLock()
seriesClone := maps.Clone(shard.series)
shard.RUnlock()

// TODO: reuse buffers here.
// We could use ugly logic to use the same slice here for both cases, but it's probably not worth it.
var newSeries []entry
var belowMutexWatermark []entry

for i := 0; i < seriesLen; i++ {
s := snapshot.Be64()
snapshotTs := minutes(snapshot.Byte())
if expirationWatermark.greaterThan(snapshotTs) {
// We're not interested in this series, it was about to be evicted.
continue
}

ts, ok := seriesClone[s]
if ok {
lastSeen := casIfGreater(snapshotTs, ts)
if mutexWatermark.greaterThan(lastSeen) {
// We've CASed the last seen timestamp, but we're getting close to this value being evicted.
// Since we're operating on a seriesClone, we might be updating an atomic value that isn't referenced by shard.series anymore.
// So, try this series again later with mutex.
belowMutexWatermark = append(belowMutexWatermark, entry{s, lastSeen})
}
continue
}
newSeries = append(newSeries, entry{s, snapshotTs})
}

// Check the series that were very close to expiration, if any.
if len(belowMutexWatermark) > 0 {
shard.RLock()
for _, e := range belowMutexWatermark {
ts, ok := shard.series[e.series]
if ok {
casIfGreater(e.ts, ts)
continue
}
// See? It didn't exist anymore.
newSeries = append(newSeries, e)
}
shard.RUnlock()
}

// Create series that didn't exist.
if len(newSeries) > 0 {
shard.Lock()
for _, e := range newSeries {
ts, ok := shard.series[e.series]
if ok {
casIfGreater(e.ts, ts)
continue
}
shard.series[e.series] = atomic.NewInt32(int32(e.ts))
// Replaying snapshot ignores limits. Series that were created elsewhere must be created here too.
totalTenantSeries.Inc()
}
shard.Unlock()
}
return nil
}
74 changes: 74 additions & 0 deletions pkg/usagetracker/tracker_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,80 @@ func TestTrackerStore_HappyCase(t *testing.T) {
}
}

func TestTrackerStore_snapshot(t *testing.T) {
const defaultIdleTimeout = 20 * time.Minute
const testUser1 = "user1"
const testUser2 = "user2"
now := time.Date(2020, 1, 1, 1, 2, 3, 0, time.UTC)

tracker := newTrackerStore(defaultIdleTimeout, log.NewNopLogger(), limiterMock{}, noopEvents{})
for i := 0; i < 60; i++ {
rejected, err := tracker.trackSeries(context.Background(), testUser1, []uint64{uint64(i)}, now)
require.Empty(t, rejected)
require.NoError(t, err)

rejected, err = tracker.trackSeries(context.Background(), testUser2, []uint64{uint64(i * 1000), uint64(i * 10000)}, now)
require.Empty(t, rejected)
require.NoError(t, err)

now = now.Add(time.Minute)
tracker.cleanup(now)
}

// testUser1 has 1 series per each one of the last defaultIdleTimeout minutes.
require.Equal(t, int(defaultIdleTimeout.Minutes()), int(tracker.tenants[testUser1].series.Load()))
// testUser2 has 2 series per each one of the last defaultIdleTimeout minutes.
require.Equal(t, 2*int(defaultIdleTimeout.Minutes()), int(tracker.tenants[testUser2].series.Load()))

tracker2 := newTrackerStore(defaultIdleTimeout, log.NewNopLogger(), limiterMock{}, noopEvents{})
var data []byte
for shard := uint8(0); shard < shards; shard++ {
data = tracker.snapshot(shard, now, data[:0])
err := tracker2.loadSnapshot(data, now)
require.NoError(t, err)
}

require.Equal(t, int(defaultIdleTimeout.Minutes()), int(tracker.tenants[testUser1].series.Load()))
require.Equal(t, 2*int(defaultIdleTimeout.Minutes()), int(tracker.tenants[testUser2].series.Load()))

// Check that they hold the same data.
for i := uint8(0); i < shards; i++ {
for tenantID, originalShard := range tracker.data[i] {
loadedShard, ok := tracker2.data[i][tenantID]
require.True(t, ok, "shard %d, tenant %s", i, tenantID)
for series, ts := range originalShard.series {
loadedTs, ok := loadedShard.series[series]
require.True(t, ok, "shard %d, tenant %s, series %d", i, tenantID, series)
require.Equal(t, ts.Load(), loadedTs.Load(), "shard %d, tenant %s, series %d", i, tenantID, series)
}
}
}

// Loading same snapshot again should be a noop.
for shard := uint8(0); shard < shards; shard++ {
data = tracker.snapshot(shard, now, data[:0])
err := tracker2.loadSnapshot(data, now)
require.NoError(t, err)
}

// Check that the total series counts are the same.
require.Equal(t, int(defaultIdleTimeout.Minutes()), int(tracker.tenants[testUser1].series.Load()))
require.Equal(t, 2*int(defaultIdleTimeout.Minutes()), int(tracker.tenants[testUser2].series.Load()))

// Check that they hold the same data.
for i := uint8(0); i < shards; i++ {
for tenantID, originalShard := range tracker.data[i] {
loadedShard, ok := tracker2.data[i][tenantID]
require.True(t, ok, "shard %d, tenant %s", i, tenantID)
for series, ts := range originalShard.series {
loadedTs, ok := loadedShard.series[series]
require.True(t, ok, "shard %d, tenant %s, series %d", i, tenantID, series)
require.Equal(t, ts.Load(), loadedTs.Load(), "shard %d, tenant %s, series %d", i, tenantID, series)
}
}
}
}

type limiterMock map[string]uint64

func (l limiterMock) localSeriesLimit(userID string) uint64 { return l[userID] }
Expand Down

0 comments on commit 5a769aa

Please sign in to comment.