-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
1,568 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
package cachettl | ||
|
||
import ( | ||
"sync" | ||
"time" | ||
) | ||
|
||
// Cache is a double linked list sorted by expiration time (ascending order) | ||
// the root (head) node is the node with the lowest expiration time | ||
// the tail node (end) is the node with the highest expiration time | ||
// Cleanups are done on Get() calls so if Get() is never invoked then Nodes stay in-memory. | ||
type Cache[K comparable, V any] struct { | ||
root *node[K, V] | ||
mu sync.Mutex | ||
m map[K]*node[K, V] | ||
now func() time.Time | ||
} | ||
|
||
type node[K comparable, V any] struct { | ||
key K | ||
value V | ||
prev *node[K, V] | ||
next *node[K, V] | ||
ttl time.Duration | ||
expiration time.Time | ||
} | ||
|
||
func (n *node[K, V]) remove() { | ||
n.prev.next = n.next | ||
n.next.prev = n.prev | ||
} | ||
|
||
// New returns a new Cache. | ||
func New[K comparable, V any]() *Cache[K, V] { | ||
return &Cache[K, V]{ | ||
now: time.Now, | ||
root: &node[K, V]{}, | ||
m: make(map[K]*node[K, V]), | ||
} | ||
} | ||
|
||
// Get returns the value associated with the key or nil otherwise. | ||
// Additionally, Get() will refresh the TTL and cleanup expired nodes. | ||
func (c *Cache[K, V]) Get(key K) (zero V) { | ||
c.mu.Lock() | ||
defer c.mu.Unlock() | ||
|
||
defer func() { // remove expired nodes | ||
cn := c.root.next // start from head since we're sorting by expiration with the highest expiration at the tail | ||
for cn != nil && cn != c.root { | ||
if c.now().After(cn.expiration) { | ||
cn.remove() // removes a node from the linked list (leaves the map untouched) | ||
delete(c.m, cn.key) // remove node from map too | ||
} else { // there is nothing else to clean up, no need to iterate further | ||
break | ||
} | ||
cn = cn.next | ||
} | ||
}() | ||
|
||
if n, ok := c.m[key]; ok && n.expiration.After(c.now()) { | ||
n.remove() | ||
n.expiration = c.now().Add(n.ttl) // refresh TTL | ||
c.add(n) | ||
return n.value | ||
} | ||
return zero | ||
} | ||
|
||
// Put adds or updates an element inside the Cache. | ||
// The Cache will be sorted with the node with the highest expiration at the tail. | ||
func (c *Cache[K, V]) Put(key K, value V, ttl time.Duration) { | ||
c.mu.Lock() | ||
defer c.mu.Unlock() | ||
|
||
now := c.now() | ||
|
||
n, ok := c.m[key] | ||
if !ok { | ||
n = &node[K, V]{ | ||
key: key, value: value, ttl: ttl, expiration: now.Add(ttl), | ||
} | ||
c.m[key] = n | ||
} else { | ||
n.value = value | ||
n.expiration = now.Add(ttl) | ||
} | ||
|
||
if c.root.next == nil { // first node insertion | ||
c.root.next = n | ||
c.root.prev = n | ||
n.prev = c.root | ||
n.next = c.root | ||
return | ||
} | ||
|
||
if ok { // removes a node from the linked list (leaves the map untouched) | ||
n.remove() | ||
} | ||
|
||
c.add(n) | ||
} | ||
|
||
func (c *Cache[K, V]) add(n *node[K, V]) { | ||
cn := c.root.prev // tail | ||
for cn != nil { // iterate from tail to root because we have expiring nodes towards the tail | ||
if n.expiration.After(cn.expiration) || n.expiration.Equal(cn.expiration) { | ||
// insert node after cn | ||
save := cn.next | ||
cn.next = n | ||
n.prev = cn | ||
n.next = save | ||
save.prev = n | ||
break | ||
} | ||
cn = cn.prev | ||
} | ||
} | ||
|
||
// slice is used for debugging purposes only | ||
func (c *Cache[K, V]) slice() (s []V) { | ||
c.mu.Lock() | ||
defer c.mu.Unlock() | ||
|
||
cn := c.root.next | ||
for cn != nil && cn != c.root { | ||
s = append(s, cn.value) | ||
cn = cn.next | ||
} | ||
return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
package cachettl | ||
|
||
import ( | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestCacheTTL(t *testing.T) { | ||
now := time.Now() | ||
|
||
c := New[string, string]() | ||
c.now = func() time.Time { return now } | ||
|
||
// nothing done so far, we expect the cache to be empty | ||
require.Nil(t, c.slice()) | ||
|
||
// insert the very first value | ||
c.Put("two", "222", 2) | ||
require.Equal(t, []string{"222"}, c.slice()) | ||
|
||
// insert the second value with an expiration higher than the first one | ||
c.Put("three", "333", 3) | ||
require.Equal(t, []string{"222", "333"}, c.slice()) | ||
|
||
// insert the third value with an expiration lower than all other values | ||
c.Put("one", "111", 1) | ||
require.Equal(t, []string{"111", "222", "333"}, c.slice()) | ||
|
||
// update "111" to have a higher expiration than all values | ||
c.Put("one", "111", 4) | ||
require.Equal(t, []string{"222", "333", "111"}, c.slice()) | ||
|
||
// update "333" to have a higher expiration than all values | ||
c.Put("three", "333", 5) | ||
require.Equal(t, []string{"222", "111", "333"}, c.slice()) | ||
|
||
// move time forward to expire "222" | ||
c.now = func() time.Time { return now.Add(1) } // "222" should still be there | ||
require.Empty(t, c.Get("whatever")) // trigger the cleanup | ||
require.Equal(t, []string{"222", "111", "333"}, c.slice()) | ||
|
||
c.now = func() time.Time { return now.Add(2) } // "222" should still be there | ||
require.Empty(t, c.Get("whatever")) // trigger the cleanup | ||
require.Equal(t, []string{"222", "111", "333"}, c.slice()) | ||
|
||
c.now = func() time.Time { return now.Add(3) } // "222" should be expired! | ||
require.Empty(t, c.Get("whatever")) // trigger the cleanup | ||
require.Equal(t, []string{"111", "333"}, c.slice()) | ||
|
||
// let's move a lot forward to expire everything | ||
c.now = func() time.Time { return now.Add(6) } | ||
require.Empty(t, c.Get("whatever")) // trigger the cleanup | ||
require.Nil(t, c.slice()) | ||
require.Len(t, c.m, 0) | ||
|
||
// now let's set a key, then move forward and get it directly without triggering with a different key | ||
c.now = func() time.Time { return now } | ||
c.Put("last", "999", 1) | ||
require.Equal(t, "999", c.Get("last")) | ||
require.Equal(t, []string{"999"}, c.slice()) | ||
c.now = func() time.Time { return now.Add(2) } | ||
require.Empty(t, c.Get("last")) // trigger the cleanup | ||
require.Nil(t, c.slice()) | ||
require.Len(t, c.m, 0) | ||
} | ||
|
||
func TestRefreshTTL(t *testing.T) { | ||
c := New[string, string]() | ||
|
||
// nothing done so far, we expect the cache to be empty | ||
require.Nil(t, c.slice()) | ||
|
||
c.Put("one", "111", time.Second) | ||
c.Put("two", "222", time.Second) | ||
c.Put("three", "333", time.Second) | ||
require.Equal(t, []string{"111", "222", "333"}, c.slice()) | ||
|
||
require.Equal(t, "111", c.Get("one")) | ||
require.Equal(t, []string{"222", "333", "111"}, c.slice()) | ||
|
||
require.Equal(t, "222", c.Get("two")) | ||
require.Equal(t, []string{"333", "111", "222"}, c.slice()) | ||
|
||
require.Equal(t, "333", c.Get("three")) | ||
require.Equal(t, []string{"111", "222", "333"}, c.slice()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.