From c3b29d9a0a172517c165570a4b34abf0276afbe5 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 26 Mar 2024 11:28:57 -0300 Subject: [PATCH 1/8] revamp --- backoff/mocks/mocks.go | 17 +- common/common.go | 95 +++ common/common_test.go | 80 +++ common/{iterutil.go => helpers.go} | 4 +- common/{iterutil_test.go => helpers_test.go} | 0 common/interface.go | 53 -- common/refutil.go | 69 --- common/refutil_test.go | 124 ---- common/sliceutil.go | 33 -- common/sliceutil_test.go | 41 -- common/strutil.go | 9 - common/strutil_test.go | 15 - common/timeutil.go | 25 - common/timeutil_test.go | 40 -- datastructures/cache/cache.go | 121 ++++ .../cache}/cache_race_test.go | 4 +- .../cache/{local_test.go => cache_test.go} | 71 ++- datastructures/cache/errors.go | 4 +- datastructures/cache/local.go | 102 ---- datastructures/cache/multilevel.go | 27 +- datastructures/cache/multilevel_test.go | 42 +- .../queuecache}/cache.go | 27 +- .../queuecache}/cache_test.go | 2 +- datastructures/set/functions.go | 56 -- datastructures/set/functions_test.go | 156 ----- datastructures/set/implementations.go | 355 ----------- datastructures/set/implementations_test.go | 561 ------------------ datastructures/set/set.go | 103 +++- datastructures/set/set_test.go | 41 ++ go.mod | 11 +- go.sum | 14 + {hasher => hashing}/hasher.go | 2 +- {hasher => hashing}/hasher_test.go | 2 +- {provisional/hashing => hashing}/murmur128.go | 0 .../hashing => hashing}/murmur128_test.go | 4 +- {hasher => hashing}/murmur32.go | 2 +- {hasher => hashing}/util.go | 2 +- {hasher => hashing}/util_test.go | 2 +- logging/interface.go | 5 + logging/levels.go | 39 ++ logging/levels_test.go | 253 ++++---- logging/logging.go | 25 + logging/mocks/mocks.go | 54 +- provisional/int64cache/cache.go | 91 --- provisional/int64cache/cache_test.go | 61 -- redis/wrapper_test.go | 4 +- .../jsonvalidator}/validator.go | 8 +- .../jsonvalidator}/validator_test.go | 2 +- {testfiles => testdata}/murmur3_64_uuids.csv | 0 testhelpers/helpers.go | 48 -- 50 files changed, 806 insertions(+), 2100 deletions(-) create mode 100644 common/common.go create mode 100644 common/common_test.go rename common/{iterutil.go => helpers.go} (86%) rename common/{iterutil_test.go => helpers_test.go} (100%) delete mode 100644 common/interface.go delete mode 100644 common/refutil.go delete mode 100644 common/refutil_test.go delete mode 100644 common/sliceutil.go delete mode 100644 common/sliceutil_test.go delete mode 100644 common/strutil.go delete mode 100644 common/strutil_test.go delete mode 100644 common/timeutil.go delete mode 100644 common/timeutil_test.go create mode 100644 datastructures/cache/cache.go rename {provisional/int64cache => datastructures/cache}/cache_race_test.go (88%) rename datastructures/cache/{local_test.go => cache_test.go} (67%) delete mode 100644 datastructures/cache/local.go rename {queuecache => datastructures/queuecache}/cache.go (74%) rename {queuecache => datastructures/queuecache}/cache_test.go (98%) delete mode 100644 datastructures/set/functions.go delete mode 100644 datastructures/set/functions_test.go delete mode 100644 datastructures/set/implementations.go delete mode 100644 datastructures/set/implementations_test.go create mode 100644 datastructures/set/set_test.go rename {hasher => hashing}/hasher.go (97%) rename {hasher => hashing}/hasher_test.go (99%) rename {provisional/hashing => hashing}/murmur128.go (100%) rename {provisional/hashing => hashing}/murmur128_test.go (88%) rename {hasher => hashing}/murmur32.go (99%) rename {hasher => hashing}/util.go (95%) rename {hasher => hashing}/util_test.go (96%) delete mode 100644 provisional/int64cache/cache.go delete mode 100644 provisional/int64cache/cache_test.go rename {json-struct-validator => struct/jsonvalidator}/validator.go (94%) rename {json-struct-validator => struct/jsonvalidator}/validator_test.go (99%) rename {testfiles => testdata}/murmur3_64_uuids.csv (100%) delete mode 100644 testhelpers/helpers.go diff --git a/backoff/mocks/mocks.go b/backoff/mocks/mocks.go index ee71a71..dca07d8 100644 --- a/backoff/mocks/mocks.go +++ b/backoff/mocks/mocks.go @@ -1,16 +1,23 @@ package mocks -import "time" +import ( + "github.com/splitio/go-toolkit/v5/backoff" + "github.com/stretchr/testify/mock" + "time" +) type BackoffMock struct { - NextCall func() time.Duration - ResetCall func() + mock.Mock } +// Next implements backoff.Interface. func (b *BackoffMock) Next() time.Duration { - return b.NextCall() + return b.Called().Get(0).(time.Duration) } +// Reset implements backoff.Interface. func (b *BackoffMock) Reset() { - b.ResetCall() + b.Called() } + +var _ backoff.Interface = (*BackoffMock)(nil) diff --git a/common/common.go b/common/common.go new file mode 100644 index 0000000..b3bbff8 --- /dev/null +++ b/common/common.go @@ -0,0 +1,95 @@ +package common + +import "cmp" + +// New helpers to be used when with newer go versions. +// the rest of the common package should be removed in v3, and consumers of the lib +// should only rely on these functions + +// Ref creates a copy of `x` in heap and returns a pointer to it +func Ref[T any](x T) *T { + return &x +} + +// RefOrNil returns a pointer to the value supplied if it's not the default value, nil otherwise +func RefOrNil[T comparable](x T) *T { + var t T + if x == t { + return nil + } + return &x +} + +// PointerOf performs a type-assertion to T and returns a pointer if successful, nil otherwise. +func PointerOf[T any](x interface{}) *T { + if x == nil { + return nil + } + + ta, ok := x.(T) + if !ok { + return nil + } + + return &ta +} + +// PartitionSliceByLength partitions a slice into multiple slices of up to `maxItems` size +func PartitionSliceByLength[T comparable](items []T, maxItems int) [][]T { + var splitted [][]T + for i := 0; i < len(items); i += maxItems { + end := i + maxItems + if end > len(items) { + end = len(items) + } + splitted = append(splitted, items[i:end]) + } + return splitted +} + +// DedupeInNewSlice creates a new slice from `items` without duplicate elements +func UnorderedDedupedCopy[T comparable](items []T) []T { + present := make(map[T]struct{}, len(items)) + for idx := range items { + present[items[idx]] = struct{}{} + } + + ret := make([]T, 0, len(present)) + for key := range present { + ret = append(ret, key) + } + + return ret +} + +// ValueOr returns the supplied value if it has something other than the default value +// for type T. Returns `fallback` otherwise +func ValueOr[T comparable](in T, fallback T) T { + var t T + if in == t { + return fallback + } + return in +} + +// Max returns the greatest item of all supplied +func Max[T cmp.Ordered](i1 T, rest ...T) T { + max := i1 + for idx := range rest { + if rest[idx] > max { + max = rest[idx] + } + } + return max +} + +// Min returns the minimum item of all supplied +func Min[T cmp.Ordered](i1 T, rest ...T) T { + min := i1 + for idx := range rest { + if rest[idx] < min { + min = rest[idx] + } + } + return min +} diff --git a/common/common_test.go b/common/common_test.go new file mode 100644 index 0000000..a010f41 --- /dev/null +++ b/common/common_test.go @@ -0,0 +1,80 @@ +package common + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRef(t *testing.T) { + str1 := "hello" + assert.Equal(t, &str1, Ref[string]("hello")) + i64 := int64(123456) + assert.Equal(t, &i64, Ref[int64](123456)) +} + +func TestRefOrNil(t *testing.T) { + str1 := "hello" + assert.Equal(t, &str1, RefOrNil("hello")) + assert.Equal(t, (*string)(nil), RefOrNil("")) +} + +func TestPartitionSliceByLength(t *testing.T) { + maxItems := 10000 + keys := make([]string, 0) + partition := PartitionSliceByLength(keys, maxItems) + if len(partition) != 0 { + t.Error("Unexpected quantity of partition") + } + + keys2 := make([]string, 0, 10000) + for i := 0; i < 300; i++ { + keys2 = append(keys2, fmt.Sprintf("test_%d", i)) + } + + partition2 := PartitionSliceByLength(keys2, maxItems) + if len(partition2) != 1 { + t.Error("Unexpected quantity of partition") + } + if len(partition2[0]) != 300 { + t.Error("Unexpected items per chunk") + } + + keys3 := make([]string, 0, 15000) + for i := 0; i < 15000; i++ { + keys3 = append(keys3, fmt.Sprintf("test_%d", i)) + } + + partition3 := PartitionSliceByLength(keys3, maxItems) + if len(partition3) != 2 { + t.Error("Unexpected quantity of partition") + } + if len(partition3[0]) != 10000 || len(partition3[1]) != 5000 { + t.Error("Unexpected items per chunk") + } +} + +func TestDedupeInNewSlice(t *testing.T) { + assert.ElementsMatch(t, []int{1, 2, 3}, UnorderedDedupedCopy([]int{3, 2, 2, 3, 1})) + assert.ElementsMatch(t, []int{1, 2, 3}, UnorderedDedupedCopy([]int{1, 2, 3})) + assert.ElementsMatch(t, []int{}, UnorderedDedupedCopy([]int{})) + assert.ElementsMatch(t, []string{"a", "c"}, UnorderedDedupedCopy([]string{"c", "c", "a"})) +} + +func TestValueOr(t *testing.T) { + assert.Equal(t, int64(3), ValueOr[int64](0, 3)) + assert.Equal(t, (*int)(nil), ValueOr[*int](nil, nil)) + assert.Equal(t, Ref(int(3)), ValueOr[*int](nil, Ref(int(3)))) + assert.Equal(t, Ref(int(4)), ValueOr[*int](Ref(int(4)), Ref(int(3)))) +} + +func TestMax(t *testing.T) { + assert.Equal(t, int(5), Max[int](1, 2, 3, 4, 5)) + assert.Equal(t, int(5), Max[int](5)) +} + +func TestMin(t *testing.T) { + assert.Equal(t, int(1), Min[int](1, 2, 3, 4, 5)) + assert.Equal(t, int(1), Min[int](1)) +} diff --git a/common/iterutil.go b/common/helpers.go similarity index 86% rename from common/iterutil.go rename to common/helpers.go index 4c1b920..570024b 100644 --- a/common/iterutil.go +++ b/common/helpers.go @@ -40,14 +40,14 @@ func WithBackoffCancelling(unit time.Duration, max time.Duration, main func() bo isDone := main() // Create timeout timer for backoff - backoffTimer := time.NewTimer(MinDuration(time.Duration(math.Pow(2, float64(attempts)))*unit, max)) + backoffTimer := time.NewTimer(Min(time.Duration(math.Pow(2, float64(attempts)))*unit, max)) defer backoffTimer.Stop() for !isDone { attempts++ // Setting timer considerint attempts - backoffTimer.Reset(MinDuration(time.Duration(math.Pow(2, float64(attempts)))*unit, max)) + backoffTimer.Reset(Min(time.Duration(math.Pow(2, float64(attempts)))*unit, max)) select { case <-cancel: diff --git a/common/iterutil_test.go b/common/helpers_test.go similarity index 100% rename from common/iterutil_test.go rename to common/helpers_test.go diff --git a/common/interface.go b/common/interface.go deleted file mode 100644 index b77e368..0000000 --- a/common/interface.go +++ /dev/null @@ -1,53 +0,0 @@ -package common - -// AsIntOrNil returns ref -func AsIntOrNil(data interface{}) *int { - if data == nil { - return nil - } - - number, ok := data.(int) - if !ok { - return nil - } - return IntRef(number) -} - -// AsInt64OrNil returns ref -func AsInt64OrNil(data interface{}) *int64 { - if data == nil { - return nil - } - - number, ok := data.(int64) - if !ok { - return nil - } - return Int64Ref(number) -} - -// AsFloat64OrNil return ref -func AsFloat64OrNil(data interface{}) *float64 { - if data == nil { - return nil - } - - number, ok := data.(float64) - if !ok { - return nil - } - return Float64Ref(number) -} - -// AsStringOrNil returns ref -func AsStringOrNil(data interface{}) *string { - if data == nil { - return nil - } - - str, ok := data.(string) - if !ok { - return nil - } - return StringRef(str) -} diff --git a/common/refutil.go b/common/refutil.go deleted file mode 100644 index 04bea36..0000000 --- a/common/refutil.go +++ /dev/null @@ -1,69 +0,0 @@ -package common - -// StringRef returns ref -func StringRef(str string) *string { - return &str -} - -// StringFromRef returns original value if not empty. Default otherwise. -func StringFromRef(str *string) string { - if str == nil { - return "" - } - return *str -} - -// IntRef returns ref -func IntRef(number int) *int { - return &number -} - -// IntFromRef returns 0 if nil, dereferenced value otherwhise. -func IntFromRef(ref *int) int { - if ref == nil { - return 0 - } - return *ref -} - -// Int64Ref returns ref -func Int64Ref(number int64) *int64 { - return &number -} - -// Int64FromRef returns value -func Int64FromRef(number *int64) int64 { - if number == nil { - return 0 - } - return *number -} - -// Float64Ref returns ref -func Float64Ref(number float64) *float64 { - return &number -} - -// IntRefOrNil returns ref -func IntRefOrNil(number int) *int { - if number == 0 { - return nil - } - return IntRef(number) -} - -// Int64RefOrNil returns ref -func Int64RefOrNil(number int64) *int64 { - if number == 0 { - return nil - } - return Int64Ref(number) -} - -// StringRefOrNil returns ref -func StringRefOrNil(str string) *string { - if str == "" { - return nil - } - return StringRef(str) -} diff --git a/common/refutil_test.go b/common/refutil_test.go deleted file mode 100644 index dde6618..0000000 --- a/common/refutil_test.go +++ /dev/null @@ -1,124 +0,0 @@ -package common - -import ( - "testing" -) - -func TestStringRef(t *testing.T) { - a := "someString" - if a != *StringRef(a) { - t.Error("Wrong string reference") - } -} - -func TestStringFromRef(t *testing.T) { - s1 := "string1" - var s2 *string = nil - if StringFromRef(&s1) != "string1" { - t.Error("Should have returned original value") - } - - if StringFromRef(s2) != "" { - t.Error("Should have returned empty string") - } -} - -func TestInt64Ref(t *testing.T) { - a := int64(3) - if a != *Int64Ref(a) { - t.Error("Wrong int64 reference") - } -} - -func TestInt64FromRef(t *testing.T) { - a := int64(3) - if Int64FromRef(&a) != 3 { - t.Error("Should be 3") - } - - if Int64FromRef(nil) != 0 { - t.Error("Should be 0") - } -} - -func TestIntRef(t *testing.T) { - a := int(43) - if a != *IntRef(a) { - t.Error("Wrong int reference") - } -} - -func TestStringRefOrNil(t *testing.T) { - a := "someString" - if a != *StringRefOrNil(a) { - t.Error("Wrong string reference") - } - - if StringRefOrNil("") != nil { - t.Error("Should be nil") - } -} - -func TestInt64RefOrNil(t *testing.T) { - a := int64(3) - if a != *Int64RefOrNil(a) { - t.Error("Wrong int64 reference") - } - if Int64RefOrNil(0) != nil { - t.Error("Should be nil") - } -} - -func TestIntRefOrNil(t *testing.T) { - a := int(43) - if a != *IntRefOrNil(a) { - t.Error("Wrong int reference") - } - if IntRefOrNil(0) != nil { - t.Error("Should be nil") - } -} - -func TestAsIntOrNil(t *testing.T) { - a := AsIntOrNil(123456789) - if a == nil { - t.Error("Wrong int reference") - } - b := AsIntOrNil("some") - if b != nil { - t.Error("It should be nil") - } -} - -func TestAsInt64OrNil(t *testing.T) { - a := AsInt64OrNil(int64(123456789)) - if a == nil { - t.Error("Wrong int reference") - } - b := AsInt64OrNil("some") - if b != nil { - t.Error("It should be nil") - } -} - -func TestAsFloat64OrNil(t *testing.T) { - a := AsFloat64OrNil(float64(123456789)) - if a == nil { - t.Error("Wrong int reference") - } - b := AsFloat64OrNil("some") - if b != nil { - t.Error("It should be nil") - } -} - -func TestAsStringOrNil(t *testing.T) { - a := AsStringOrNil("some") - if a == nil { - t.Error("Wrong string reference") - } - b := AsStringOrNil(int64(123456789)) - if b != nil { - t.Error("It should be nil") - } -} diff --git a/common/sliceutil.go b/common/sliceutil.go deleted file mode 100644 index 676ea42..0000000 --- a/common/sliceutil.go +++ /dev/null @@ -1,33 +0,0 @@ -package common - -// Partition create partitions considering the passed amount -func Partition(items []string, maxItems int) [][]string { - var splitted [][]string - - for i := 0; i < len(items); i += maxItems { - end := i + maxItems - - if end > len(items) { - end = len(items) - } - - splitted = append(splitted, items[i:end]) - } - - return splitted -} - -// DedupeStringSlice returns a new slice without duplicate strings -func DedupeStringSlice(items []string) []string { - present := make(map[string]struct{}, len(items)) - for idx := range items { - present[items[idx]] = struct{}{} - } - - ret := make([]string, 0, len(present)) - for key := range present { - ret = append(ret, key) - } - - return ret -} diff --git a/common/sliceutil_test.go b/common/sliceutil_test.go deleted file mode 100644 index f5f1726..0000000 --- a/common/sliceutil_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package common - -import ( - "fmt" - "testing" -) - -func TestPartition(t *testing.T) { - maxItems := 10000 - keys := make([]string, 0) - partition := Partition(keys, maxItems) - if len(partition) != 0 { - t.Error("Unexpected quantity of partition") - } - - keys2 := make([]string, 0, 10000) - for i := 0; i < 300; i++ { - keys2 = append(keys2, fmt.Sprintf("test_%d", i)) - } - - partition2 := Partition(keys2, maxItems) - if len(partition2) != 1 { - t.Error("Unexpected quantity of partition") - } - if len(partition2[0]) != 300 { - t.Error("Unexpected items per chunk") - } - - keys3 := make([]string, 0, 15000) - for i := 0; i < 15000; i++ { - keys3 = append(keys3, fmt.Sprintf("test_%d", i)) - } - - partition3 := Partition(keys3, maxItems) - if len(partition3) != 2 { - t.Error("Unexpected quantity of partition") - } - if len(partition3[0]) != 10000 || len(partition3[1]) != 5000 { - t.Error("Unexpected items per chunk") - } -} diff --git a/common/strutil.go b/common/strutil.go deleted file mode 100644 index 35dfca1..0000000 --- a/common/strutil.go +++ /dev/null @@ -1,9 +0,0 @@ -package common - -// StringValueOrDefault returns original value if not empty. Default otherwise. -func StringValueOrDefault(str string, def string) string { - if str != "" { - return str - } - return def -} diff --git a/common/strutil_test.go b/common/strutil_test.go deleted file mode 100644 index c70f93c..0000000 --- a/common/strutil_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package common - -import ( - "testing" -) - -func TestStringValueOrDefault(t *testing.T) { - if StringValueOrDefault("abc", "def") != "abc" { - t.Error("Should have returned original value") - } - - if StringValueOrDefault("", "def") != "def" { - t.Error("Should have returned default value") - } -} diff --git a/common/timeutil.go b/common/timeutil.go deleted file mode 100644 index d166bef..0000000 --- a/common/timeutil.go +++ /dev/null @@ -1,25 +0,0 @@ -package common - -import "time" - -// MinDuration returns the min duration among them -func MinDuration(d1, d2 time.Duration, ds ...time.Duration) time.Duration { - min := d1 - for _, d := range append(ds, d2) { - if d < min { - min = d - } - } - return min -} - -// MaxDuration returns the max duration among them -func MaxDuration(d1, d2 time.Duration, ds ...time.Duration) time.Duration { - max := d1 - for _, d := range append(ds, d2) { - if d > max { - max = d - } - } - return max -} diff --git a/common/timeutil_test.go b/common/timeutil_test.go deleted file mode 100644 index 0db316d..0000000 --- a/common/timeutil_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package common - -import ( - "testing" - "time" -) - -func TestMinDuration(t *testing.T) { - m := MinDuration(1*time.Second, 2*time.Second) - if m.Seconds() != 1 { - t.Error("Unexpected result") - } - - m2 := MinDuration(2*time.Second, 1*time.Second) - if m2.Seconds() != 1 { - t.Error("Unexpected result") - } - - m3 := MinDuration(2*time.Minute, 1*time.Second, 2*time.Millisecond) - if m3.Milliseconds() != 2 { - t.Error("Unexpected result") - } -} - -func TestMaxDuration(t *testing.T) { - m := MaxDuration(1*time.Second, 2*time.Second) - if m.Seconds() != 2 { - t.Error("Unexpected result") - } - - m2 := MaxDuration(2*time.Second, 1*time.Second) - if m2.Seconds() != 2 { - t.Error("Unexpected result") - } - - m3 := MaxDuration(4*time.Minute, 1*time.Second, 2*time.Millisecond) - if m3.Seconds() != 240 { - t.Error("Unexpected result") - } -} diff --git a/datastructures/cache/cache.go b/datastructures/cache/cache.go new file mode 100644 index 0000000..c1ed7dc --- /dev/null +++ b/datastructures/cache/cache.go @@ -0,0 +1,121 @@ +package cache + +import ( + "container/list" + "fmt" + "sync" + "time" +) + +const ( + NoTTL = -1 +) + +// SimpleLRU is an in-memory TTL & LRU cache +type SimpleLRU[K comparable, V comparable] interface { + Get(key K) (V, error) + Set(key K, value V) error +} + +// SimpleLRUImpl implements the Simple interface +type SimpleLRUImpl[K comparable, V comparable] struct { + ttl time.Duration + maxLen int + ttls map[K]time.Time + items map[K]*list.Element + lru *list.List + mutex sync.Mutex +} + +type centry[K comparable, V comparable] struct { + key K + value V +} + +// Get retrieves an item if exist, nil + an error otherwise +func (c *SimpleLRUImpl[K, V]) Get(key K) (V, error) { + var empty V + c.mutex.Lock() + defer c.mutex.Unlock() + node, ok := c.items[key] + if !ok { + return empty, &Miss{Where: "LOCAL", Key: key} + } + + entry, ok := node.Value.(centry[K, V]) + if !ok { + return empty, fmt.Errorf("Invalid data in cache for key %v", key) + } + + if c.ttls != nil { // TTL enabled + ttl, ok := c.ttls[key] + if !ok { + return empty, fmt.Errorf( + "Missing TTL for key %v. Wrapping as expired: %w", + key, + &Expired{Key: key, Value: entry.value, When: ttl.Add(c.ttl)}, + ) + } + + if time.Now().UnixNano() > ttl.UnixNano() { + return empty, &Expired{Key: key, Value: entry.value, When: ttl.Add(c.ttl)} + } + } + + c.lru.MoveToFront(node) + return entry.value, nil +} + +// Set adds a new item. Since the cache being full results in removing the LRU element, this method never fails. +func (c *SimpleLRUImpl[K, V]) Set(key K, value V) error { + c.mutex.Lock() + defer c.mutex.Unlock() + if node, ok := c.items[key]; ok { + c.lru.MoveToFront(node) + node.Value = centry[K, V]{key: key, value: value} + } else { + // Drop the LRU item on the list before adding a new one. + if c.lru.Len() == c.maxLen { + entry, ok := c.lru.Back().Value.(centry[K, V]) + if !ok { + return fmt.Errorf("Invalid data in list for key %v", key) + } + key := entry.key + delete(c.items, key) + if c.ttls != nil { + delete(c.ttls, key) + } + c.lru.Remove(c.lru.Back()) + } + + ptr := c.lru.PushFront(centry[K, V]{key: key, value: value}) + c.items[key] = ptr + } + + if c.ttls != nil { + c.ttls[key] = time.Now().Add(c.ttl) + } + return nil +} + +// NewSimple returns a new Simple instance of the specified size and TTL +func NewSimpleLRU[K comparable, V comparable](maxSize int, ttl time.Duration) (*SimpleLRUImpl[K, V], error) { + if maxSize <= 0 { + return nil, fmt.Errorf("Cache size should be > 0. Is: %d", maxSize) + } + + var ttls map[K]time.Time = nil + if ttl != NoTTL { + ttls = make(map[K]time.Time) + } + + return &SimpleLRUImpl[K, V]{ + maxLen: maxSize, + ttl: ttl, + lru: new(list.List), + items: make(map[K]*list.Element, maxSize), + ttls: ttls, + }, nil +} + +var _ SimpleLRU[string, int] = (*SimpleLRUImpl[string, int])(nil) diff --git a/provisional/int64cache/cache_race_test.go b/datastructures/cache/cache_race_test.go similarity index 88% rename from provisional/int64cache/cache_race_test.go rename to datastructures/cache/cache_race_test.go index 255ef20..a8dbbe0 100644 --- a/provisional/int64cache/cache_race_test.go +++ b/datastructures/cache/cache_race_test.go @@ -1,6 +1,6 @@ // +build !race -package int64cache +package cache import ( "math/rand" @@ -10,7 +10,7 @@ import ( func TestLocalCacheHighConcurrency(t *testing.T) { - c, err := NewInt64Cache(500) + c, err := NewSimpleLRU[int64, int64](500, NoTTL) if err != nil { t.Error("No error should have been returned. Got: ", err) } diff --git a/datastructures/cache/local_test.go b/datastructures/cache/cache_test.go similarity index 67% rename from datastructures/cache/local_test.go rename to datastructures/cache/cache_test.go index d4e8ea7..9ff36cc 100644 --- a/datastructures/cache/local_test.go +++ b/datastructures/cache/cache_test.go @@ -1,5 +1,3 @@ -// +build !race - package cache import ( @@ -10,8 +8,8 @@ import ( "time" ) -func TestLocalCache(t *testing.T) { - cache, err := NewLocalCache(5, 1*time.Second) +func TestSimpleCache(t *testing.T) { + cache, err := NewSimpleLRU[string, int](5, 1*time.Second) if err != nil { t.Error("No error should have been returned. Got: ", err) } @@ -50,7 +48,7 @@ func TestLocalCache(t *testing.T) { t.Errorf("Incorrect data within the Miss error. Got: %+v", asMiss) } - if val != nil { + if val != 0 { t.Errorf("Value for key 'someKey1' should be nil. Is %d", val) } @@ -80,7 +78,7 @@ func TestLocalCache(t *testing.T) { time.Sleep(2 * time.Second) // Wait for all keys to expire. for i := 2; i <= 6; i++ { val, err := cache.Get(fmt.Sprintf("someKey%d", i)) - if val != nil { + if val != 0 { t.Errorf("No value should have been returned for expired key 'someKey%d'.", i) } @@ -115,9 +113,9 @@ func TestLocalCache(t *testing.T) { } } -func TestLocalCacheHighConcurrency(t *testing.T) { +func TestSimpleCacheHighConcurrency(t *testing.T) { - cache, err := NewLocalCache(500, 1*time.Second) + cache, err := NewSimpleLRU[string, int](500, 1*time.Second) if err != nil { t.Error("No error should have been returned. Got: ", err) } @@ -142,3 +140,60 @@ func TestLocalCacheHighConcurrency(t *testing.T) { } wg.Wait() } + + +func TestInt64Cache(t *testing.T) { + c, err := NewSimpleLRU[int64, int64](5, NoTTL) + if err != nil { + t.Error("No error should have been returned. Got: ", err) + } + + for i := int64(1); i <= 5; i++ { + err := c.Set(i, i) + if err != nil { + t.Errorf("Setting value '%d', should not have raised an error. Got: %s", i, err) + } + } + + for i := int64(1); i <= 5; i++ { + val, err := c.Get(i) + if err != nil { + t.Errorf("Getting value '%d', should not have raised an error. Got: %s", i, err) + } + if val != i { + t.Errorf("Value for key '%d' should be %d. Is %d", i, i, val) + } + } + + c.Set(6, 6) + + // Oldest item (1) should have been removed + val, err := c.Get(1) + if err == nil { + t.Errorf("Getting value 'someKey1', should not have raised an error. Got: %s", err) + } + + _, ok := err.(*Miss) + if !ok { + t.Errorf("Error should be of type Miss. Is %T", err) + } + + if val != 0 { + t.Errorf("Value for key 'someKey1' should be nil. Is %d", val) + } + + // 2-6 should be available + for i := int64(2); i <= 6; i++ { + val, err := c.Get(i) + if err != nil { + t.Errorf("Getting value '%d', should not have raised an error. Got: %s", i, err) + } + if val != i { + t.Errorf("Value for key '%d' should be %d. Is %d", i, i, val) + } + } + + if len(c.items) != 5 { + t.Error("Items size should be 5. is: ", len(c.items)) + } +} diff --git a/datastructures/cache/errors.go b/datastructures/cache/errors.go index d8015bb..5e4a559 100644 --- a/datastructures/cache/errors.go +++ b/datastructures/cache/errors.go @@ -8,7 +8,7 @@ import ( // Miss is a special type of error indicating that a key was not found type Miss struct { Where string - Key string + Key interface{} } func (c *Miss) Error() string { @@ -17,7 +17,7 @@ func (c *Miss) Error() string { // Expired is a special type of error indicating that a key is no longer valid (value is still attached in the error) type Expired struct { - Key string + Key interface{} When time.Time Value interface{} } diff --git a/datastructures/cache/local.go b/datastructures/cache/local.go deleted file mode 100644 index 699d3b8..0000000 --- a/datastructures/cache/local.go +++ /dev/null @@ -1,102 +0,0 @@ -package cache - -import ( - "container/list" - "fmt" - "sync" - "time" -) - -// LocalCache is an in-memory TTL & LRU cache -type LocalCache interface { - Get(key string) (interface{}, error) - Set(key string, value interface{}) error -} - -// LocalCacheImpl implements the LocalCache interface -type LocalCacheImpl struct { - ttl time.Duration - maxLen int - ttls map[string]time.Time - items map[string]*list.Element - lru *list.List - mutex sync.Mutex -} - -type entry struct { - key string - value interface{} -} - -// Get retrieves an item if exist, nil + an error otherwise -func (c *LocalCacheImpl) Get(key string) (interface{}, error) { - c.mutex.Lock() - defer c.mutex.Unlock() - node, ok := c.items[key] - if !ok { - return nil, &Miss{Where: "LOCAL", Key: key} - } - - entry, ok := node.Value.(entry) - if !ok { - return nil, fmt.Errorf("Invalid data in cache for key %s", key) - } - - ttl, ok := c.ttls[key] - if !ok { - return nil, fmt.Errorf( - "Missing TTL for key %s. Wrapping as expired: %w", - key, - &Expired{Key: key, Value: entry.value, When: ttl.Add(c.ttl)}, - ) - } - - if time.Now().UnixNano() > ttl.UnixNano() { - return nil, &Expired{Key: key, Value: entry.value, When: ttl.Add(c.ttl)} - } - - c.lru.MoveToFront(node) - return entry.value, nil -} - -// Set adds a new item. Since the cache being full results in removing the LRU element, this method never fails. -func (c *LocalCacheImpl) Set(key string, value interface{}) error { - c.mutex.Lock() - defer c.mutex.Unlock() - if node, ok := c.items[key]; ok { - c.lru.MoveToFront(node) - node.Value = entry{key: key, value: value} - } else { - // Drop the LRU item on the list before adding a new one. - if c.lru.Len() == c.maxLen { - entry, ok := c.lru.Back().Value.(entry) - if !ok { - return fmt.Errorf("Invalid data in list for key %s", key) - } - key := entry.key - delete(c.items, key) - delete(c.ttls, key) - c.lru.Remove(c.lru.Back()) - } - - ptr := c.lru.PushFront(entry{key: key, value: value}) - c.items[key] = ptr - } - c.ttls[key] = time.Now().Add(c.ttl) - return nil -} - -// NewLocalCache returns a new LocalCache instance of the specified size and TTL -func NewLocalCache(maxSize int, ttl time.Duration) (*LocalCacheImpl, error) { - if maxSize <= 0 { - return nil, fmt.Errorf("Cache size should be > 0. Is: %d", maxSize) - } - - return &LocalCacheImpl{ - maxLen: maxSize, - ttl: ttl, - lru: new(list.List), - items: make(map[string]*list.Element, maxSize), - ttls: make(map[string]time.Time), - }, nil -} diff --git a/datastructures/cache/multilevel.go b/datastructures/cache/multilevel.go index d0f0451..917eb8d 100644 --- a/datastructures/cache/multilevel.go +++ b/datastructures/cache/multilevel.go @@ -7,26 +7,26 @@ import ( ) // MLCLayer is the interface that should be implemented for all caching structs to be used with this piece of code. -type MLCLayer interface { - Get(ctx context.Context, key string) (interface{}, error) - Set(ctx context.Context, key string, value interface{}) error +type MLCLayer[K comparable, V comparable] interface { + Get(ctx context.Context, key K) (V, error) + Set(ctx context.Context, key K, value V) error } // MultiLevelCache bundles a list of ordered cache layers (upper -> lower) -type MultiLevelCache interface { - Get(ctx context.Context, key string) (interface{}, error) +type MultiLevelCache[K comparable, V comparable] interface { + Get(ctx context.Context, key K) (V, error) } // MultiLevelCacheImpl implements the MultiLevelCache interface -type MultiLevelCacheImpl struct { - layers []MLCLayer +type MultiLevelCacheImpl[K comparable, V comparable] struct { + layers []MLCLayer[K, V] logger logging.LoggerInterface } // Get returns the value of the requested key (if found) and populates upper levels with it -func (c *MultiLevelCacheImpl) Get(ctx context.Context, key string) (interface{}, error) { +func (c *MultiLevelCacheImpl[K, V]) Get(ctx context.Context, key K) (V, error) { toUpdate := make([]int, 0, len(c.layers)) - var item interface{} + var item V var err error for index, layer := range c.layers { item, err = layer.Get(ctx, key) @@ -49,8 +49,9 @@ func (c *MultiLevelCacheImpl) Get(ctx context.Context, key string) (interface{}, } } - if item == nil || err != nil { - return nil, &Miss{Where: "ALL_LEVELS", Key: key} + var empty V + if item == empty || err != nil { + return empty, &Miss{Where: "ALL_LEVELS", Key: key} } // Update upper layers if any @@ -66,10 +67,10 @@ func (c *MultiLevelCacheImpl) Get(ctx context.Context, key string) (interface{}, } // NewMultiLevel creates and returns a new MultiLevelCache instance -func NewMultiLevel(layers []MLCLayer, logger logging.LoggerInterface) (*MultiLevelCacheImpl, error) { +func NewMultiLevel[K comparable, V comparable](layers []MLCLayer[K, V], logger logging.LoggerInterface) (*MultiLevelCacheImpl[K, V], error) { if logger == nil { logger = logging.NewLogger(nil) } - return &MultiLevelCacheImpl{layers: layers, logger: logger}, nil + return &MultiLevelCacheImpl[K, V]{layers: layers, logger: logger}, nil } diff --git a/datastructures/cache/multilevel_test.go b/datastructures/cache/multilevel_test.go index b3306a3..c3549d2 100644 --- a/datastructures/cache/multilevel_test.go +++ b/datastructures/cache/multilevel_test.go @@ -10,15 +10,15 @@ import ( ) type LayerMock struct { - getCall func(ctx context.Context, key string) (interface{}, error) - setCall func(ctx context.Context, key string, value interface{}) error + getCall func(ctx context.Context, key string) (string, error) + setCall func(ctx context.Context, key string, value string) error } -func (m *LayerMock) Get(ctx context.Context, key string) (interface{}, error) { +func (m *LayerMock) Get(ctx context.Context, key string) (string, error) { return m.getCall(ctx, key) } -func (m *LayerMock) Set(ctx context.Context, key string, value interface{}) error { +func (m *LayerMock) Set(ctx context.Context, key string, value string) error { return m.setCall(ctx, key, value) } @@ -56,20 +56,20 @@ func TestMultiLevelCache(t *testing.T) { // Bottom layer fails if key1 or 2 are requested, has key 3. returns Miss if any other key is requested calls := newCallTracker(t) topLayer := &LayerMock{ - getCall: func(ctx context.Context, key string) (interface{}, error) { + getCall: func(ctx context.Context, key string) (string, error) { calls.track(fmt.Sprintf("top:get:%s", key)) switch key { case "key1": return "value1", nil case "key2": - return nil, &Miss{Where: "layer1", Key: "key2"} + return "", &Miss{Where: "layer1", Key: "key2"} case "key3": - return nil, &Expired{Key: "key3", Value: "someOtherValue"} + return "", &Expired{Key: "key3", Value: "someOtherValue"} default: - return nil, errors.New("someError") + return "", errors.New("someError") } }, - setCall: func(ctx context.Context, key string, value interface{}) error { + setCall: func(ctx context.Context, key string, value string) error { calls.track(fmt.Sprintf("top:set:%s", key)) switch key { case "key1": @@ -87,19 +87,19 @@ func TestMultiLevelCache(t *testing.T) { } midLayer := &LayerMock{ - getCall: func(ctx context.Context, key string) (interface{}, error) { + getCall: func(ctx context.Context, key string) (string, error) { calls.track(fmt.Sprintf("mid:get:%s", key)) switch key { case "key1": t.Error("Get should not be called on the mid layer for key1") - return nil, nil + return "", nil case "key2": return "value2", nil default: - return nil, &Miss{Where: "layer2", Key: key} + return "", &Miss{Where: "layer2", Key: key} } }, - setCall: func(ctx context.Context, key string, value interface{}) error { + setCall: func(ctx context.Context, key string, value string) error { calls.track(fmt.Sprintf("mid:set:%s", key)) switch key { case "key1": @@ -115,22 +115,22 @@ func TestMultiLevelCache(t *testing.T) { } bottomLayer := &LayerMock{ - getCall: func(ctx context.Context, key string) (interface{}, error) { + getCall: func(ctx context.Context, key string) (string, error) { calls.track(fmt.Sprintf("bot:get:%s", key)) switch key { case "key1": t.Error("Get should not be called on the mid layer for key1") - return nil, nil + return "", nil case "key2": t.Error("Get should not be called on the mid layer for key1") - return nil, nil + return "", nil case "key3": return "value3", nil default: - return nil, &Miss{Where: "layer3", Key: key} + return "", &Miss{Where: "layer3", Key: key} } }, - setCall: func(ctx context.Context, key string, value interface{}) error { + setCall: func(ctx context.Context, key string, value string) error { calls.track(fmt.Sprintf("bot:set:%s", key)) switch key { case "key1": @@ -144,9 +144,9 @@ func TestMultiLevelCache(t *testing.T) { }, } - cacheML := MultiLevelCacheImpl{ + cacheML := MultiLevelCacheImpl[string, string]{ logger: logging.NewLogger(nil), - layers: []MLCLayer{topLayer, midLayer, bottomLayer}, + layers: []MLCLayer[string, string]{topLayer, midLayer, bottomLayer}, } value1, err := cacheML.Get(context.TODO(), "key1") @@ -203,7 +203,7 @@ func TestMultiLevelCache(t *testing.T) { t.Errorf("Incorrect 'Where' or 'Key'. Got: %+v", asMiss) } - if value4 != nil { + if value4 != "" { t.Errorf("Value returned for GET 'key4' should be nil. Is: %+v", value4) } calls.checkCall("top:get:key4", 1) diff --git a/queuecache/cache.go b/datastructures/queuecache/cache.go similarity index 74% rename from queuecache/cache.go rename to datastructures/queuecache/cache.go index 5bba1a9..617dbd6 100644 --- a/queuecache/cache.go +++ b/datastructures/queuecache/cache.go @@ -29,19 +29,19 @@ func (e *MessagesDroppedError) Error() string { } // InMemoryQueueCacheOverlay offers an in-memory queue that gets re-populated whenever it runs out of items -type InMemoryQueueCacheOverlay struct { +type InMemoryQueueCacheOverlay[T any] struct { maxSize int writeCursor int readCursor int - queue []interface{} + queue []T lock sync.Mutex - refillCustom func(count int) ([]interface{}, error) + refillCustom func(count int) ([]T, error) } // New creates a new InMemoryQueueCacheOverlay -func New(maxSize int, refillFunc func(count int) ([]interface{}, error)) *InMemoryQueueCacheOverlay { - return &InMemoryQueueCacheOverlay{ - queue: make([]interface{}, maxSize), +func New[T any](maxSize int, refillFunc func(count int) ([]T, error)) *InMemoryQueueCacheOverlay[T] { + return &InMemoryQueueCacheOverlay[T]{ + queue: make([]T, maxSize), maxSize: maxSize, writeCursor: 0, readCursor: 0, @@ -50,7 +50,7 @@ func New(maxSize int, refillFunc func(count int) ([]interface{}, error)) *InMemo } // Count returns the number of cached items -func (i *InMemoryQueueCacheOverlay) Count() int { +func (i *InMemoryQueueCacheOverlay[T]) Count() int { if i.writeCursor == i.readCursor { return 0 } else if i.writeCursor > i.readCursor { @@ -59,7 +59,7 @@ func (i *InMemoryQueueCacheOverlay) Count() int { return i.maxSize - (i.readCursor - i.writeCursor) } -func (i *InMemoryQueueCacheOverlay) write(elem interface{}) error { +func (i *InMemoryQueueCacheOverlay[T]) write(elem T) error { if ((i.writeCursor + 1) % i.maxSize) == i.readCursor { return errors.New("QUEUE_FULL") } @@ -69,9 +69,10 @@ func (i *InMemoryQueueCacheOverlay) write(elem interface{}) error { return nil } -func (i *InMemoryQueueCacheOverlay) read() (interface{}, error) { +func (i *InMemoryQueueCacheOverlay[T]) read() (T, error) { if i.readCursor == i.writeCursor { - return nil, errors.New("QUEUE_EMPTY") + var t T + return t, errors.New("QUEUE_EMPTY") } toReturn := i.queue[i.readCursor] @@ -79,7 +80,7 @@ func (i *InMemoryQueueCacheOverlay) read() (interface{}, error) { return toReturn, nil } -func (i *InMemoryQueueCacheOverlay) refillWrapper(count int) (result []interface{}, err error) { +func (i *InMemoryQueueCacheOverlay[T]) refillWrapper(count int) (result []T, err error) { defer func() { if r := recover(); r != nil { result = nil @@ -92,7 +93,7 @@ func (i *InMemoryQueueCacheOverlay) refillWrapper(count int) (result []interface } // Fetch items (will re-populate if necessary) -func (i *InMemoryQueueCacheOverlay) Fetch(requestedCount int) ([]interface{}, error) { +func (i *InMemoryQueueCacheOverlay[T]) Fetch(requestedCount int) ([]T, error) { defer i.lock.Unlock() i.lock.Lock() @@ -110,7 +111,7 @@ func (i *InMemoryQueueCacheOverlay) Fetch(requestedCount int) ([]interface{}, er } } - toReturn := make([]interface{}, int(math.Min(float64(requestedCount), float64(i.Count())))) + toReturn := make([]T, int(math.Min(float64(requestedCount), float64(i.Count())))) for index := 0; index < len(toReturn); index++ { elem, err := i.read() if err != nil { diff --git a/queuecache/cache_test.go b/datastructures/queuecache/cache_test.go similarity index 98% rename from queuecache/cache_test.go rename to datastructures/queuecache/cache_test.go index 45216e1..6e5d734 100644 --- a/queuecache/cache_test.go +++ b/datastructures/queuecache/cache_test.go @@ -133,7 +133,7 @@ func TestRefillPanic(t *testing.T) { } func TestCountWorksProperly(t *testing.T) { - cache := InMemoryQueueCacheOverlay{maxSize: 100} + cache := InMemoryQueueCacheOverlay[int]{maxSize: 100} cache.readCursor = 0 cache.writeCursor = 0 diff --git a/datastructures/set/functions.go b/datastructures/set/functions.go deleted file mode 100644 index 5931a28..0000000 --- a/datastructures/set/functions.go +++ /dev/null @@ -1,56 +0,0 @@ -package set - -// Union calculates the union of two or more sets -func Union(set1, set2 Set, sets ...Set) Set { - u := set1.Copy() - set2.Each(func(item interface{}) bool { - u.Add(item) - return true - }) - - for _, set := range sets { - set.Each(func(item interface{}) bool { - u.Add(item) - return true - }) - } - - return u -} - -// Intersection calculates the intersection of two or more sets -func Intersection(set1, set2 Set, sets ...Set) Set { - all := Union(set1, set2, sets...) - result := Union(set1, set2, sets...) - - all.Each(func(item interface{}) bool { - if !set1.Has(item) || !set2.Has(item) { - result.Remove(item) - } - - for _, set := range sets { - if !set.Has(item) { - result.Remove(item) - } - } - return true - }) - return result -} - -// Difference calculates the difference of two or more sets -func Difference(set1, set2 Set, sets ...Set) Set { - s := set1.Copy() - s.Separate(set2) - for _, set := range sets { - s.Separate(set) // seperate is thread safe - } - return s -} - -// SymmetricDifference calculates the symmetric difference of two or more sets -func SymmetricDifference(s Set, t Set) Set { - u := Difference(s, t) - v := Difference(t, s) - return Union(u, v) -} diff --git a/datastructures/set/functions_test.go b/datastructures/set/functions_test.go deleted file mode 100644 index e764130..0000000 --- a/datastructures/set/functions_test.go +++ /dev/null @@ -1,156 +0,0 @@ -package set - -import ( - "reflect" - "testing" -) - -func Test_Union(t *testing.T) { - s := NewThreadSafeSet("1", "2", "3") - r := NewThreadSafeSet("3", "4", "5") - x := NewSet("5", "6", "7") - - u := Union(s, r, x) - if settype := reflect.TypeOf(u).String(); settype != "*set.ThreadSafeSet" { - t.Error("Union should derive its set type from the first passed set, got", settype) - } - if u.Size() != 7 { - t.Error("Union: the merged set doesn't have all items in it.") - } - - if !u.Has("1", "2", "3", "4", "5", "6", "7") { - t.Error("Union: merged items are not availabile in the set.") - } - - z := Union(x, r) - if z.Size() != 5 { - t.Error("Union: Union of 2 sets doesn't have the proper number of items.") - } - if settype := reflect.TypeOf(z).String(); settype != "*set.ThreadUnsafeSet" { - t.Error("Union should derive its set type from the first passed set, got", settype) - } - -} - -func Test_Difference(t *testing.T) { - s := NewThreadSafeSet("1", "2", "3") - r := NewThreadSafeSet("3", "4", "5") - x := NewThreadSafeSet("5", "6", "7") - u := Difference(s, r, x) - - if u.Size() != 2 { - t.Error("Difference: the set doesn't have all items in it.") - } - - if !u.Has("1", "2") { - t.Error("Difference: items are not availabile in the set.") - } - - y := Difference(r, r) - if y.Size() != 0 { - t.Error("Difference: size should be zero") - } - -} - -func Test_Intersection(t *testing.T) { - s1 := NewThreadSafeSet("1", "3", "4", "5") - s2 := NewThreadSafeSet("2", "3", "5", "6") - s3 := NewThreadSafeSet("4", "5", "6", "7") - u := Intersection(s1, s2, s3) - - if u.Size() != 1 { - t.Error("Intersection: the set doesn't have all items in it.") - } - - if !u.Has("5") { - t.Error("Intersection: items after intersection are not availabile in the set.") - } -} - -func Test_SymmetricDifference(t *testing.T) { - s := NewThreadSafeSet("1", "2", "3") - r := NewThreadSafeSet("3", "4", "5") - u := SymmetricDifference(s, r) - - if u.Size() != 4 { - t.Error("SymmetricDifference: the set doesn't have all items in it.") - } - - if !u.Has("1", "2", "4", "5") { - t.Error("SymmetricDifference: items are not availabile in the set.") - } -} - -func BenchmarkSetEquality(b *testing.B) { - s := NewThreadSafeSet() - u := NewThreadSafeSet() - - for i := 0; i < b.N; i++ { - s.Add(i) - u.Add(i) - } - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - s.IsEqual(u) - } -} - -func BenchmarkSubset(b *testing.B) { - s := NewThreadSafeSet() - u := NewThreadSafeSet() - - for i := 0; i < b.N; i++ { - s.Add(i) - u.Add(i) - } - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - s.IsSubset(u) - } -} - -func benchmarkIntersection(b *testing.B, numberOfItems int) { - s1 := NewThreadSafeSet() - s2 := NewThreadSafeSet() - - for i := 0; i < numberOfItems/2; i++ { - s1.Add(i) - } - for i := 0; i < numberOfItems; i++ { - s2.Add(i) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - Intersection(s1, s2) - } -} - -func BenchmarkIntersection10(b *testing.B) { - benchmarkIntersection(b, 10) -} - -func BenchmarkIntersection100(b *testing.B) { - benchmarkIntersection(b, 100) -} - -func BenchmarkIntersection1000(b *testing.B) { - benchmarkIntersection(b, 1000) -} - -func BenchmarkIntersection10000(b *testing.B) { - benchmarkIntersection(b, 10000) -} - -func BenchmarkIntersection100000(b *testing.B) { - benchmarkIntersection(b, 100000) -} - -func BenchmarkIntersection1000000(b *testing.B) { - benchmarkIntersection(b, 1000000) -} diff --git a/datastructures/set/implementations.go b/datastructures/set/implementations.go deleted file mode 100644 index 91e205c..0000000 --- a/datastructures/set/implementations.go +++ /dev/null @@ -1,355 +0,0 @@ -package set - -import ( - "sync" -) - -var keyExists = struct{}{} // Value that indicates existance in the set for the key element - -// ThreadUnsafeSet structure. Container of unique items with O(1) access time. NOT THREAD SAFE -type ThreadUnsafeSet struct { - set -} - -// NewSet Constructs a new set from an optinal slice of items -func NewSet(items ...interface{}) *ThreadUnsafeSet { - s := &ThreadUnsafeSet{} - s.m = make(map[interface{}]struct{}) - s.Add(items...) - return s -} - -// Add adds new items to the set -func (s *ThreadUnsafeSet) Add(items ...interface{}) { - if len(items) == 0 { - return - } - - for _, item := range items { - s.m[item] = keyExists - - } -} - -// Remove removes items from the set -func (s *ThreadUnsafeSet) Remove(items ...interface{}) { - if len(items) == 0 { - return - } - - for _, item := range items { - delete(s.m, item) - } -} - -// Pop removes an item from the set and returns it. -func (s *set) Pop() interface{} { - for item := range s.m { - delete(s.m, item) - return item - } - return nil -} - -// Has returns true if the items passed are present in the set -func (s *ThreadUnsafeSet) Has(items ...interface{}) bool { - // assume checked for empty item, which not exist - if len(items) == 0 { - return false - } - - has := true - for _, item := range items { - if _, has = s.m[item]; !has { - break - } - } - return has -} - -// Size returns the size of the set -func (s *ThreadUnsafeSet) Size() int { - return len(s.m) -} - -// Clear removes all elements from the set -func (s *ThreadUnsafeSet) Clear() { - s.m = make(map[interface{}]struct{}) -} - -// IsEqual returns true if the received set is equal to this one -func (s *ThreadUnsafeSet) IsEqual(t Set) bool { - // Force locking only if given set is threadsafe. - if conv, ok := t.(*ThreadSafeSet); ok { - conv.l.RLock() - defer conv.l.RUnlock() - } - - // return false if they are no the same size - if sameSize := len(s.m) == t.Size(); !sameSize { - return false - } - - equal := true - t.Each(func(item interface{}) bool { - _, equal = s.m[item] - return equal // if false, Each() will end - }) - - return equal -} - -// IsSubset returns true if the passed set is a subset of this one -func (s *ThreadUnsafeSet) IsSubset(t Set) (subset bool) { - subset = true - - t.Each(func(item interface{}) bool { - _, subset = s.m[item] - return subset - }) - - return subset -} - -// Each executes a passed function on each of the items passed. -func (s *ThreadUnsafeSet) Each(f func(item interface{}) bool) { - for item := range s.m { - if !f(item) { - break - } - } -} - -// List returns a slice of the items in th set -func (s *ThreadUnsafeSet) List() []interface{} { - list := make([]interface{}, 0, len(s.m)) - - for item := range s.m { - list = append(list, item) - } - - return list -} - -// Copy returns a new set with a copy of the elements -func (s *ThreadUnsafeSet) Copy() Set { - return NewSet(s.List()...) -} - -// Merge adds all the elefements in the passed set to this one. -func (s *ThreadUnsafeSet) Merge(t Set) { - t.Each(func(item interface{}) bool { - s.m[item] = keyExists - return true - }) -} - -// Separate removes all the items that are present in the passed set from this set -func (s *ThreadUnsafeSet) Separate(t Set) { - s.Remove(t.List()...) -} - -// IsEmpty returns true if the set has no elements -func (s *ThreadUnsafeSet) IsEmpty() bool { - return s.Size() == 0 -} - -// IsSuperset returns true if the passed set is a supertset of this one -func (s *ThreadUnsafeSet) IsSuperset(t Set) bool { - return t.IsSubset(s) -} - -// ** Thread safe implementation - -// ThreadSafeSet is a thread safe implementation of the set data structure -type ThreadSafeSet struct { - set - l sync.RWMutex -} - -// NewThreadSafeSet instantiates a new ThreadSafeSet -func NewThreadSafeSet(items ...interface{}) *ThreadSafeSet { - s := &ThreadSafeSet{} - s.m = make(map[interface{}]struct{}) - - // Ensure interface compliance - var _ Set = s - - s.Add(items...) - return s -} - -// Add adds a new element to the set -func (s *ThreadSafeSet) Add(items ...interface{}) { - if len(items) == 0 { - return - } - - s.l.Lock() - defer s.l.Unlock() - - for _, item := range items { - s.m[item] = keyExists - } -} - -// Remove deletes an elemenet from the set. -func (s *ThreadSafeSet) Remove(items ...interface{}) { - if len(items) == 0 { - return - } - - s.l.Lock() - defer s.l.Unlock() - - for _, item := range items { - delete(s.m, item) - } -} - -// Pop removes an element from the set and returns it -func (s *ThreadSafeSet) Pop() interface{} { - s.l.RLock() - for item := range s.m { - s.l.RUnlock() - s.l.Lock() - delete(s.m, item) - s.l.Unlock() - return item - } - s.l.RUnlock() - return nil -} - -// Has returns true if the element passed is in the set -func (s *ThreadSafeSet) Has(items ...interface{}) bool { - // assume checked for empty item, which not exist - if len(items) == 0 { - return false - } - - s.l.RLock() - defer s.l.RUnlock() - - has := true - for _, item := range items { - if _, has = s.m[item]; !has { - break - } - } - return has -} - -// Size returns the number of elements in the set -func (s *ThreadSafeSet) Size() int { - s.l.RLock() - defer s.l.RUnlock() - - l := len(s.m) - return l -} - -// Clear removes all the elements in the set -func (s *ThreadSafeSet) Clear() { - s.l.Lock() - defer s.l.Unlock() - - s.m = make(map[interface{}]struct{}) -} - -// IsEqual returns true if the set contains the same elements as the passed one -func (s *ThreadSafeSet) IsEqual(t Set) bool { - s.l.RLock() - defer s.l.RUnlock() - - // Force locking only if given set is threadsafe. - if conv, ok := t.(*ThreadSafeSet); ok { - conv.l.RLock() - defer conv.l.RUnlock() - } - - // return false if they are no the same size - if sameSize := len(s.m) == t.Size(); !sameSize { - return false - } - - equal := true - t.Each(func(item interface{}) bool { - _, equal = s.m[item] - return equal // if false, Each() will end - }) - - return equal -} - -// IsSubset returns true if the passed set is a subset of this one -func (s *ThreadSafeSet) IsSubset(t Set) (subset bool) { - s.l.RLock() - defer s.l.RUnlock() - - subset = true - - t.Each(func(item interface{}) bool { - _, subset = s.m[item] - return subset - }) - - return -} - -// Each executes the passed function on each item from the set -func (s *ThreadSafeSet) Each(f func(item interface{}) bool) { - s.l.RLock() - defer s.l.RUnlock() - - for item := range s.m { - if !f(item) { - break - } - } -} - -// List returns a list with all the elements of the set -func (s *ThreadSafeSet) List() []interface{} { - s.l.RLock() - defer s.l.RUnlock() - - list := make([]interface{}, 0, len(s.m)) - - for item := range s.m { - list = append(list, item) - } - - return list -} - -// Merge adds all the elements of the passed set into this one -func (s *ThreadSafeSet) Merge(t Set) { - s.l.Lock() - defer s.l.Unlock() - - t.Each(func(item interface{}) bool { - s.m[item] = keyExists - return true - }) -} - -// Copy returns a copy of the this thread -func (s *ThreadSafeSet) Copy() Set { - return NewThreadSafeSet(s.List()...) -} - -// Separate removes all the items that are present in the passed set from this set -func (s *ThreadSafeSet) Separate(t Set) { - s.Remove(t.List()...) -} - -// IsEmpty returns true if the set has no elements -func (s *ThreadSafeSet) IsEmpty() bool { - return s.Size() == 0 -} - -// IsSuperset returns true if the passed set is a supertset of this one -func (s *ThreadSafeSet) IsSuperset(t Set) bool { - return t.IsSubset(s) -} diff --git a/datastructures/set/implementations_test.go b/datastructures/set/implementations_test.go deleted file mode 100644 index d603005..0000000 --- a/datastructures/set/implementations_test.go +++ /dev/null @@ -1,561 +0,0 @@ -package set - -import ( - "reflect" - "strconv" - "testing" -) - -func TestSetNonTS_NewNonTS_parameters(t *testing.T) { - s := NewSet("string", "another_string", 1, 3.14) - - if s.Size() != 4 { - t.Error("NewThreadUnsafeThread: calling with parameters should create a set with size of four") - } -} - -func TestSetNonTS_Add(t *testing.T) { - s := NewSet() - s.Add(1) - s.Add(2) - s.Add(2) - s.Add("fatih") - s.Add("zeynep") - s.Add("zeynep") - - if s.Size() != 4 { - t.Error("Add: items are not unique. The set size should be four") - } - - if !s.Has(1, 2, "fatih", "zeynep") { - t.Error("Add: added items are not availabile in the set.") - } -} - -func TestSetNonTS_Add_multiple(t *testing.T) { - s := NewSet() - s.Add("ankara", "san francisco", 3.14) - - if s.Size() != 3 { - t.Error("Add: items are not unique. The set size should be three") - } - - if !s.Has("ankara", "san francisco", 3.14) { - t.Error("Add: added items are not availabile in the set.") - } -} - -func TestSetNonTS_Remove(t *testing.T) { - s := NewSet() - s.Add(1) - s.Add(2) - s.Add("fatih") - - s.Remove(1) - if s.Size() != 2 { - t.Error("Remove: set size should be two after removing") - } - - s.Remove(1) - if s.Size() != 2 { - t.Error("Remove: set size should be not change after trying to remove a non-existing item") - } - - s.Remove(2) - s.Remove("fatih") - if s.Size() != 0 { - t.Error("Remove: set size should be zero") - } - - s.Remove("fatih") // try to remove something from a zero length set -} - -func TestSetNonTS_Remove_multiple(t *testing.T) { - s := NewSet() - s.Add("ankara", "san francisco", 3.14, "istanbul") - s.Remove("ankara", "san francisco", 3.14) - - if s.Size() != 1 { - t.Error("Remove: items are not unique. The set size should be four") - } - - if !s.Has("istanbul") { - t.Error("Add: added items are not availabile in the set.") - } -} - -func TestSetNonTS_Pop(t *testing.T) { - s := NewSet() - s.Add(1) - s.Add(2) - s.Add("fatih") - - a := s.Pop() - if s.Size() != 2 { - t.Error("Pop: set size should be two after popping out") - } - - if s.Has(a) { - t.Error("Pop: returned item should not exist") - } - - s.Pop() - s.Pop() - b := s.Pop() - if b != nil { - t.Error("Pop: should return nil because set is empty") - } - - s.Pop() // try to remove something from a zero length set -} - -func TestSetNonTS_Has(t *testing.T) { - s := NewSet("1", "2", "3", "4") - - if !s.Has("1") { - t.Error("Has: the item 1 exist, but 'Has' is returning false") - } - - if !s.Has("1", "2", "3", "4") { - t.Error("Has: the items all exist, but 'Has' is returning false") - } -} - -func TestSetNonTS_Clear(t *testing.T) { - s := NewSet() - s.Add(1) - s.Add("istanbul") - s.Add("san francisco") - - s.Clear() - if s.Size() != 0 { - t.Error("Clear: set size should be zero") - } -} - -func TestSetNonTS_IsEmpty(t *testing.T) { - s := NewSet() - - empty := s.IsEmpty() - if !empty { - t.Error("IsEmpty: set is empty, it should be true") - } - - s.Add(2) - s.Add(3) - notEmpty := s.IsEmpty() - - if notEmpty { - t.Error("IsEmpty: set is filled, it should be false") - } -} - -func TestSetNonTS_IsEqual(t *testing.T) { - s := NewSet("1", "2", "3") - u := NewSet("1", "2", "3") - - ok := s.IsEqual(u) - if !ok { - t.Error("IsEqual: set s and t are equal. However it returns false") - } - - // same size, different content - a := NewSet("1", "2", "3") - b := NewSet("4", "5", "6") - - ok = a.IsEqual(b) - if ok { - t.Error("IsEqual: set a and b are now equal (1). However it returns true") - } - - // different size, similar content - a = NewSet("1", "2", "3") - b = NewSet("1", "2", "3", "4") - - ok = a.IsEqual(b) - if ok { - t.Error("IsEqual: set s and t are now equal (2). However it returns true") - } - -} - -func TestSetNonTS_IsSubset(t *testing.T) { - s := NewSet("1", "2", "3", "4") - u := NewSet("1", "2", "3") - - ok := s.IsSubset(u) - if !ok { - t.Error("IsSubset: u is a subset of s. However it returns false") - } - - ok = u.IsSubset(s) - if ok { - t.Error("IsSubset: s is not a subset of u. However it returns true") - } - -} - -func TestSetNonTS_IsSuperset(t *testing.T) { - s := NewSet("1", "2", "3", "4") - u := NewSet("1", "2", "3") - - ok := u.IsSuperset(s) - if !ok { - t.Error("IsSuperset: s is a superset of u. However it returns false") - } - - ok = s.IsSuperset(u) - if ok { - t.Error("IsSuperset: u is not a superset of u. However it returns true") - } - -} - -func TestSetNonTS_List(t *testing.T) { - s := NewSet("1", "2", "3", "4") - - // this returns a slice of interface{} - if len(s.List()) != 4 { - t.Error("List: slice size should be four.") - } - - for _, item := range s.List() { - r := reflect.TypeOf(item) - if r.Kind().String() != "string" { - t.Error("List: slice item should be a string") - } - } -} - -func TestSetNonTS_Copy(t *testing.T) { - s := NewSet("1", "2", "3", "4") - r := s.Copy() - - if !s.IsEqual(r) { - t.Error("Copy: set s and r are not equal") - } -} - -func TestSetNonTS_Merge(t *testing.T) { - s := NewSet("1", "2", "3") - r := NewSet("3", "4", "5") - s.Merge(r) - - if s.Size() != 5 { - t.Error("Merge: the set doesn't have all items in it.") - } - - if !s.Has("1", "2", "3", "4", "5") { - t.Error("Merge: merged items are not availabile in the set.") - } -} - -func TestSetNonTS_Separate(t *testing.T) { - s := NewSet("1", "2", "3") - r := NewSet("3", "5") - s.Separate(r) - - if s.Size() != 2 { - t.Error("Separate: the set doesn't have all items in it.") - } - - if !s.Has("1", "2") { - t.Error("Separate: items after separation are not availabile in the set.") - } -} - -// ************************************** -// ****** TEST Thread-Safe Implementation -// ************************************** - -func TestSet_New(t *testing.T) { - s := NewThreadSafeSet() - - if s.Size() != 0 { - t.Error("New: calling without any parameters should create a set with zero size") - } -} - -func TestSet_New_parameters(t *testing.T) { - s := NewThreadSafeSet("string", "another_string", 1, 3.14) - - if s.Size() != 4 { - t.Error("New: calling with parameters should create a set with size of four") - } -} - -func TestSet_Add(t *testing.T) { - s := NewThreadSafeSet() - s.Add(1) - s.Add(2) - s.Add(2) // duplicate - s.Add("fatih") - s.Add("zeynep") - s.Add("zeynep") // another duplicate - - if s.Size() != 4 { - t.Error("Add: items are not unique. The set size should be four") - } - - if !s.Has(1, 2, "fatih", "zeynep") { - t.Error("Add: added items are not availabile in the set.") - } -} - -func TestSet_Add_multiple(t *testing.T) { - s := NewThreadSafeSet() - s.Add("ankara", "san francisco", 3.14) - - if s.Size() != 3 { - t.Error("Add: items are not unique. The set size should be three") - } - - if !s.Has("ankara", "san francisco", 3.14) { - t.Error("Add: added items are not availabile in the set.") - } -} - -func TestSet_Remove(t *testing.T) { - s := NewThreadSafeSet() - s.Add(1) - s.Add(2) - s.Add("fatih") - - s.Remove(1) - if s.Size() != 2 { - t.Error("Remove: set size should be two after removing") - } - - s.Remove(1) - if s.Size() != 2 { - t.Error("Remove: set size should be not change after trying to remove a non-existing item") - } - - s.Remove(2) - s.Remove("fatih") - if s.Size() != 0 { - t.Error("Remove: set size should be zero") - } - - s.Remove("fatih") // try to remove something from a zero length set -} - -func TestSet_Remove_multiple(t *testing.T) { - s := NewThreadSafeSet() - s.Add("ankara", "san francisco", 3.14, "istanbul") - s.Remove("ankara", "san francisco", 3.14) - - if s.Size() != 1 { - t.Error("Remove: items are not unique. The set size should be four") - } - - if !s.Has("istanbul") { - t.Error("Add: added items are not availabile in the set.") - } -} - -func TestSet_Pop(t *testing.T) { - s := NewThreadSafeSet() - s.Add(1) - s.Add(2) - s.Add("fatih") - - a := s.Pop() - if s.Size() != 2 { - t.Error("Pop: set size should be two after popping out") - } - - if s.Has(a) { - t.Error("Pop: returned item should not exist") - } - - s.Pop() - s.Pop() - b := s.Pop() - if b != nil { - t.Error("Pop: should return nil because set is empty") - } - - s.Pop() // try to remove something from a zero length set -} - -func TestSet_Has(t *testing.T) { - s := NewThreadSafeSet("1", "2", "3", "4") - - if !s.Has("1") { - t.Error("Has: the item 1 exist, but 'Has' is returning false") - } - - if !s.Has("1", "2", "3", "4") { - t.Error("Has: the items all exist, but 'Has' is returning false") - } -} - -func TestSet_Clear(t *testing.T) { - s := NewThreadSafeSet() - s.Add(1) - s.Add("istanbul") - s.Add("san francisco") - - s.Clear() - if s.Size() != 0 { - t.Error("Clear: set size should be zero") - } -} - -func TestSet_IsEmpty(t *testing.T) { - s := NewThreadSafeSet() - - empty := s.IsEmpty() - if !empty { - t.Error("IsEmpty: set is empty, it should be true") - } - - s.Add(2) - s.Add(3) - notEmpty := s.IsEmpty() - - if notEmpty { - t.Error("IsEmpty: set is filled, it should be false") - } -} - -func TestSet_IsEqual(t *testing.T) { - s := NewThreadSafeSet("1", "2", "3") - u := NewThreadSafeSet("1", "2", "3") - - ok := s.IsEqual(u) - if !ok { - t.Error("IsEqual: set s and t are equal. However it returns false") - } - - // same size, different content - a := NewThreadSafeSet("1", "2", "3") - b := NewThreadSafeSet("4", "5", "6") - - ok = a.IsEqual(b) - if ok { - t.Error("IsEqual: set a and b are now equal (1). However it returns true") - } - - // different size, similar content - a = NewThreadSafeSet("1", "2", "3") - b = NewThreadSafeSet("1", "2", "3", "4") - - ok = a.IsEqual(b) - if ok { - t.Error("IsEqual: set s and t are now equal (2). However it returns true") - } - -} - -func TestSet_IsSubset(t *testing.T) { - s := NewThreadSafeSet("1", "2", "3", "4") - u := NewThreadSafeSet("1", "2", "3") - - ok := s.IsSubset(u) - if !ok { - t.Error("IsSubset: u is a subset of s. However it returns false") - } - - ok = u.IsSubset(s) - if ok { - t.Error("IsSubset: s is not a subset of u. However it returns true") - } - -} - -func TestSet_IsSuperset(t *testing.T) { - s := NewThreadSafeSet("1", "2", "3", "4") - u := NewThreadSafeSet("1", "2", "3") - - ok := u.IsSuperset(s) - if !ok { - t.Error("IsSuperset: s is a superset of u. However it returns false") - } - - ok = s.IsSuperset(u) - if ok { - t.Error("IsSuperset: u is not a superset of u. However it returns true") - } - -} - -func TestSet_List(t *testing.T) { - s := NewThreadSafeSet("1", "2", "3", "4") - - // this returns a slice of interface{} - if len(s.List()) != 4 { - t.Error("List: slice size should be four.") - } - - for _, item := range s.List() { - r := reflect.TypeOf(item) - if r.Kind().String() != "string" { - t.Error("List: slice item should be a string") - } - } -} - -func TestSet_Copy(t *testing.T) { - s := NewThreadSafeSet("1", "2", "3", "4") - r := s.Copy() - - if !s.IsEqual(r) { - t.Error("Copy: set s and r are not equal") - } -} - -func TestSet_Merge(t *testing.T) { - s := NewThreadSafeSet("1", "2", "3") - r := NewThreadSafeSet("3", "4", "5") - s.Merge(r) - - if s.Size() != 5 { - t.Error("Merge: the set doesn't have all items in it.") - } - - if !s.Has("1", "2", "3", "4", "5") { - t.Error("Merge: merged items are not availabile in the set.") - } -} - -func TestSet_Separate(t *testing.T) { - s := NewThreadSafeSet("1", "2", "3") - r := NewThreadSafeSet("3", "5") - s.Separate(r) - - if s.Size() != 2 { - t.Error("Separate: the set doesn't have all items in it.") - } - - if !s.Has("1", "2") { - t.Error("Separate: items after separation are not availabile in the set.") - } -} - -func TestSet_RaceAdd(t *testing.T) { - // Create two sets. Add concurrently items to each of them. Remove from the - // other one. - // "go test -race" should detect this if the library is not thread-safe. - s := NewThreadSafeSet() - u := NewThreadSafeSet() - - go func() { - for i := 0; i < 1000; i++ { - item := "item" + strconv.Itoa(i) - go func(i int) { - s.Add(item) - u.Add(item) - }(i) - } - }() - - for i := 0; i < 1000; i++ { - item := "item" + strconv.Itoa(i) - go func(i int) { - s.Add(item) - u.Add(item) - }(i) - } -} diff --git a/datastructures/set/set.go b/datastructures/set/set.go index fb5a9b0..5c903f0 100644 --- a/datastructures/set/set.go +++ b/datastructures/set/set.go @@ -1,24 +1,85 @@ package set -// Set interface shared between Thread-Safe and Thread-Unsafe implementations -type Set interface { - Add(items ...interface{}) - Remove(items ...interface{}) - Pop() interface{} - Has(items ...interface{}) bool - Size() int - Clear() - IsEmpty() bool - IsEqual(s Set) bool - IsSubset(s Set) bool - IsSuperset(s Set) bool - Each(func(interface{}) bool) - List() []interface{} - Copy() Set - Merge(s Set) - Separate(t Set) -} - -type set struct { - m map[interface{}]struct{} +type Set[T comparable] struct { + data map[T]struct{} +} + +func New[T comparable](capacity int) Set[T] { + return Set[T]{ + data: make(map[T]struct{}, capacity), + } +} + +func Define[T comparable](items ...T) Set[T] { + tr := New[T](len(items)) + tr.AddFromSlice(items) + return tr +} + +func (s *Set[T]) Add(item ...T) { + for idx := range item { + s.data[item[idx]] = struct{}{} + } +} + +func (s *Set[T]) Contains(item T) bool { + _, exists := s.data[item] + return exists +} + +func (s *Set[T]) Remove(item T) { + delete(s.data, item) +} + +func (s *Set[T]) Len() int { + return len(s.data) +} + +func (s *Set[T]) Range(callback func(item T) bool) { + for item := range s.data { + if !callback(item) { + return + } + } +} + +func (s *Set[T]) Clone() Set[T] { + tr := New[T](s.Len()) + tr.AddFromAnother(*s) + return tr +} + +func (s *Set[T]) Union(another Set[T]) { + s.AddFromAnother(another) +} + +func (s *Set[T]) Intersect(another Set[T]) { + s.Range(func(item T) bool { + if !another.Contains(item) { + s.Remove(item) + } + return true + }) +} + +func (s *Set[T]) AddFromAnother(another Set[T]) { + another.Range(func(item T) bool { + s.Add(item) + return true + }) +} + +func (s *Set[T]) AddFromSlice(items []T) { + for _, item := range items { + s.Add(item) + } +} + +func (s *Set[T]) ToSlice() []T { + tr := make([]T, 0, s.Len()) + s.Range(func(item T) bool { + tr = append(tr, item) + return true + }) + return tr } diff --git a/datastructures/set/set_test.go b/datastructures/set/set_test.go new file mode 100644 index 0000000..86ee0b4 --- /dev/null +++ b/datastructures/set/set_test.go @@ -0,0 +1,41 @@ +package set + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSet(t *testing.T) { + s := New[int](10) + assert.Equal(t, int(0), s.Len()) + s.Add(1, 2, 3, 4, 5, 6, 7, 8, 9) + assert.Equal(t, int(9), s.Len()) + for i := 1; i <= 9; i++ { + assert.True(t, s.Contains(i)) + } + assert.False(t, s.Contains(0)) + assert.False(t, s.Contains(10)) + + s.Add(0) + assert.True(t, s.Contains(0)) + s.Remove(0) + assert.False(t, s.Contains(0)) + + assert.Equal(t, s, s.Clone()) + + s.Intersect(Define(5, 245)) + assert.Equal(t, Define(int(5)), s) + + s.Union(Define(10, 20, 30)) + assert.Equal(t, Define(5, 10, 20, 30), s) + + assert.Equal(t, Define(1, 2, 3), Define(3, 2, 1)) + assert.Equal(t, Define(1), Define(1, 1, 1)) + + asSlice := s.ToSlice() + assert.ElementsMatch(t, []int{5, 10, 20, 30}, asSlice) + n := New[int](4) + n.AddFromSlice(asSlice) + assert.Equal(t, Define(5, 10, 20, 30), n) +} diff --git a/go.mod b/go.mod index 9af38fe..5585cea 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,17 @@ module github.com/splitio/go-toolkit/v5 -go 1.18 +go 1.21 -require github.com/redis/go-redis/v9 v9.0.4 +require ( + github.com/redis/go-redis/v9 v9.0.4 + github.com/stretchr/testify v1.9.0 +) require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 367a7ef..7d879af 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,22 @@ github.com/bsm/ginkgo/v2 v2.7.0 h1:ItPMPH90RbmZJt5GtkcNvIRuGEdwlBItdNVoyzaNQao= +github.com/bsm/ginkgo/v2 v2.7.0/go.mod h1:AiKlXPm7ItEHNc/2+OkrNG4E0ITzojb9/xWzvQ9XZ9w= github.com/bsm/gomega v1.26.0 h1:LhQm+AFcgV2M0WyKroMASzAzCAJVpAxQXv4SaI9a69Y= +github.com/bsm/gomega v1.26.0/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/redis/go-redis/v9 v9.0.4 h1:FC82T+CHJ/Q/PdyLW++GeCO+Ol59Y4T7R4jbgjvktgc= github.com/redis/go-redis/v9 v9.0.4/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/hasher/hasher.go b/hashing/hasher.go similarity index 97% rename from hasher/hasher.go rename to hashing/hasher.go index 148c320..58f84a0 100644 --- a/hasher/hasher.go +++ b/hashing/hasher.go @@ -1,4 +1,4 @@ -package hasher +package hashing // Hasher interface type Hasher interface { diff --git a/hasher/hasher_test.go b/hashing/hasher_test.go similarity index 99% rename from hasher/hasher_test.go rename to hashing/hasher_test.go index 2a4ce20..b496589 100644 --- a/hasher/hasher_test.go +++ b/hashing/hasher_test.go @@ -1,4 +1,4 @@ -package hasher +package hashing import ( "bufio" diff --git a/provisional/hashing/murmur128.go b/hashing/murmur128.go similarity index 100% rename from provisional/hashing/murmur128.go rename to hashing/murmur128.go diff --git a/provisional/hashing/murmur128_test.go b/hashing/murmur128_test.go similarity index 88% rename from provisional/hashing/murmur128_test.go rename to hashing/murmur128_test.go index c56e2fb..0aacb79 100644 --- a/provisional/hashing/murmur128_test.go +++ b/hashing/murmur128_test.go @@ -1,14 +1,14 @@ package hashing import ( - "io/ioutil" + "os" "strconv" "strings" "testing" ) func TestMurmur128(t *testing.T) { - raw, err := ioutil.ReadFile("../../testfiles/murmur3_64_uuids.csv") + raw, err := os.ReadFile("../testdata/murmur3_64_uuids.csv") if err != nil { t.Error("error reading murmur128 test cases files: ", err.Error()) } diff --git a/hasher/murmur32.go b/hashing/murmur32.go similarity index 99% rename from hasher/murmur32.go rename to hashing/murmur32.go index f611908..2d38721 100644 --- a/hasher/murmur32.go +++ b/hashing/murmur32.go @@ -1,4 +1,4 @@ -package hasher +package hashing // Implementation borrowed from https://github.com/spaolacci/murmur3, // distributed under BSD-3 license. diff --git a/hasher/util.go b/hashing/util.go similarity index 95% rename from hasher/util.go rename to hashing/util.go index f8f50a8..e641877 100644 --- a/hasher/util.go +++ b/hashing/util.go @@ -1,4 +1,4 @@ -package hasher +package hashing import ( "encoding/base64" diff --git a/hasher/util_test.go b/hashing/util_test.go similarity index 96% rename from hasher/util_test.go rename to hashing/util_test.go index 97733d0..19f5810 100644 --- a/hasher/util_test.go +++ b/hashing/util_test.go @@ -1,4 +1,4 @@ -package hasher +package hashing import "testing" diff --git a/logging/interface.go b/logging/interface.go index 2b42e06..75c19fa 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -9,6 +9,11 @@ type LoggerInterface interface { Info(msg ...interface{}) Debug(msg ...interface{}) Verbose(msg ...interface{}) + Errorf(fmt string, msg ...interface{}) + Warningf(fmt string, msg ...interface{}) + Infof(fmt string, msg ...interface{}) + Debugf(fmt string, msg ...interface{}) + Verbosef(fmt string, msg ...interface{}) } // ParamsFn is a function that returns a slice of interface{} diff --git a/logging/levels.go b/logging/levels.go index 36e1763..2b9ff85 100644 --- a/logging/levels.go +++ b/logging/levels.go @@ -76,6 +76,45 @@ func (l *LevelFilteredLoggerWrapper) Verbose(is ...interface{}) { } } +// Debugf implements LoggerInterface. +func (l *LevelFilteredLoggerWrapper) Debugf(fmt string, args ...interface{}) { + if l.level >= LevelDebug { + l.delegate.Debugf(fmt, args...) + } +} + +// Errorf implements LoggerInterface. +func (l *LevelFilteredLoggerWrapper) Errorf(fmt string, args ...interface{}) { + if l.level >= LevelError { + l.delegate.Errorf(fmt, args...) + } +} + +// Infof implements LoggerInterface. +func (l *LevelFilteredLoggerWrapper) Infof(fmt string, args ...interface{}) { + if l.level >= LevelInfo { + l.delegate.Infof(fmt, args...) + } +} + +// Verbosef implements LoggerInterface. +func (l *LevelFilteredLoggerWrapper) Verbosef(fmt string, args ...interface{}) { + if l.level >= LevelVerbose { + l.delegate.Verbosef(fmt, args...) + } +} + +// Warningf implements LoggerInterface. +func (l *LevelFilteredLoggerWrapper) Warningf(fmt string, args ...interface{}) { + if l.level >= LevelWarning { + l.delegate.Warningf(fmt, args...) + } +} + + + +var _ LoggerInterface = (*LevelFilteredLoggerWrapper)(nil) + var levels map[string]int = map[string]int{ "ERROR": LevelError, "WARNING": LevelWarning, diff --git a/logging/levels_test.go b/logging/levels_test.go index c40e36d..21ef84d 100644 --- a/logging/levels_test.go +++ b/logging/levels_test.go @@ -2,40 +2,15 @@ package logging import ( "testing" -) - -type mockedLogger struct { - msgs map[string]bool -} - -func (l *mockedLogger) reset() { - l.msgs = make(map[string]bool) -} - -func (l *mockedLogger) Error(msg ...interface{}) { - l.msgs["Error"] = true -} -func (l *mockedLogger) Warning(msg ...interface{}) { - l.msgs["Warning"] = true -} - -func (l *mockedLogger) Info(msg ...interface{}) { - l.msgs["Info"] = true -} - -func (l *mockedLogger) Debug(msg ...interface{}) { - l.msgs["Debug"] = true -} - -func (l *mockedLogger) Verbose(msg ...interface{}) { - l.msgs["Verbose"] = true -} + "github.com/splitio/go-toolkit/v5/logging/mocks" +) func TestErrorLevel(t *testing.T) { - delegate := &mockedLogger{} - delegate.reset() + delegate := &mocks.MockLogger{} + delegate.On("Error", "text").Once() + delegate.On("Errorf", "formatted %d", int(3)).Once() logger := LevelFilteredLoggerWrapper{ delegate: delegate, @@ -43,31 +18,25 @@ func TestErrorLevel(t *testing.T) { } logger.Error("text") + logger.Errorf("formatted %d", int(3)) logger.Warning("text") + logger.Warningf("formatted %d", int(3)) logger.Info("text") + logger.Infof("formatted %d", int(3)) logger.Debug("text") + logger.Debugf("formatted %d", int(3)) logger.Verbose("text") + logger.Verbosef("formatted %d", int(3)) - shouldBeCalled := []string{"Error"} - shouldNotBeCalled := []string{"Warning", "Info", "Debug", "Verbose"} - - for _, level := range shouldBeCalled { - if !delegate.msgs[level] { - t.Errorf("Call to log level function \"%s\" should have been forwarded", level) - } - } - - for _, level := range shouldNotBeCalled { - if delegate.msgs[level] { - t.Errorf("Call to log level function \"%s\" should NOT have been forwarded", level) - } - } + delegate.AssertExpectations(t) } func TestWarningLevel(t *testing.T) { - - delegate := &mockedLogger{} - delegate.reset() + delegate := &mocks.MockLogger{} + delegate.On("Error", "text").Once() + delegate.On("Errorf", "formatted %d", int(3)).Once() + delegate.On("Warning", "text").Once() + delegate.On("Warningf", "formatted %d", int(3)).Once() logger := LevelFilteredLoggerWrapper{ delegate: delegate, @@ -75,31 +44,27 @@ func TestWarningLevel(t *testing.T) { } logger.Error("text") + logger.Errorf("formatted %d", int(3)) logger.Warning("text") + logger.Warningf("formatted %d", int(3)) logger.Info("text") + logger.Infof("formatted %d", int(3)) logger.Debug("text") + logger.Debugf("formatted %d", int(3)) logger.Verbose("text") + logger.Verbosef("formatted %d", int(3)) - shouldBeCalled := []string{"Error", "Warning"} - shouldNotBeCalled := []string{"Info", "Debug", "Verbose"} - - for _, level := range shouldBeCalled { - if !delegate.msgs[level] { - t.Errorf("Call to log level function \"%s\" should have been forwarded", level) - } - } - - for _, level := range shouldNotBeCalled { - if delegate.msgs[level] { - t.Errorf("Call to log level function \"%s\" should NOT have been forwarded", level) - } - } + delegate.AssertExpectations(t) } func TestInfoLevel(t *testing.T) { - - delegate := &mockedLogger{} - delegate.reset() + delegate := &mocks.MockLogger{} + delegate.On("Error", "text").Once() + delegate.On("Errorf", "formatted %d", int(3)).Once() + delegate.On("Warning", "text").Once() + delegate.On("Warningf", "formatted %d", int(3)).Once() + delegate.On("Info", "text").Once() + delegate.On("Infof", "formatted %d", int(3)).Once() logger := LevelFilteredLoggerWrapper{ delegate: delegate, @@ -107,31 +72,29 @@ func TestInfoLevel(t *testing.T) { } logger.Error("text") + logger.Errorf("formatted %d", int(3)) logger.Warning("text") + logger.Warningf("formatted %d", int(3)) logger.Info("text") + logger.Infof("formatted %d", int(3)) logger.Debug("text") + logger.Debugf("formatted %d", int(3)) logger.Verbose("text") + logger.Verbosef("formatted %d", int(3)) - shouldBeCalled := []string{"Error", "Warning", "Info"} - shouldNotBeCalled := []string{"Debug", "Verbose"} - - for _, level := range shouldBeCalled { - if !delegate.msgs[level] { - t.Errorf("Call to log level function \"%s\" should have been forwarded", level) - } - } - - for _, level := range shouldNotBeCalled { - if delegate.msgs[level] { - t.Errorf("Call to log level function \"%s\" should NOT have been forwarded", level) - } - } + delegate.AssertExpectations(t) } func TestDebugLevel(t *testing.T) { - - delegate := &mockedLogger{} - delegate.reset() + delegate := &mocks.MockLogger{} + delegate.On("Error", "text").Once() + delegate.On("Errorf", "formatted %d", int(3)).Once() + delegate.On("Warning", "text").Once() + delegate.On("Warningf", "formatted %d", int(3)).Once() + delegate.On("Info", "text").Once() + delegate.On("Infof", "formatted %d", int(3)).Once() + delegate.On("Debug", "text").Once() + delegate.On("Debugf", "formatted %d", int(3)).Once() logger := LevelFilteredLoggerWrapper{ delegate: delegate, @@ -139,31 +102,31 @@ func TestDebugLevel(t *testing.T) { } logger.Error("text") + logger.Errorf("formatted %d", int(3)) logger.Warning("text") + logger.Warningf("formatted %d", int(3)) logger.Info("text") + logger.Infof("formatted %d", int(3)) logger.Debug("text") + logger.Debugf("formatted %d", int(3)) logger.Verbose("text") + logger.Verbosef("formatted %d", int(3)) - shouldBeCalled := []string{"Error", "Warning", "Info", "Debug"} - shouldNotBeCalled := []string{"Verbose"} - - for _, level := range shouldBeCalled { - if !delegate.msgs[level] { - t.Errorf("Call to log level function \"%s\" should have been forwarded", level) - } - } - - for _, level := range shouldNotBeCalled { - if delegate.msgs[level] { - t.Errorf("Call to log level function \"%s\" should NOT have been forwarded", level) - } - } + delegate.AssertExpectations(t) } func TestVerboseLevel(t *testing.T) { - - delegate := &mockedLogger{} - delegate.reset() + delegate := &mocks.MockLogger{} + delegate.On("Error", "text").Once() + delegate.On("Errorf", "formatted %d", int(3)).Once() + delegate.On("Warning", "text").Once() + delegate.On("Warningf", "formatted %d", int(3)).Once() + delegate.On("Info", "text").Once() + delegate.On("Infof", "formatted %d", int(3)).Once() + delegate.On("Debug", "text").Once() + delegate.On("Debugf", "formatted %d", int(3)).Once() + delegate.On("Verbose", "text").Once() + delegate.On("Verbosef", "formatted %d", int(3)).Once() logger := LevelFilteredLoggerWrapper{ delegate: delegate, @@ -171,31 +134,31 @@ func TestVerboseLevel(t *testing.T) { } logger.Error("text") + logger.Errorf("formatted %d", int(3)) logger.Warning("text") + logger.Warningf("formatted %d", int(3)) logger.Info("text") + logger.Infof("formatted %d", int(3)) logger.Debug("text") + logger.Debugf("formatted %d", int(3)) logger.Verbose("text") + logger.Verbosef("formatted %d", int(3)) - shouldBeCalled := []string{"Error", "Warning", "Info", "Debug", "Verbose"} - shouldNotBeCalled := []string{} - - for _, level := range shouldBeCalled { - if !delegate.msgs[level] { - t.Errorf("Call to log level function \"%s\" should have been forwarded", level) - } - } - - for _, level := range shouldNotBeCalled { - if delegate.msgs[level] { - t.Errorf("Call to log level function \"%s\" should NOT have been forwarded", level) - } - } + delegate.AssertExpectations(t) } func TestAllLevel(t *testing.T) { - - delegate := &mockedLogger{} - delegate.reset() + delegate := &mocks.MockLogger{} + delegate.On("Error", "text").Once() + delegate.On("Errorf", "formatted %d", int(3)).Once() + delegate.On("Warning", "text").Once() + delegate.On("Warningf", "formatted %d", int(3)).Once() + delegate.On("Info", "text").Once() + delegate.On("Infof", "formatted %d", int(3)).Once() + delegate.On("Debug", "text").Once() + delegate.On("Debugf", "formatted %d", int(3)).Once() + delegate.On("Verbose", "text").Once() + delegate.On("Verbosef", "formatted %d", int(3)).Once() logger := LevelFilteredLoggerWrapper{ delegate: delegate, @@ -203,58 +166,58 @@ func TestAllLevel(t *testing.T) { } logger.Error("text") + logger.Errorf("formatted %d", int(3)) logger.Warning("text") + logger.Warningf("formatted %d", int(3)) logger.Info("text") + logger.Infof("formatted %d", int(3)) logger.Debug("text") + logger.Debugf("formatted %d", int(3)) logger.Verbose("text") + logger.Verbosef("formatted %d", int(3)) - shouldBeCalled := []string{"Error", "Warning", "Info", "Debug", "Verbose"} - shouldNotBeCalled := []string{} + delegate.AssertExpectations(t) - for _, level := range shouldBeCalled { - if !delegate.msgs[level] { - t.Errorf("Call to log level function \"%s\" should have been forwarded", level) - } - } - - for _, level := range shouldNotBeCalled { - if delegate.msgs[level] { - t.Errorf("Call to log level function \"%s\" should NOT have been forwarded", level) - } - } } func TestNoneLevel(t *testing.T) { - - delegate := &mockedLogger{} - delegate.reset() - + delegate := &mocks.MockLogger{} logger := LevelFilteredLoggerWrapper{ delegate: delegate, level: LevelNone, } logger.Error("text") + logger.Errorf("formatted %d", int(3)) logger.Warning("text") + logger.Warningf("formatted %d", int(3)) logger.Info("text") + logger.Infof("formatted %d", int(3)) logger.Debug("text") + logger.Debugf("formatted %d", int(3)) logger.Verbose("text") + logger.Verbosef("formatted %d", int(3)) - shouldNotBeCalled := []string{"Error", "Warning", "Info", "Debug", "Verbose"} - shouldBeCalled := []string{} + delegate.AssertExpectations(t) +} - for _, level := range shouldBeCalled { - if !delegate.msgs[level] { - t.Errorf("Call to log level function \"%s\" should have been forwarded", level) - } - } +// --------------------------------- - for _, level := range shouldNotBeCalled { - if delegate.msgs[level] { - t.Errorf("Call to log level function \"%s\" should NOT have been forwarded", level) - } - } -} +type mockedLogger struct{ msgs map[string]bool } + +func (*mockedLogger) Debugf(fmt string, msg ...interface{}) { panic("unimplemented") } +func (*mockedLogger) Errorf(fmt string, msg ...interface{}) { panic("unimplemented") } +func (*mockedLogger) Infof(fmt string, msg ...interface{}) { panic("unimplemented") } +func (*mockedLogger) Verbosef(fmt string, msg ...interface{}) { panic("unimplemented") } +func (*mockedLogger) Warningf(fmt string, msg ...interface{}) { panic("unimplemented") } +func (l *mockedLogger) reset() { l.msgs = make(map[string]bool) } +func (l *mockedLogger) Error(msg ...interface{}) { l.msgs["Error"] = true } +func (l *mockedLogger) Warning(msg ...interface{}) { l.msgs["Warning"] = true } +func (l *mockedLogger) Info(msg ...interface{}) { l.msgs["Info"] = true } +func (l *mockedLogger) Debug(msg ...interface{}) { l.msgs["Debug"] = true } +func (l *mockedLogger) Verbose(msg ...interface{}) { l.msgs["Verbose"] = true } + +var _ LoggerInterface = (*mockedLogger)(nil) func writelog(logger *ExtendedLevelFilteredLoggerWrapper) { logger.ErrorFn("hello %s", func() []interface{} { return []interface{}{"world"} }) diff --git a/logging/logging.go b/logging/logging.go index 81106f4..e125028 100644 --- a/logging/logging.go +++ b/logging/logging.go @@ -65,6 +65,31 @@ func (l *Logger) Error(msg ...interface{}) { l.errorLogger.Output(l.framesToSkip, fmt.Sprintln(msg...)) } +// Verbose logs a message with Debug level +func (l *Logger) Verbosef(f string, args ...interface{}) { + l.verboseLogger.Output(l.framesToSkip, fmt.Sprintf(f, args...)) +} + +// Debug logs a message with Debug level +func (l *Logger) Debugf(f string, args ...interface{}) { + l.debugLogger.Output(l.framesToSkip, fmt.Sprintf(f, args...)) +} + +// Info logs a message with Info level +func (l *Logger) Infof(f string, args ...interface{}) { + l.infoLogger.Output(l.framesToSkip, fmt.Sprintf(f, args...)) +} + +// Warning logs a message with Warning level +func (l *Logger) Warningf(f string, args ...interface{}) { + l.warningLogger.Output(l.framesToSkip, fmt.Sprintf(f, args...)) +} + +// Error logs a message with Error level +func (l *Logger) Errorf(f string, args ...interface{}) { + l.errorLogger.Output(l.framesToSkip, fmt.Sprintf(f, args...)) +} + func normalizeOptions(options *LoggerOptions) *LoggerOptions { var toRet *LoggerOptions if options == nil { diff --git a/logging/mocks/mocks.go b/logging/mocks/mocks.go index 80b10e8..d622b41 100644 --- a/logging/mocks/mocks.go +++ b/logging/mocks/mocks.go @@ -1,29 +1,59 @@ package mocks +import ( + "github.com/stretchr/testify/mock" +) + type MockLogger struct { - ErrorCall func(msg ...interface{}) - WarningCall func(msg ...interface{}) - InfoCall func(msg ...interface{}) - DebugCall func(msg ...interface{}) - VerboseCall func(msg ...interface{}) + mock.Mock +} + +// Debug implements logging.LoggerInterface. +func (l *MockLogger) Debug(msg ...interface{}) { + l.Called(msg...) } +// Debugf implements logging.LoggerInterface. +func (l *MockLogger) Debugf(fmt string, msg ...interface{}) { + l.Called(append([]interface{}{fmt}, msg...)...) +} + +// Error implements logging.LoggerInterface. func (l *MockLogger) Error(msg ...interface{}) { - l.ErrorCall(msg...) + l.Called(msg...) } -func (l *MockLogger) Warning(msg ...interface{}) { - l.WarningCall(msg...) +// Errorf implements logging.LoggerInterface. +func (l *MockLogger) Errorf(fmt string, msg ...interface{}) { + l.Called(append([]interface{}{fmt}, msg...)...) } +// Info implements logging.LoggerInterface. func (l *MockLogger) Info(msg ...interface{}) { - l.InfoCall(msg...) + l.Called(msg...) } -func (l *MockLogger) Debug(msg ...interface{}) { - l.DebugCall(msg...) +// Infof implements logging.LoggerInterface. +func (l *MockLogger) Infof(fmt string, msg ...interface{}) { + l.Called(append([]interface{}{fmt}, msg...)...) } +// Verbose implements logging.LoggerInterface. func (l *MockLogger) Verbose(msg ...interface{}) { - l.VerboseCall(msg...) + l.Called(msg...) +} + +// Verbosef implements logging.LoggerInterface. +func (l *MockLogger) Verbosef(fmt string, msg ...interface{}) { + l.Called(append([]interface{}{fmt}, msg...)...) +} + +// Warning implements logging.LoggerInterface. +func (l *MockLogger) Warning(msg ...interface{}) { + l.Called(msg...) +} + +// Warningf implements logging.LoggerInterface. +func (l *MockLogger) Warningf(fmt string, msg ...interface{}) { + l.Called(append([]interface{}{fmt}, msg...)...) } diff --git a/provisional/int64cache/cache.go b/provisional/int64cache/cache.go deleted file mode 100644 index c65a04d..0000000 --- a/provisional/int64cache/cache.go +++ /dev/null @@ -1,91 +0,0 @@ -package int64cache - -import ( - "container/list" - "fmt" - "sync" -) - -// Int64Cache is an in-memory TTL & LRU cache -type Int64Cache interface { - Get(key int64) (int64, error) - Set(key int64, value int64) error -} - -// Impl implements the LocalCache interface -type Impl struct { - maxLen int - items map[int64]*list.Element - lru *list.List - mutex sync.Mutex -} - -type entry struct { - key int64 - value int64 -} - -// Get retrieves an item if exist, nil + an error otherwise -func (c *Impl) Get(key int64) (int64, error) { - c.mutex.Lock() - defer c.mutex.Unlock() - node, ok := c.items[key] - if !ok { - return 0, &Miss{} - } - - entry, ok := node.Value.(entry) - if !ok { - return 0, fmt.Errorf("Invalid data in cache for key %d", key) - } - - c.lru.MoveToFront(node) - return entry.value, nil -} - -// Set adds a new item. Since the cache being full results in removing the LRU element, this method never fails. -func (c *Impl) Set(key int64, value int64) error { - c.mutex.Lock() - defer c.mutex.Unlock() - if node, ok := c.items[key]; ok { - c.lru.MoveToFront(node) - node.Value = entry{key: key, value: value} - } else { - // Drop the LRU item on the list before adding a new one. - if c.lru.Len() == c.maxLen { - entry, ok := c.lru.Back().Value.(entry) - if !ok { - return fmt.Errorf("Invalid data in list for key %d", key) - } - key := entry.key - delete(c.items, key) - c.lru.Remove(c.lru.Back()) - } - - ptr := c.lru.PushFront(entry{key: key, value: value}) - c.items[key] = ptr - } - return nil -} - -// NewInt64Cache returns a new LocalCache instance of the specified size and TTL -func NewInt64Cache(maxSize int) (*Impl, error) { - if maxSize <= 0 { - return nil, fmt.Errorf("Cache size should be > 0. Is: %d", maxSize) - } - - return &Impl{ - maxLen: maxSize, - lru: new(list.List), - items: make(map[int64]*list.Element, maxSize), - }, nil -} - -// Miss is a special error indicating the key was not found in the cache -type Miss struct { - Key int64 -} - -func (m *Miss) Error() string { - return fmt.Sprintf("key %d not found in cache", m.Key) -} diff --git a/provisional/int64cache/cache_test.go b/provisional/int64cache/cache_test.go deleted file mode 100644 index 2633bca..0000000 --- a/provisional/int64cache/cache_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package int64cache - -import ( - "testing" -) - -func TestInt64Cache(t *testing.T) { - c, err := NewInt64Cache(5) - if err != nil { - t.Error("No error should have been returned. Got: ", err) - } - - for i := int64(1); i <= 5; i++ { - err := c.Set(i, i) - if err != nil { - t.Errorf("Setting value '%d', should not have raised an error. Got: %s", i, err) - } - } - - for i := int64(1); i <= 5; i++ { - val, err := c.Get(i) - if err != nil { - t.Errorf("Getting value '%d', should not have raised an error. Got: %s", i, err) - } - if val != i { - t.Errorf("Value for key '%d' should be %d. Is %d", i, i, val) - } - } - - c.Set(6, 6) - - // Oldest item (1) should have been removed - val, err := c.Get(1) - if err == nil { - t.Errorf("Getting value 'someKey1', should not have raised an error. Got: %s", err) - } - - _, ok := err.(*Miss) - if !ok { - t.Errorf("Error should be of type Miss. Is %T", err) - } - - if val != 0 { - t.Errorf("Value for key 'someKey1' should be nil. Is %d", val) - } - - // 2-6 should be available - for i := int64(2); i <= 6; i++ { - val, err := c.Get(i) - if err != nil { - t.Errorf("Getting value '%d', should not have raised an error. Got: %s", i, err) - } - if val != i { - t.Errorf("Value for key '%d' should be %d. Is %d", i, i, val) - } - } - - if len(c.items) != 5 { - t.Error("Items size should be 5. is: ", len(c.items)) - } -} diff --git a/redis/wrapper_test.go b/redis/wrapper_test.go index a66ea00..e5fb32b 100644 --- a/redis/wrapper_test.go +++ b/redis/wrapper_test.go @@ -6,7 +6,7 @@ import ( "time" "github.com/redis/go-redis/v9" - "github.com/splitio/go-toolkit/v5/testhelpers" + "github.com/stretchr/testify/assert" ) func TestRedisWrapperKeysAndScan(t *testing.T) { @@ -93,7 +93,7 @@ func TestRedisWrapperPipeline(t *testing.T) { } items, _ := result[0].Multi() - testhelpers.AssertStringSliceEquals(t, items, []string{"e1", "e2", "e3"}, "result of lrange should be e1,e2,e3") + assert.Equal(t, []string{"e1", "e2", "e3"}, items) if l := result[1].Int(); l != 3 { t.Error("length should be 3. is: ", l) } diff --git a/json-struct-validator/validator.go b/struct/jsonvalidator/validator.go similarity index 94% rename from json-struct-validator/validator.go rename to struct/jsonvalidator/validator.go index 9b30481..8aa70ea 100644 --- a/json-struct-validator/validator.go +++ b/struct/jsonvalidator/validator.go @@ -1,4 +1,4 @@ -package validator +package jsonvalidator import ( "encoding/json" @@ -64,9 +64,9 @@ func getFieldsForMap(s map[string]interface{}) []string { return getFieldsForMapRecursive("", s) } -func validateParameters(userConf []string, p *set.ThreadUnsafeSet) error { +func validateParameters(userConf []string, p set.Set[string]) error { for _, field := range userConf { - if !p.Has(field) { + if !p.Contains(field) { return errors.New(field) } } @@ -89,7 +89,7 @@ func ValidateConfiguration(p interface{}, s []byte) error { structToInspect = reflect.Indirect(reflect.ValueOf(p)).Interface() } primaryFieldList := getFieldsForStruct(structToInspect) - primarySet := set.NewSet() + primarySet := set.New[string](100) for _, c := range primaryFieldList { primarySet.Add(c) } diff --git a/json-struct-validator/validator_test.go b/struct/jsonvalidator/validator_test.go similarity index 99% rename from json-struct-validator/validator_test.go rename to struct/jsonvalidator/validator_test.go index b63e006..c79eb74 100644 --- a/json-struct-validator/validator_test.go +++ b/struct/jsonvalidator/validator_test.go @@ -1,4 +1,4 @@ -package validator +package jsonvalidator import ( "testing" diff --git a/testfiles/murmur3_64_uuids.csv b/testdata/murmur3_64_uuids.csv similarity index 100% rename from testfiles/murmur3_64_uuids.csv rename to testdata/murmur3_64_uuids.csv diff --git a/testhelpers/helpers.go b/testhelpers/helpers.go deleted file mode 100644 index d91a922..0000000 --- a/testhelpers/helpers.go +++ /dev/null @@ -1,48 +0,0 @@ -package testhelpers - -import ( - "testing" - - "github.com/splitio/go-toolkit/v5/datastructures/set" -) - -// AssertStringSliceEquals fails is two string slices are not identical -func AssertStringSliceEquals(t *testing.T, actual []string, expected []string, message string) { - t.Helper() - if len(actual) != len(expected) { - t.Errorf(message) - t.Errorf("Slices have different sizes. Actual: %d, expected: %d", len(actual), len(expected)) - t.Errorf("Actual: %v || Expected: %v", actual, expected) - return - } - - idx := 0 - for idx < len(actual) && actual[idx] == expected[idx] { - idx++ - } - - if idx != len(actual) { - t.Errorf(message) - t.Errorf("Slices have different elements") - t.Errorf("Actual: %v || Expected: %v", actual, expected) - } -} - -func AssertStringSliceEqualsNoOrder(t *testing.T, actual []string, expected []string, message string) { - t.Helper() - asInterfaces1 := make([]interface{}, 0, len(actual)) - for _, s := range actual { - asInterfaces1 = append(asInterfaces1, s) - } - asInterfaces2 := make([]interface{}, 0, len(expected)) - for _, s := range expected { - asInterfaces2 = append(asInterfaces2, s) - } - set1 := set.NewSet(asInterfaces1...) - set2 := set.NewSet(asInterfaces2...) - if !set1.IsEqual(set2) { - t.Error("slices contain different elements despite order: ", message) - t.Error("actual: ", actual) - t.Error("expected: ", expected) - } -} From c250cf013c9c523e8e5d095de509e002227c96f4 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Tue, 26 Mar 2024 11:43:00 -0300 Subject: [PATCH 2/8] switch to v6 --- Makefile | 11 +++++++++++ asynctask/asynctasks.go | 4 ++-- asynctask/asynctasks_test.go | 2 +- backoff/mocks/mocks.go | 2 +- datastructures/cache/multilevel.go | 2 +- datastructures/cache/multilevel_test.go | 2 +- go.mod | 2 +- logging/levels_test.go | 2 +- redis/helpers/helpers.go | 2 +- redis/helpers/helpers_test.go | 4 ++-- redis/mocks/mocks.go | 2 +- sse/sse.go | 4 ++-- sse/sse_test.go | 2 +- struct/jsonvalidator/validator.go | 2 +- workerpool/workerpool.go | 4 ++-- workerpool/workerpool_test.go | 2 +- 16 files changed, 30 insertions(+), 19 deletions(-) create mode 100644 Makefile diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..2b35782 --- /dev/null +++ b/Makefile @@ -0,0 +1,11 @@ +GO ?= go + +.PHONY: test test-norace + +test: + $(GO) test ./... -count=1 -race + +test-norace: + $(GO) test ./... -count=1 + + diff --git a/asynctask/asynctasks.go b/asynctask/asynctasks.go index 06a04ac..226a898 100644 --- a/asynctask/asynctasks.go +++ b/asynctask/asynctasks.go @@ -4,8 +4,8 @@ import ( "fmt" "time" - "github.com/splitio/go-toolkit/v5/logging" - "github.com/splitio/go-toolkit/v5/struct/traits/lifecycle" + "github.com/splitio/go-toolkit/v6/logging" + "github.com/splitio/go-toolkit/v6/struct/traits/lifecycle" ) // AsyncTask is a struct that wraps tasks that should run periodically and can be remotely stopped & started, diff --git a/asynctask/asynctasks_test.go b/asynctask/asynctasks_test.go index 5e60665..b1012fb 100644 --- a/asynctask/asynctasks_test.go +++ b/asynctask/asynctasks_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/splitio/go-toolkit/v5/logging" + "github.com/splitio/go-toolkit/v6/logging" ) func TestAsyncTaskNormalOperation(t *testing.T) { diff --git a/backoff/mocks/mocks.go b/backoff/mocks/mocks.go index dca07d8..555fc80 100644 --- a/backoff/mocks/mocks.go +++ b/backoff/mocks/mocks.go @@ -1,7 +1,7 @@ package mocks import ( - "github.com/splitio/go-toolkit/v5/backoff" + "github.com/splitio/go-toolkit/v6/backoff" "github.com/stretchr/testify/mock" "time" ) diff --git a/datastructures/cache/multilevel.go b/datastructures/cache/multilevel.go index 917eb8d..0c2254e 100644 --- a/datastructures/cache/multilevel.go +++ b/datastructures/cache/multilevel.go @@ -3,7 +3,7 @@ package cache import ( "context" - "github.com/splitio/go-toolkit/v5/logging" + "github.com/splitio/go-toolkit/v6/logging" ) // MLCLayer is the interface that should be implemented for all caching structs to be used with this piece of code. diff --git a/datastructures/cache/multilevel_test.go b/datastructures/cache/multilevel_test.go index c3549d2..22d2907 100644 --- a/datastructures/cache/multilevel_test.go +++ b/datastructures/cache/multilevel_test.go @@ -6,7 +6,7 @@ import ( "fmt" "testing" - "github.com/splitio/go-toolkit/v5/logging" + "github.com/splitio/go-toolkit/v6/logging" ) type LayerMock struct { diff --git a/go.mod b/go.mod index 5585cea..11a1f1c 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/splitio/go-toolkit/v5 +module github.com/splitio/go-toolkit/v6 go 1.21 diff --git a/logging/levels_test.go b/logging/levels_test.go index 21ef84d..8fde475 100644 --- a/logging/levels_test.go +++ b/logging/levels_test.go @@ -3,7 +3,7 @@ package logging import ( "testing" - "github.com/splitio/go-toolkit/v5/logging/mocks" + "github.com/splitio/go-toolkit/v6/logging/mocks" ) func TestErrorLevel(t *testing.T) { diff --git a/redis/helpers/helpers.go b/redis/helpers/helpers.go index fb728b0..0807d30 100644 --- a/redis/helpers/helpers.go +++ b/redis/helpers/helpers.go @@ -3,7 +3,7 @@ package helpers import ( "fmt" - "github.com/splitio/go-toolkit/v5/redis" + "github.com/splitio/go-toolkit/v6/redis" ) const ( diff --git a/redis/helpers/helpers_test.go b/redis/helpers/helpers_test.go index dca3cd4..d968aa0 100644 --- a/redis/helpers/helpers_test.go +++ b/redis/helpers/helpers_test.go @@ -4,8 +4,8 @@ import ( "errors" "testing" - "github.com/splitio/go-toolkit/v5/redis" - "github.com/splitio/go-toolkit/v5/redis/mocks" + "github.com/splitio/go-toolkit/v6/redis" + "github.com/splitio/go-toolkit/v6/redis/mocks" ) func TestEnsureConnected(t *testing.T) { diff --git a/redis/mocks/mocks.go b/redis/mocks/mocks.go index 88b180a..fb84349 100644 --- a/redis/mocks/mocks.go +++ b/redis/mocks/mocks.go @@ -3,7 +3,7 @@ package mocks import ( "time" - "github.com/splitio/go-toolkit/v5/redis" + "github.com/splitio/go-toolkit/v6/redis" ) // MockResultOutput mocks struct diff --git a/sse/sse.go b/sse/sse.go index 5c533de..cca3354 100644 --- a/sse/sse.go +++ b/sse/sse.go @@ -9,8 +9,8 @@ import ( "sync" "time" - "github.com/splitio/go-toolkit/v5/logging" - "github.com/splitio/go-toolkit/v5/struct/traits/lifecycle" + "github.com/splitio/go-toolkit/v6/logging" + "github.com/splitio/go-toolkit/v6/struct/traits/lifecycle" ) const ( diff --git a/sse/sse_test.go b/sse/sse_test.go index df6f2b3..59e256b 100644 --- a/sse/sse_test.go +++ b/sse/sse_test.go @@ -9,7 +9,7 @@ import ( "testing" "time" - "github.com/splitio/go-toolkit/v5/logging" + "github.com/splitio/go-toolkit/v6/logging" ) func TestSSEErrorConnecting(t *testing.T) { diff --git a/struct/jsonvalidator/validator.go b/struct/jsonvalidator/validator.go index 8aa70ea..151c5e4 100644 --- a/struct/jsonvalidator/validator.go +++ b/struct/jsonvalidator/validator.go @@ -6,7 +6,7 @@ import ( "reflect" "strings" - "github.com/splitio/go-toolkit/v5/datastructures/set" + "github.com/splitio/go-toolkit/v6/datastructures/set" ) func getFieldsForStructRecursive(prefix string, structType reflect.Type) []string { diff --git a/workerpool/workerpool.go b/workerpool/workerpool.go index b763dae..9018a14 100644 --- a/workerpool/workerpool.go +++ b/workerpool/workerpool.go @@ -5,8 +5,8 @@ import ( "sync" "time" - "github.com/splitio/go-toolkit/v5/logging" - "github.com/splitio/go-toolkit/v5/struct/traits/lifecycle" + "github.com/splitio/go-toolkit/v6/logging" + "github.com/splitio/go-toolkit/v6/struct/traits/lifecycle" ) const ( diff --git a/workerpool/workerpool_test.go b/workerpool/workerpool_test.go index 04acbd2..9606b7f 100644 --- a/workerpool/workerpool_test.go +++ b/workerpool/workerpool_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/splitio/go-toolkit/v5/logging" + "github.com/splitio/go-toolkit/v6/logging" ) var resMutex sync.RWMutex From 3dcce66cf7e59362b8591b6e0e8ddff7e642edf7 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Thu, 28 Mar 2024 09:03:06 -0300 Subject: [PATCH 3/8] modernize mocks & test harness --- CHANGES | 7 + asynctask/asynctasks_test.go | 48 +- backoff/backoff_test.go | 30 +- common/common.go | 18 +- datastructures/boolslice/boolslice_test.go | 151 +-- datastructures/cache/cache_test.go | 142 +-- datastructures/cache/mocks/mocks.go | 22 + datastructures/cache/multilevel_test.go | 221 +--- datastructures/queuecache/cache_test.go | 93 +- datautils/compress.go | 6 +- datautils/compress_test.go | 38 +- datautils/encode_test.go | 27 +- deepcopy/deepcopy.go | 114 -- deepcopy/deepcopy_test.go | 1110 -------------------- hashing/hasher_test.go | 23 +- hashing/murmur128_test.go | 12 +- hashing/util_test.go | 24 +- redis/helpers/helpers_test.go | 68 +- redis/mocks/mocks.go | 389 ++++--- redis/wrapper_test.go | 100 +- sse/event_test.go | 31 +- sse/mocks/mocks.go | 21 +- sse/sse_test.go | 123 +-- struct/jsonvalidator/validator_test.go | 76 +- struct/traits/lifecycle/lifecycle_test.go | 160 +-- sync/atomicbool_test.go | 40 +- workerpool/workerpool_test.go | 29 +- 27 files changed, 616 insertions(+), 2507 deletions(-) create mode 100644 datastructures/cache/mocks/mocks.go delete mode 100644 deepcopy/deepcopy.go delete mode 100644 deepcopy/deepcopy_test.go diff --git a/CHANGES b/CHANGES index 4087f68..283ef27 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,10 @@ +v6.0.0 (TBD): +- updated common helpers to be generic +- updated datastructures to be generic +- cleanup package structre and remove deprecated ones +- updated logger with formatting functionality +- modernized test harness & mocks + 5.4.0 (Jan 10, 2024) - Added `Scan` operation to Redis diff --git a/asynctask/asynctasks_test.go b/asynctask/asynctasks_test.go index b1012fb..77e6b99 100644 --- a/asynctask/asynctasks_test.go +++ b/asynctask/asynctasks_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/splitio/go-toolkit/v6/logging" + "github.com/stretchr/testify/assert" ) func TestAsyncTaskNormalOperation(t *testing.T) { @@ -29,27 +30,14 @@ func TestAsyncTaskNormalOperation(t *testing.T) { task1.Start() time.Sleep(1 * time.Second) - if !task1.IsRunning() { - t.Error("Task should be running") - } - time.Sleep(1 * time.Second) + assert.True(t, task1.IsRunning()) + time.Sleep(1 * time.Second) task1.Stop(true) - if task1.IsRunning() { - t.Error("Task should be stopped") - } - - if !onInit.Load().(bool) { - t.Error("Initialization hook not executed") - } - - if !onExecution.Load().(bool) { - t.Error("Main task function not executed") - } - - if !onStop.Load().(bool) { - t.Error("After execution function not executed") - } + assert.False(t, task1.IsRunning()) + assert.True(t, onInit.Load().(bool)) + assert.True(t, onExecution.Load().(bool)) + assert.True(t, onStop.Load().(bool)) } func TestAsyncTaskPanics(t *testing.T) { @@ -94,15 +82,10 @@ func TestAsyncTaskPanics(t *testing.T) { task3.Start() time.Sleep(time.Second * 2) task3.Stop(true) - if task1.IsRunning() { - t.Error("Task1 is running and should be stopped") - } - if task2.IsRunning() { - t.Error("Task2 is running and should be stopped") - } - if task3.IsRunning() { - t.Error("Task3 is running and should be stopped") - } + + assert.False(t, task1.IsRunning()) + assert.False(t, task2.IsRunning()) + assert.False(t, task3.IsRunning()) } func TestAsyncTaskErrors(t *testing.T) { @@ -138,9 +121,8 @@ func TestAsyncTaskErrors(t *testing.T) { task2.Start() time.Sleep(2 * time.Second) - if res.Load().(int) != 0 { - t.Error("Task should have never executed if there was an error when calling onInit()") - } + + assert.Equal(t, int(0), res.Load().(int)) } func TestAsyncTaskWakeUp(t *testing.T) { @@ -163,7 +145,5 @@ func TestAsyncTaskWakeUp(t *testing.T) { _ = task1.WakeUp() _ = task1.Stop(true) - if atomic.LoadInt32(&res) != 3 { - t.Errorf("Task shuld have executed 4 times. It ran %d times", res) - } + assert.Equal(t, int32(3), atomic.LoadInt32(&res)) } diff --git a/backoff/backoff_test.go b/backoff/backoff_test.go index 8836a3e..58c719f 100644 --- a/backoff/backoff_test.go +++ b/backoff/backoff_test.go @@ -3,32 +3,20 @@ package backoff import ( "testing" "time" + "github.com/stretchr/testify/assert" ) func TestBackoff(t *testing.T) { base := int64(10) maxAllowed := 60 * time.Second backoff := New(base, maxAllowed) - if backoff.base != base { - t.Error("It should be equals to 10") - } - if backoff.maxAllowed != maxAllowed { - t.Error("It should be equals to 60") - } - if backoff.Next() != 1*time.Second { - t.Error("It should be 1 second") - } - if backoff.Next() != 10*time.Second { - t.Error("It should be 10 seconds") - } - if backoff.Next() != 60*time.Second { - t.Error("It should be 60 seconds") - } + assert.Equal(t, base, backoff.base) + assert.Equal(t, maxAllowed, backoff.maxAllowed) + assert.Equal(t, 1*time.Second, backoff.Next()) + assert.Equal(t, 10*time.Second, backoff.Next()) + assert.Equal(t, 60*time.Second, backoff.Next()) + backoff.Reset() - if backoff.current != 0 { - t.Error("It should be zero") - } - if backoff.Next() != 1*time.Second { - t.Error("It should be 1 second") - } + assert.Equal(t, int64(0), backoff.current) + assert.Equal(t, 1*time.Second, backoff.Next()) } diff --git a/common/common.go b/common/common.go index b3bbff8..de86446 100644 --- a/common/common.go +++ b/common/common.go @@ -13,11 +13,11 @@ func Ref[T any](x T) *T { // RefOrNil returns a pointer to the value supplied if it's not the default value, nil otherwise func RefOrNil[T comparable](x T) *T { - var t T - if x == t { - return nil - } - return &x + var t T + if x == t { + return nil + } + return &x } // PointerOf performs a type-assertion to T and returns a pointer if successful, nil otherwise. @@ -93,3 +93,11 @@ func Min[T cmp.Ordered](i1 T, rest ...T) T { } return min } + +func AsInterfaceSlice[T any](in []T) []interface{} { + out := make([]interface{}, len(in)) + for idx := range in { + out[idx] = in[idx] + } + return out +} diff --git a/datastructures/boolslice/boolslice_test.go b/datastructures/boolslice/boolslice_test.go index 7432888..20e3093 100644 --- a/datastructures/boolslice/boolslice_test.go +++ b/datastructures/boolslice/boolslice_test.go @@ -3,18 +3,16 @@ package boolslice import ( "math" "testing" + + "github.com/stretchr/testify/assert" ) func TestBoolSlice(t *testing.T) { _, err := NewBoolSlice(12) - if err == nil { - t.Error("It should return err") - } + assert.NotNil(t, err) b, err := NewBoolSlice(int(math.Pow(2, 15))) - if err != nil { - t.Error("It should not return err", err) - } + assert.Nil(t, err) i1 := 12 i2 := 20 @@ -22,96 +20,57 @@ func TestBoolSlice(t *testing.T) { i4 := 2000 i5 := 8192 - if err := b.Set(int(math.Pow(2, 15)) + 1); err == nil { - t.Error("It should return err") - } - if err := b.Set(i1); err != nil { - t.Error("It should not return err") - } - if err := b.Set(i2); err != nil { - t.Error("It should not return err") - } - if err := b.Set(i3); err != nil { - t.Error("It should not return err") - } - if err := b.Set(i4); err != nil { - t.Error("It should not return err") - } - if err := b.Set(i5); err != nil { - t.Error("It should not return err") - } - - if _, err := b.Get(int(math.Pow(2, 15)) + 1); err == nil { - t.Error("It should return err") - } - if v, _ := b.Get(i1); !v { - t.Error("It should match", i1) - } - if v, _ := b.Get(i2); !v { - t.Error("It should match", i2) - } - if v, _ := b.Get(i3); !v { - t.Error("It should match", i3) - } - if v, _ := b.Get(i4); !v { - t.Error("It should match", i4) - } - if v, _ := b.Get(i5); !v { - t.Error("It should match", i5) - } - if v, _ := b.Get(200); v { - t.Error("It should not match 200") - } - if v, _ := b.Get(5000); v { - t.Error("It should not match 5000") - } - - if len(b.Bytes()) != int(math.Pow(2, 15)/8) { - t.Error("Len should be 4096") - } - - if err := b.Clear(int(math.Pow(2, 15)) + 1); err == nil { - t.Error("It should return err") - } - if err := b.Clear(i1); err != nil { - t.Error("It should not return err") - } - - if v, _ := b.Get(i1); v { - t.Error("It should not match after cleared", i1) - } - - if _, err := Rebuild(1, nil); err.Error() != "size must be a multiple of 8" { - t.Error("It should return err") - } - - if _, err := Rebuild(8, nil); err.Error() != "data cannot be empty" { - t.Error("It should return err") - } + assert.Equal(t, ErrorOutOfBounds, b.Set(int(math.Pow(2, 15)) + 1)) + assert.Nil(t, b.Set(i1)) + assert.Nil(t, b.Set(i2)) + assert.Nil(t, b.Set(i3)) + assert.Nil(t, b.Set(i4)) + assert.Nil(t, b.Set(i5)) + + set, err := b.Get(int(math.Pow(2, 15)) + 1) + assert.False(t, set) + assert.Equal(t, ErrorOutOfBounds, err) + + for _, i := range []int{i1, i2, i3, i4, i5} { + res, err := b.Get(i) + assert.Nil(t, err) + assert.True(t, res) + } + + for _, i := range []int{200, 500} { + res, err := b.Get(i) + assert.Nil(t, err) + assert.False(t, res) + } + + assert.Equal(t, math.Pow(2, 15)/8, float64(len(b.Bytes()))) + assert.Equal(t, ErrorOutOfBounds, b.Clear(int(math.Pow(2, 15)) + 1)) + assert.Nil(t, b.Clear(i1)) + + v, err := b.Get(i1) + assert.Nil(t, err) + assert.False(t, v) + + res, err := Rebuild(1, nil) + assert.Nil(t, res) + assert.NotNil(t, err) + + res, err = Rebuild(8, nil) + assert.Nil(t, res) + assert.NotNil(t, err) rebuilt, err := Rebuild(int(math.Pow(2, 15)), b.Bytes()) - if err != nil { - t.Error("It should not return err") - } - if v, _ := rebuilt.Get(i2); !v { - t.Error("It should match", i2) - } - if v, _ := rebuilt.Get(i3); !v { - t.Error("It should match", i3) - } - if v, _ := rebuilt.Get(i4); !v { - t.Error("It should match", i4) - } - if v, _ := rebuilt.Get(i5); !v { - t.Error("It should match", i5) - } - if v, _ := rebuilt.Get(i1); v { - t.Error("It should not match 12") - } - if v, _ := rebuilt.Get(200); v { - t.Error("It should not match 200") - } - if v, _ := rebuilt.Get(5000); v { - t.Error("It should not match 5000") - } + assert.Nil(t, err) + + for _, i := range []int{i2, i3, i4, i5} { + res, err := rebuilt.Get(i) + assert.Nil(t, err) + assert.True(t, res) + } + + for _, i := range []int{200, 5000} { + res, err := rebuilt.Get(i) + assert.Nil(t, err) + assert.False(t, res) + } } diff --git a/datastructures/cache/cache_test.go b/datastructures/cache/cache_test.go index 9ff36cc..1bf9eda 100644 --- a/datastructures/cache/cache_test.go +++ b/datastructures/cache/cache_test.go @@ -6,110 +6,62 @@ import ( "sync" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestSimpleCache(t *testing.T) { cache, err := NewSimpleLRU[string, int](5, 1*time.Second) - if err != nil { - t.Error("No error should have been returned. Got: ", err) - } + assert.Nil(t, err) for i := 1; i <= 5; i++ { err := cache.Set(fmt.Sprintf("someKey%d", i), i) - if err != nil { - t.Errorf("Setting value 'someKey%d', should not have raised an error. Got: %s", i, err) - } + assert.Nil(t, err) } for i := 1; i <= 5; i++ { val, err := cache.Get(fmt.Sprintf("someKey%d", i)) - if err != nil { - t.Errorf("Getting value 'someKey%d', should not have raised an error. Got: %s", i, err) - } - if val != i { - t.Errorf("Value for key 'someKey%d' should be %d. Is %d", i, i, val) - } + assert.Nil(t, err) + assert.Equal(t, i, val) } cache.Set("someKey6", 6) // Oldest item (1) should have been removed val, err := cache.Get("someKey1") - if err == nil { - t.Errorf("Getting value 'someKey1', should not have raised an error. Got: %s", err) - } - + assert.NotNil(t, err) asMiss, ok := err.(*Miss) - if !ok { - t.Errorf("Error should be of type Miss. Is %T", err) - } - - if asMiss.Key != "someKey1" || asMiss.Where != "LOCAL" { - t.Errorf("Incorrect data within the Miss error. Got: %+v", asMiss) - } - - if val != 0 { - t.Errorf("Value for key 'someKey1' should be nil. Is %d", val) - } + assert.True(t, ok) + assert.Equal(t, "someKey1", asMiss.Key) + assert.Equal(t, "LOCAL", asMiss.Where) + assert.Equal(t, 0, val) // 2-6 should be available for i := 2; i <= 6; i++ { val, err := cache.Get(fmt.Sprintf("someKey%d", i)) - if err != nil { - t.Errorf("Getting value 'someKey%d', should not have raised an error. Got: %s", i, err) - } - if val != i { - t.Errorf("Value for key 'someKey%d' should be %d. Is %d", i, i, val) - } - } - - if len(cache.items) != 5 { - t.Error("Items size should be 5. is: ", len(cache.items)) + assert.Nil(t, err) + assert.Equal(t, i, val) } - if len(cache.ttls) != len(cache.items) { - t.Error("TTLs size should be the same size as items") - } - - if cache.lru.Len() != 5 { - t.Error("LRU size should be 5. is: ", cache.lru.Len()) - } + assert.Equal(t, 5, len(cache.items)) + assert.Equal(t, 5, len(cache.ttls)) + assert.Equal(t, 5, cache.lru.Len()) time.Sleep(2 * time.Second) // Wait for all keys to expire. + for i := 2; i <= 6; i++ { val, err := cache.Get(fmt.Sprintf("someKey%d", i)) - if val != 0 { - t.Errorf("No value should have been returned for expired key 'someKey%d'.", i) - } - - if err == nil { - t.Errorf("Getting value 'someKey%d', should have raised an 'Expired' error. Got nil", i) - continue - } - - asExpiredErr, ok := err.(*Expired) - if !ok { - t.Errorf("Returned error should be of 'Expired' type. Is %T", err) - continue - } - - if asExpiredErr.Key != fmt.Sprintf("someKey%d", i) { - t.Errorf("Key in Expired error should be 'someKey%d'. Is: '%s'", i, asExpiredErr.Key) - } - - if asExpiredErr.Value != i { - t.Errorf("Value in Expired error should be %d. Is %+v", i, asExpiredErr.Value) - } + assert.Equal(t, 0, val) + assert.NotNil(t, err) + asExpired, ok := err.(*Expired) + assert.True(t, ok) + assert.Equal(t, fmt.Sprintf("someKey%d", i), asExpired.Key) + assert.Equal(t, i, asExpired.Value) ttl, ok := cache.ttls[fmt.Sprintf("someKey%d", i)] - if !ok { - t.Errorf("A ttl entry should exist for key 'someKey%d'", i) - continue - } + assert.True(t, ok) + assert.Equal(t, asExpired.When, ttl.Add(cache.ttl)) - if asExpiredErr.When != ttl.Add(cache.ttl) { - t.Errorf("Key 'someKey%d' should have expired at %+v. It did at %+v", i, ttl.Add(cache.ttl), asExpiredErr.When) - } } } @@ -141,59 +93,35 @@ func TestSimpleCacheHighConcurrency(t *testing.T) { wg.Wait() } - func TestInt64Cache(t *testing.T) { c, err := NewSimpleLRU[int64, int64](5, NoTTL) - if err != nil { - t.Error("No error should have been returned. Got: ", err) - } + assert.Nil(t, err) for i := int64(1); i <= 5; i++ { - err := c.Set(i, i) - if err != nil { - t.Errorf("Setting value '%d', should not have raised an error. Got: %s", i, err) - } + assert.Nil(t, c.Set(i, i)) } for i := int64(1); i <= 5; i++ { val, err := c.Get(i) - if err != nil { - t.Errorf("Getting value '%d', should not have raised an error. Got: %s", i, err) - } - if val != i { - t.Errorf("Value for key '%d' should be %d. Is %d", i, i, val) - } + assert.Nil(t, err) + assert.Equal(t, i, val) } c.Set(6, 6) // Oldest item (1) should have been removed val, err := c.Get(1) - if err == nil { - t.Errorf("Getting value 'someKey1', should not have raised an error. Got: %s", err) - } - + assert.NotNil(t, err) _, ok := err.(*Miss) - if !ok { - t.Errorf("Error should be of type Miss. Is %T", err) - } - - if val != 0 { - t.Errorf("Value for key 'someKey1' should be nil. Is %d", val) - } + assert.True(t, ok) + assert.Equal(t, int64(0), val) // 2-6 should be available for i := int64(2); i <= 6; i++ { val, err := c.Get(i) - if err != nil { - t.Errorf("Getting value '%d', should not have raised an error. Got: %s", i, err) - } - if val != i { - t.Errorf("Value for key '%d' should be %d. Is %d", i, i, val) - } + assert.Nil(t, err) + assert.Equal(t, i, val) } - if len(c.items) != 5 { - t.Error("Items size should be 5. is: ", len(c.items)) - } + assert.Equal(t, 5, len(c.items)) } diff --git a/datastructures/cache/mocks/mocks.go b/datastructures/cache/mocks/mocks.go new file mode 100644 index 0000000..9fcf797 --- /dev/null +++ b/datastructures/cache/mocks/mocks.go @@ -0,0 +1,22 @@ +package mocks + +import ( + "context" + + "github.com/stretchr/testify/mock" +) + +type LayerMock struct { + mock.Mock +} + +func (m *LayerMock) Get(ctx context.Context, key string) (string, error) { + args := m.Called(ctx, key) + return args.String(0), args.Error(1) +} + +func (m *LayerMock) Set(ctx context.Context, key string, value string) error { + args := m.Called(ctx, key, value) + return args.Error(0) +} + diff --git a/datastructures/cache/multilevel_test.go b/datastructures/cache/multilevel_test.go index 22d2907..6adc401 100644 --- a/datastructures/cache/multilevel_test.go +++ b/datastructures/cache/multilevel_test.go @@ -2,212 +2,61 @@ package cache import ( "context" - "errors" - "fmt" "testing" + "github.com/splitio/go-toolkit/v6/datastructures/cache/mocks" "github.com/splitio/go-toolkit/v6/logging" + "github.com/stretchr/testify/assert" ) -type LayerMock struct { - getCall func(ctx context.Context, key string) (string, error) - setCall func(ctx context.Context, key string, value string) error -} - -func (m *LayerMock) Get(ctx context.Context, key string) (string, error) { - return m.getCall(ctx, key) -} - -func (m *LayerMock) Set(ctx context.Context, key string, value string) error { - return m.setCall(ctx, key, value) -} - -type callTracker struct { - calls map[string]int - t *testing.T -} - -func newCallTracker(t *testing.T) *callTracker { - return &callTracker{calls: make(map[string]int), t: t} -} - -func (c *callTracker) track(name string) { c.calls[name]++ } - -func (c *callTracker) reset() { c.calls = make(map[string]int) } - -func (c *callTracker) checkCall(name string, count int) { - c.t.Helper() - if c.calls[name] != count { - c.t.Errorf("calls for '%s' should be %d. is: %d", name, count, c.calls[name]) - } -} - -func (c *callTracker) checkTotalCalls(count int) { - c.t.Helper() - if len(c.calls) != count { - c.t.Errorf("The nomber of total calls should be '%d' and is '%d'", count, len(c.calls)) - } -} - func TestMultiLevelCache(t *testing.T) { // To test this we setup 3 layers of caching in order of querying: top -> mid -> bottom // Top layer has key1, doesn't have key2 (returns Miss), has key3 expired and errors out when requesting any other Key // Mid layer has key 2, returns a Miss on any other key, and fails the test if key1 is fetched (because it was present on top layer) // Bottom layer fails if key1 or 2 are requested, has key 3. returns Miss if any other key is requested - calls := newCallTracker(t) - topLayer := &LayerMock{ - getCall: func(ctx context.Context, key string) (string, error) { - calls.track(fmt.Sprintf("top:get:%s", key)) - switch key { - case "key1": - return "value1", nil - case "key2": - return "", &Miss{Where: "layer1", Key: "key2"} - case "key3": - return "", &Expired{Key: "key3", Value: "someOtherValue"} - default: - return "", errors.New("someError") - } - }, - setCall: func(ctx context.Context, key string, value string) error { - calls.track(fmt.Sprintf("top:set:%s", key)) - switch key { - case "key1": - t.Error("Set should not be called on the top layer for key1") - break - case "key2": - break - case "key3": - break - default: - return errors.New("someError") - } - return nil - }, - } - midLayer := &LayerMock{ - getCall: func(ctx context.Context, key string) (string, error) { - calls.track(fmt.Sprintf("mid:get:%s", key)) - switch key { - case "key1": - t.Error("Get should not be called on the mid layer for key1") - return "", nil - case "key2": - return "value2", nil - default: - return "", &Miss{Where: "layer2", Key: key} - } - }, - setCall: func(ctx context.Context, key string, value string) error { - calls.track(fmt.Sprintf("mid:set:%s", key)) - switch key { - case "key1": - t.Error("Set should not be called on the mid layer for key1") - case "key2": - t.Error("Set should not be called on the mid layer for key2") - case "key3": - default: - return errors.New("someError") - } - return nil - }, - } + ctx := context.Background() - bottomLayer := &LayerMock{ - getCall: func(ctx context.Context, key string) (string, error) { - calls.track(fmt.Sprintf("bot:get:%s", key)) - switch key { - case "key1": - t.Error("Get should not be called on the mid layer for key1") - return "", nil - case "key2": - t.Error("Get should not be called on the mid layer for key1") - return "", nil - case "key3": - return "value3", nil - default: - return "", &Miss{Where: "layer3", Key: key} - } - }, - setCall: func(ctx context.Context, key string, value string) error { - calls.track(fmt.Sprintf("bot:set:%s", key)) - switch key { - case "key1": - t.Error("Set should not be called on the mid layer for key1") - case "key2": - t.Error("Set should not be called on the mid layer for key2") - default: - return errors.New("someError") - } - return nil - }, - } + topLayer := &mocks.LayerMock{} + topLayer.On("Get", ctx, "key1").Once().Return("value1", nil) + topLayer.On("Get", ctx, "key2").Once().Return("", &Miss{Where: "layer1", Key: "key2"}) + topLayer.On("Get", ctx, "key3").Once().Return("value1", &Expired{Key: "key3", Value: "someOtherValue"}) + topLayer.On("Get", ctx, "key4").Once().Return("", &Miss{Where: "layer1", Key: "key4"}) + topLayer.On("Set", ctx, "key2", "value2").Once().Return(nil) + topLayer.On("Set", ctx, "key3", "value3").Once().Return(nil) + + midLayer := &mocks.LayerMock{} + midLayer.On("Get", ctx, "key2").Once().Return("value2", nil) + midLayer.On("Get", ctx, "key3").Once().Return("", &Miss{Where: "layer2", Key: "key3"}, nil) + midLayer.On("Get", ctx, "key4").Once().Return("", &Miss{Where: "layer2", Key: "key4"}) + midLayer.On("Set", ctx, "key3", "value3").Once().Return(nil) + + bottomLayer := &mocks.LayerMock{} + bottomLayer.On("Get", ctx, "key3").Once().Return("value3", nil) + bottomLayer.On("Get", ctx, "key4").Once().Return("", &Miss{Where: "layer3", Key: "key4"}) cacheML := MultiLevelCacheImpl[string, string]{ logger: logging.NewLogger(nil), layers: []MLCLayer[string, string]{topLayer, midLayer, bottomLayer}, } - value1, err := cacheML.Get(context.TODO(), "key1") - if err != nil { - t.Error("No error should have been returned. Got: ", err) - } - if value1 != "value1" { - t.Error("Get 'key1' should return 'value1'. Got: ", value1) - } - calls.checkCall("top:get:key1", 1) - calls.checkTotalCalls(1) - - calls.reset() - value2, err := cacheML.Get(context.TODO(), "key2") - if err != nil { - t.Error("No error should have been returned. Got: ", err) - } - if value2 != "value2" { - t.Error("Get 'key2' should return 'value2'. Got: ", value2) - } - calls.checkCall("top:get:key2", 1) - calls.checkCall("mid:get:key2", 1) - calls.checkCall("top:set:key2", 1) - calls.checkTotalCalls(3) - - calls.reset() - value3, err := cacheML.Get(context.TODO(), "key3") - if err != nil { - t.Error("Error should be nil. Was: ", err) - } + value1, err := cacheML.Get(ctx, "key1") + assert.Nil(t, err) + assert.Equal(t, "value1", value1) - if value3 != "value3" { - t.Error("Get 'key3' should return 'value3'. Got: ", value3) - } - calls.checkCall("top:get:key3", 1) - calls.checkCall("mid:get:key3", 1) - calls.checkCall("bot:get:key3", 1) - calls.checkCall("mid:set:key3", 1) - calls.checkCall("top:set:key3", 1) - calls.checkTotalCalls(5) + value2, err := cacheML.Get(ctx, "key2") + assert.Nil(t, err) + assert.Equal(t, "value2", value2) - calls.reset() - value4, err := cacheML.Get(context.TODO(), "key4") - if err == nil { - t.Error("Error should be returned when getting nonexistant key.") - } + value3, err := cacheML.Get(ctx, "key3") + assert.Nil(t, err) + assert.Equal(t, "value3", value3) + value4, err := cacheML.Get(ctx, "key4") + assert.NotNil(t, err) asMiss, ok := err.(*Miss) - if !ok { - t.Errorf("Error should be of Miss type. Is %T", err) - } - - if asMiss.Where != "ALL_LEVELS" || asMiss.Key != "key4" { - t.Errorf("Incorrect 'Where' or 'Key'. Got: %+v", asMiss) - } - - if value4 != "" { - t.Errorf("Value returned for GET 'key4' should be nil. Is: %+v", value4) - } - calls.checkCall("top:get:key4", 1) - calls.checkCall("top:get:key4", 1) - calls.checkCall("top:get:key4", 1) - calls.checkTotalCalls(3) + assert.True(t, ok) + assert.Equal(t, "ALL_LEVELS", asMiss.Where) + assert.Equal(t, "key4", asMiss.Key) + assert.Equal(t, "", value4) } diff --git a/datastructures/queuecache/cache_test.go b/datastructures/queuecache/cache_test.go index 6e5d734..324b8e3 100644 --- a/datastructures/queuecache/cache_test.go +++ b/datastructures/queuecache/cache_test.go @@ -4,6 +4,8 @@ import ( "errors" "math" "testing" + + "github.com/stretchr/testify/assert" ) func TestCacheBasicUsage(t *testing.T) { @@ -36,78 +38,42 @@ func TestCacheBasicUsage(t *testing.T) { for index, item := range first5 { asInt, ok := item.(int) - if !ok { - t.Error("Item should be stored as int and isn't") - } - - if asInt != index { - t.Error("Each number should be equal to its index") - } + assert.True(t, ok) + assert.Equal(t, index, asInt) } offset := 5 next5, err := myCache.Fetch(5) - if err != nil { - t.Error(err) - } + assert.Nil(t, err) for index, item := range next5 { asInt, ok := item.(int) - if !ok { - t.Error("Item should be stored as int and isn't") - } - - if asInt != index+offset { - t.Error("Each number should be equal to its index") - } + assert.True(t, ok) + assert.Equal(t, index+offset, asInt) } index = 0 myCache = New(10, fetchMore) for i := 0; i < 100; i++ { elem, err := myCache.Fetch(1) - if err != nil { - t.Error(err) - } - + assert.Nil(t, err) asInt, ok := elem[0].(int) - if !ok { - t.Error("Item should be stored as int and isn't") - } - - if asInt != i { - t.Error("Each number should be equal to its index") - t.Error("asInt", asInt) - t.Error("index", i) - } + assert.True(t, ok) + assert.Equal(t, i, asInt) } elems, err := myCache.Fetch(1) - if elems != nil { - t.Error("Elem should be nil and is: ", elems) - } - - if err == nil || err.Error() != "NO_MORE_DATA" { - t.Error("Error should be NO_MORE_DATA and is: ", err.Error()) - } + assert.Nil(t, elems) + assert.ErrorContains(t, err, "NO_MORE_DATA") // Set index to 0 so that refill works and restart tests. index = 0 for i := 0; i < 100; i++ { elem, err := myCache.Fetch(1) - if err != nil { - t.Error(err) - } + assert.Nil(t, err) asInt, ok := elem[0].(int) - if !ok { - t.Error("Item should be stored as int and isn't") - } - - if asInt != i { - t.Error("Each number should be equal to its index") - t.Error("asInt", asInt) - t.Error("index", i) - } + assert.True(t, ok) + assert.Equal(t, i, asInt) } } @@ -118,18 +84,11 @@ func TestRefillPanic(t *testing.T) { myCache := New(10, fetchMore) result, err := myCache.Fetch(5) - - if result != nil { - t.Error("Result should have been nil and is: ", result) - } - if err == nil { - t.Error("Error should not have been nil") - } + assert.Nil(t, result) + assert.NotNil(t, err) _, ok := err.(*RefillError) - if !ok { - t.Error("Returned error should have been a RefillError") - } + assert.True(t, ok) } func TestCountWorksProperly(t *testing.T) { @@ -137,25 +96,17 @@ func TestCountWorksProperly(t *testing.T) { cache.readCursor = 0 cache.writeCursor = 0 - if cache.Count() != 0 { - t.Error("Count should be 0 and is: ", cache.Count()) - } + assert.Equal(t, 0, cache.Count()) cache.readCursor = 0 cache.writeCursor = 1 - if cache.Count() != 1 { - t.Error("Count should be 1 and is: ", cache.Count()) - } + assert.Equal(t, 1, cache.Count()) cache.readCursor = 50 cache.writeCursor = 99 - if cache.Count() != 49 { - t.Error("Count should be 49 and is: ", cache.Count()) - } + assert.Equal(t, 49, cache.Count()) cache.readCursor = 50 cache.writeCursor = 20 - if cache.Count() != 70 { - t.Error("Count should be 69 and is: ", cache.Count()) - } + assert.Equal(t, 70, cache.Count()) } diff --git a/datautils/compress.go b/datautils/compress.go index ca555aa..ee2f53a 100644 --- a/datautils/compress.go +++ b/datautils/compress.go @@ -5,7 +5,7 @@ import ( "compress/gzip" "compress/zlib" "fmt" - "io/ioutil" + "io" ) const ( @@ -48,7 +48,7 @@ func Decompress(data []byte, compressType int) ([]byte, error) { return nil, err } defer gz.Close() - raw, err := ioutil.ReadAll(gz) + raw, err := io.ReadAll(gz) if err != nil { return nil, err } @@ -59,7 +59,7 @@ func Decompress(data []byte, compressType int) ([]byte, error) { return nil, err } defer zl.Close() - raw, err := ioutil.ReadAll(zl) + raw, err := io.ReadAll(zl) if err != nil { return nil, err } diff --git a/datautils/compress_test.go b/datautils/compress_test.go index 0acfbcb..6d8973b 100644 --- a/datautils/compress_test.go +++ b/datautils/compress_test.go @@ -1,53 +1,41 @@ package datautils -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" +) func TestCompressDecompressError(t *testing.T) { data := "compression" _, err := Compress([]byte(data), 4) - if err == nil || err.Error() != "compression type not found" { - t.Error("It should return err") - } + assert.ErrorContains(t, err, "compression type not found") _, err = Decompress([]byte("err"), 4) - if err == nil || err.Error() != "compression type not found" { - t.Error("It should return err") - } + assert.ErrorContains(t, err, "compression type not found") } func TestCompressDecompressGZip(t *testing.T) { data := "compression gzip" compressed, err := Compress([]byte(data), GZip) - if err != nil { - t.Error("err should be nil") - } + assert.Nil(t, err) decompressed, err := Decompress(compressed, GZip) - if err != nil { - t.Error("err should be nil") - } + assert.Nil(t, err) - if string(decompressed) != data { - t.Error("It should be equal") - } + assert.Equal(t, data, string(decompressed)) } func TestCompressDecompressZLib(t *testing.T) { data := "compression zlib" compressed, err := Compress([]byte(data), Zlib) - if err != nil { - t.Error("err should be nil") - } + assert.Nil(t, err) decompressed, err := Decompress(compressed, Zlib) - if err != nil { - t.Error("err should be nil") - } + assert.Nil(t, err) - if string(decompressed) != data { - t.Error("It should be equal") - } + assert.Equal(t, data, string(decompressed)) } diff --git a/datautils/encode_test.go b/datautils/encode_test.go index acd9bf0..1e79b30 100644 --- a/datautils/encode_test.go +++ b/datautils/encode_test.go @@ -1,32 +1,25 @@ package datautils -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" +) func TestError(t *testing.T) { _, err := Encode([]byte("err"), 4) - if err == nil || err.Error() != "encode type not found" { - t.Error("It should return err") - } + assert.ErrorContains(t, err, "encode type not found") _, err = Decode("err", 4) - if err == nil || err.Error() != "encode type not found" { - t.Error("It should return err") - } + assert.ErrorContains(t, err, "encode type not found") } func TestB64EncodeDecode(t *testing.T) { data := "encode b64" encoded, err := Encode([]byte(data), Base64) - if err != nil { - t.Error("It should not return err") - } + assert.Nil(t, err) decoded, err := Decode(encoded, Base64) - if err != nil { - t.Error("It should not return err") - } - - if data != string(decoded) { - t.Error("It should be equal") - } + assert.Nil(t, err) + assert.Equal(t, data, string(decoded)) } diff --git a/deepcopy/deepcopy.go b/deepcopy/deepcopy.go deleted file mode 100644 index 73b83f4..0000000 --- a/deepcopy/deepcopy.go +++ /dev/null @@ -1,114 +0,0 @@ -package deepcopy - -import ( - "reflect" - "time" -) - -// Interface for delegating copy process to type -type Interface interface { - DeepCopy() interface{} -} - -// Copy creates a deep copy of whatever is passed to it and returns the copy -// in an interface{}. The returned value will need to be asserted to the -// correct type. -func Copy(src interface{}) interface{} { - if src == nil { - return nil - } - - // Make the interface a reflect.Value - original := reflect.ValueOf(src) - - // Make a copy of the same type as the original. - cpy := reflect.New(original.Type()).Elem() - - // Recursively copy the original. - copyRecursive(original, cpy) - - // Return the copy as an interface. - return cpy.Interface() -} - -// copyRecursive does the actual copying of the interface. It currently has -// limited support for what it can handle. Add as needed. -func copyRecursive(original, cpy reflect.Value) { - // check for implement deepcopy.Interface - if original.CanInterface() { - if copier, ok := original.Interface().(Interface); ok { - cpy.Set(reflect.ValueOf(copier.DeepCopy())) - return - } - } - - // handle according to original's Kind - switch original.Kind() { - case reflect.Ptr: - // Get the actual value being pointed to. - originalValue := original.Elem() - - // if it isn't valid, return. - if !originalValue.IsValid() { - return - } - cpy.Set(reflect.New(originalValue.Type())) - copyRecursive(originalValue, cpy.Elem()) - - case reflect.Interface: - // If this is a nil, don't do anything - if original.IsNil() { - return - } - // Get the value for the interface, not the pointer. - originalValue := original.Elem() - - // Get the value by calling Elem(). - copyValue := reflect.New(originalValue.Type()).Elem() - copyRecursive(originalValue, copyValue) - cpy.Set(copyValue) - - case reflect.Struct: - t, ok := original.Interface().(time.Time) - if ok { - cpy.Set(reflect.ValueOf(t)) - return - } - // Go through each field of the struct and copy it. - for i := 0; i < original.NumField(); i++ { - // The Type's StructField for a given field is checked to see if StructField.PkgPath - // is set to determine if the field is exported or not because CanSet() returns false - // for settable fields. I'm not sure why. -mohae - if original.Type().Field(i).PkgPath != "" { - continue - } - copyRecursive(original.Field(i), cpy.Field(i)) - } - - case reflect.Slice: - if original.IsNil() { - return - } - // Make a new slice and copy each element. - cpy.Set(reflect.MakeSlice(original.Type(), original.Len(), original.Cap())) - for i := 0; i < original.Len(); i++ { - copyRecursive(original.Index(i), cpy.Index(i)) - } - - case reflect.Map: - if original.IsNil() { - return - } - cpy.Set(reflect.MakeMap(original.Type())) - for _, key := range original.MapKeys() { - originalValue := original.MapIndex(key) - copyValue := reflect.New(originalValue.Type()).Elem() - copyRecursive(originalValue, copyValue) - copyKey := Copy(key.Interface()) - cpy.SetMapIndex(reflect.ValueOf(copyKey), copyValue) - } - - default: - cpy.Set(original) - } -} diff --git a/deepcopy/deepcopy_test.go b/deepcopy/deepcopy_test.go deleted file mode 100644 index f150b1a..0000000 --- a/deepcopy/deepcopy_test.go +++ /dev/null @@ -1,1110 +0,0 @@ -package deepcopy - -import ( - "fmt" - "reflect" - "testing" - "time" - "unsafe" -) - -// just basic is this working stuff -func TestSimple(t *testing.T) { - Strings := []string{"a", "b", "c"} - cpyS := Copy(Strings).([]string) - if (*reflect.SliceHeader)(unsafe.Pointer(&Strings)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyS)).Data { - t.Error("[]string: expected SliceHeader data pointers to point to different locations, they didn't") - goto CopyBools - } - if len(cpyS) != len(Strings) { - t.Errorf("[]string: len was %d; want %d", len(cpyS), len(Strings)) - goto CopyBools - } - for i, v := range Strings { - if v != cpyS[i] { - t.Errorf("[]string: got %v at index %d of the copy; want %v", cpyS[i], i, v) - } - } - -CopyBools: - Bools := []bool{true, true, false, false} - cpyB := Copy(Bools).([]bool) - if (*reflect.SliceHeader)(unsafe.Pointer(&Strings)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyB)).Data { - t.Error("[]bool: expected SliceHeader data pointers to point to different locations, they didn't") - goto CopyBytes - } - if len(cpyB) != len(Bools) { - t.Errorf("[]bool: len was %d; want %d", len(cpyB), len(Bools)) - goto CopyBytes - } - for i, v := range Bools { - if v != cpyB[i] { - t.Errorf("[]bool: got %v at index %d of the copy; want %v", cpyB[i], i, v) - } - } - -CopyBytes: - Bytes := []byte("hello") - cpyBt := Copy(Bytes).([]byte) - if (*reflect.SliceHeader)(unsafe.Pointer(&Strings)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyBt)).Data { - t.Error("[]byte: expected SliceHeader data pointers to point to different locations, they didn't") - goto CopyInts - } - if len(cpyBt) != len(Bytes) { - t.Errorf("[]byte: len was %d; want %d", len(cpyBt), len(Bytes)) - goto CopyInts - } - for i, v := range Bytes { - if v != cpyBt[i] { - t.Errorf("[]byte: got %v at index %d of the copy; want %v", cpyBt[i], i, v) - } - } - -CopyInts: - Ints := []int{42} - cpyI := Copy(Ints).([]int) - if (*reflect.SliceHeader)(unsafe.Pointer(&Strings)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyI)).Data { - t.Error("[]int: expected SliceHeader data pointers to point to different locations, they didn't") - goto CopyUints - } - if len(cpyI) != len(Ints) { - t.Errorf("[]int: len was %d; want %d", len(cpyI), len(Ints)) - goto CopyUints - } - for i, v := range Ints { - if v != cpyI[i] { - t.Errorf("[]int: got %v at index %d of the copy; want %v", cpyI[i], i, v) - } - } - -CopyUints: - Uints := []uint{1, 2, 3, 4, 5} - cpyU := Copy(Uints).([]uint) - if (*reflect.SliceHeader)(unsafe.Pointer(&Strings)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyU)).Data { - t.Error("[]: expected SliceHeader data pointers to point to different locations, they didn't") - goto CopyFloat32s - } - if len(cpyU) != len(Uints) { - t.Errorf("[]uint: len was %d; want %d", len(cpyU), len(Uints)) - goto CopyFloat32s - } - for i, v := range Uints { - if v != cpyU[i] { - t.Errorf("[]uint: got %v at index %d of the copy; want %v", cpyU[i], i, v) - } - } - -CopyFloat32s: - Float32s := []float32{3.14} - cpyF := Copy(Float32s).([]float32) - if (*reflect.SliceHeader)(unsafe.Pointer(&Strings)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyF)).Data { - t.Error("[]float32: expected SliceHeader data pointers to point to different locations, they didn't") - goto CopyInterfaces - } - if len(cpyF) != len(Float32s) { - t.Errorf("[]float32: len was %d; want %d", len(cpyF), len(Float32s)) - goto CopyInterfaces - } - for i, v := range Float32s { - if v != cpyF[i] { - t.Errorf("[]float32: got %v at index %d of the copy; want %v", cpyF[i], i, v) - } - } - -CopyInterfaces: - Interfaces := []interface{}{"a", 42, true, 4.32} - cpyIf := Copy(Interfaces).([]interface{}) - if (*reflect.SliceHeader)(unsafe.Pointer(&Strings)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyIf)).Data { - t.Error("[]interfaces: expected SliceHeader data pointers to point to different locations, they didn't") - return - } - if len(cpyIf) != len(Interfaces) { - t.Errorf("[]interface{}: len was %d; want %d", len(cpyIf), len(Interfaces)) - return - } - for i, v := range Interfaces { - if v != cpyIf[i] { - t.Errorf("[]interface{}: got %v at index %d of the copy; want %v", cpyIf[i], i, v) - } - } -} - -type Basics struct { - String string - Strings []string - StringArr [4]string - Bool bool - Bools []bool - Byte byte - Bytes []byte - Int int - Ints []int - Int8 int8 - Int8s []int8 - Int16 int16 - Int16s []int16 - Int32 int32 - Int32s []int32 - Int64 int64 - Int64s []int64 - Uint uint - Uints []uint - Uint8 uint8 - Uint8s []uint8 - Uint16 uint16 - Uint16s []uint16 - Uint32 uint32 - Uint32s []uint32 - Uint64 uint64 - Uint64s []uint64 - Float32 float32 - Float32s []float32 - Float64 float64 - Float64s []float64 - Complex64 complex64 - Complex64s []complex64 - Complex128 complex128 - Complex128s []complex128 - Interface interface{} - Interfaces []interface{} -} - -// These tests test that all supported basic types are copied correctly. This -// is done by copying a struct with fields of most of the basic types as []T. -func TestMostTypes(t *testing.T) { - test := Basics{ - String: "kimchi", - Strings: []string{"uni", "ika"}, - StringArr: [4]string{"malort", "barenjager", "fernet", "salmiakki"}, - Bool: true, - Bools: []bool{true, false, true}, - Byte: 'z', - Bytes: []byte("abc"), - Int: 42, - Ints: []int{0, 1, 3, 4}, - Int8: 8, - Int8s: []int8{8, 9, 10}, - Int16: 16, - Int16s: []int16{16, 17, 18, 19}, - Int32: 32, - Int32s: []int32{32, 33}, - Int64: 64, - Int64s: []int64{64}, - Uint: 420, - Uints: []uint{11, 12, 13}, - Uint8: 81, - Uint8s: []uint8{81, 82}, - Uint16: 160, - Uint16s: []uint16{160, 161, 162, 163, 164}, - Uint32: 320, - Uint32s: []uint32{320, 321}, - Uint64: 640, - Uint64s: []uint64{6400, 6401, 6402, 6403}, - Float32: 32.32, - Float32s: []float32{32.32, 33}, - Float64: 64.1, - Float64s: []float64{64, 65, 66}, - Complex64: complex64(-64 + 12i), - Complex64s: []complex64{complex64(-65 + 11i), complex64(66 + 10i)}, - Complex128: complex128(-128 + 12i), - Complex128s: []complex128{complex128(-128 + 11i), complex128(129 + 10i)}, - Interfaces: []interface{}{42, true, "pan-galactic"}, - } - - cpy := Copy(test).(Basics) - - // see if they point to the same location - if fmt.Sprintf("%p", &cpy) == fmt.Sprintf("%p", &test) { - t.Error("address of copy was the same as original; they should be different") - return - } - - // Go through each field and check to see it got copied properly - if cpy.String != test.String { - t.Errorf("String: got %v; want %v", cpy.String, test.String) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Strings)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Strings)).Data { - t.Error("Strings: address of copy was the same as original; they should be different") - goto StringArr - } - - if len(cpy.Strings) != len(test.Strings) { - t.Errorf("Strings: len was %d; want %d", len(cpy.Strings), len(test.Strings)) - goto StringArr - } - for i, v := range test.Strings { - if v != cpy.Strings[i] { - t.Errorf("Strings: got %v at index %d of the copy; want %v", cpy.Strings[i], i, v) - } - } - -StringArr: - if unsafe.Pointer(&test.StringArr) == unsafe.Pointer(&cpy.StringArr) { - t.Error("StringArr: address of copy was the same as original; they should be different") - goto Bools - } - for i, v := range test.StringArr { - if v != cpy.StringArr[i] { - t.Errorf("StringArr: got %v at index %d of the copy; want %v", cpy.StringArr[i], i, v) - } - } - -Bools: - if cpy.Bool != test.Bool { - t.Errorf("Bool: got %v; want %v", cpy.Bool, test.Bool) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Bools)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Bools)).Data { - t.Error("Bools: address of copy was the same as original; they should be different") - goto Bytes - } - if len(cpy.Bools) != len(test.Bools) { - t.Errorf("Bools: len was %d; want %d", len(cpy.Bools), len(test.Bools)) - goto Bytes - } - for i, v := range test.Bools { - if v != cpy.Bools[i] { - t.Errorf("Bools: got %v at index %d of the copy; want %v", cpy.Bools[i], i, v) - } - } - -Bytes: - if cpy.Byte != test.Byte { - t.Errorf("Byte: got %v; want %v", cpy.Byte, test.Byte) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Bytes)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Bytes)).Data { - t.Error("Bytes: address of copy was the same as original; they should be different") - goto Ints - } - if len(cpy.Bytes) != len(test.Bytes) { - t.Errorf("Bytes: len was %d; want %d", len(cpy.Bytes), len(test.Bytes)) - goto Ints - } - for i, v := range test.Bytes { - if v != cpy.Bytes[i] { - t.Errorf("Bytes: got %v at index %d of the copy; want %v", cpy.Bytes[i], i, v) - } - } - -Ints: - if cpy.Int != test.Int { - t.Errorf("Int: got %v; want %v", cpy.Int, test.Int) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Ints)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Ints)).Data { - t.Error("Ints: address of copy was the same as original; they should be different") - goto Int8s - } - if len(cpy.Ints) != len(test.Ints) { - t.Errorf("Ints: len was %d; want %d", len(cpy.Ints), len(test.Ints)) - goto Int8s - } - for i, v := range test.Ints { - if v != cpy.Ints[i] { - t.Errorf("Ints: got %v at index %d of the copy; want %v", cpy.Ints[i], i, v) - } - } - -Int8s: - if cpy.Int8 != test.Int8 { - t.Errorf("Int8: got %v; want %v", cpy.Int8, test.Int8) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Int8s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Int8s)).Data { - t.Error("Int8s: address of copy was the same as original; they should be different") - goto Int16s - } - if len(cpy.Int8s) != len(test.Int8s) { - t.Errorf("Int8s: len was %d; want %d", len(cpy.Int8s), len(test.Int8s)) - goto Int16s - } - for i, v := range test.Int8s { - if v != cpy.Int8s[i] { - t.Errorf("Int8s: got %v at index %d of the copy; want %v", cpy.Int8s[i], i, v) - } - } - -Int16s: - if cpy.Int16 != test.Int16 { - t.Errorf("Int16: got %v; want %v", cpy.Int16, test.Int16) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Int16s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Int16s)).Data { - t.Error("Int16s: address of copy was the same as original; they should be different") - goto Int32s - } - if len(cpy.Int16s) != len(test.Int16s) { - t.Errorf("Int16s: len was %d; want %d", len(cpy.Int16s), len(test.Int16s)) - goto Int32s - } - for i, v := range test.Int16s { - if v != cpy.Int16s[i] { - t.Errorf("Int16s: got %v at index %d of the copy; want %v", cpy.Int16s[i], i, v) - } - } - -Int32s: - if cpy.Int32 != test.Int32 { - t.Errorf("Int32: got %v; want %v", cpy.Int32, test.Int32) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Int32s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Int32s)).Data { - t.Error("Int32s: address of copy was the same as original; they should be different") - goto Int64s - } - if len(cpy.Int32s) != len(test.Int32s) { - t.Errorf("Int32s: len was %d; want %d", len(cpy.Int32s), len(test.Int32s)) - goto Int64s - } - for i, v := range test.Int32s { - if v != cpy.Int32s[i] { - t.Errorf("Int32s: got %v at index %d of the copy; want %v", cpy.Int32s[i], i, v) - } - } - -Int64s: - if cpy.Int64 != test.Int64 { - t.Errorf("Int64: got %v; want %v", cpy.Int64, test.Int64) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Int64s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Int64s)).Data { - t.Error("Int64s: address of copy was the same as original; they should be different") - goto Uints - } - if len(cpy.Int64s) != len(test.Int64s) { - t.Errorf("Int64s: len was %d; want %d", len(cpy.Int64s), len(test.Int64s)) - goto Uints - } - for i, v := range test.Int64s { - if v != cpy.Int64s[i] { - t.Errorf("Int64s: got %v at index %d of the copy; want %v", cpy.Int64s[i], i, v) - } - } - -Uints: - if cpy.Uint != test.Uint { - t.Errorf("Uint: got %v; want %v", cpy.Uint, test.Uint) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Uints)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Uints)).Data { - t.Error("Uints: address of copy was the same as original; they should be different") - goto Uint8s - } - if len(cpy.Uints) != len(test.Uints) { - t.Errorf("Uints: len was %d; want %d", len(cpy.Uints), len(test.Uints)) - goto Uint8s - } - for i, v := range test.Uints { - if v != cpy.Uints[i] { - t.Errorf("Uints: got %v at index %d of the copy; want %v", cpy.Uints[i], i, v) - } - } - -Uint8s: - if cpy.Uint8 != test.Uint8 { - t.Errorf("Uint8: got %v; want %v", cpy.Uint8, test.Uint8) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Uint8s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Uint8s)).Data { - t.Error("Uint8s: address of copy was the same as original; they should be different") - goto Uint16s - } - if len(cpy.Uint8s) != len(test.Uint8s) { - t.Errorf("Uint8s: len was %d; want %d", len(cpy.Uint8s), len(test.Uint8s)) - goto Uint16s - } - for i, v := range test.Uint8s { - if v != cpy.Uint8s[i] { - t.Errorf("Uint8s: got %v at index %d of the copy; want %v", cpy.Uint8s[i], i, v) - } - } - -Uint16s: - if cpy.Uint16 != test.Uint16 { - t.Errorf("Uint16: got %v; want %v", cpy.Uint16, test.Uint16) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Uint16s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Uint16s)).Data { - t.Error("Uint16s: address of copy was the same as original; they should be different") - goto Uint32s - } - if len(cpy.Uint16s) != len(test.Uint16s) { - t.Errorf("Uint16s: len was %d; want %d", len(cpy.Uint16s), len(test.Uint16s)) - goto Uint32s - } - for i, v := range test.Uint16s { - if v != cpy.Uint16s[i] { - t.Errorf("Uint16s: got %v at index %d of the copy; want %v", cpy.Uint16s[i], i, v) - } - } - -Uint32s: - if cpy.Uint32 != test.Uint32 { - t.Errorf("Uint32: got %v; want %v", cpy.Uint32, test.Uint32) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Uint32s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Uint32s)).Data { - t.Error("Uint32s: address of copy was the same as original; they should be different") - goto Uint64s - } - if len(cpy.Uint32s) != len(test.Uint32s) { - t.Errorf("Uint32s: len was %d; want %d", len(cpy.Uint32s), len(test.Uint32s)) - goto Uint64s - } - for i, v := range test.Uint32s { - if v != cpy.Uint32s[i] { - t.Errorf("Uint32s: got %v at index %d of the copy; want %v", cpy.Uint32s[i], i, v) - } - } - -Uint64s: - if cpy.Uint64 != test.Uint64 { - t.Errorf("Uint64: got %v; want %v", cpy.Uint64, test.Uint64) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Uint64s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Uint64s)).Data { - t.Error("Uint64s: address of copy was the same as original; they should be different") - goto Float32s - } - if len(cpy.Uint64s) != len(test.Uint64s) { - t.Errorf("Uint64s: len was %d; want %d", len(cpy.Uint64s), len(test.Uint64s)) - goto Float32s - } - for i, v := range test.Uint64s { - if v != cpy.Uint64s[i] { - t.Errorf("Uint64s: got %v at index %d of the copy; want %v", cpy.Uint64s[i], i, v) - } - } - -Float32s: - if cpy.Float32 != test.Float32 { - t.Errorf("Float32: got %v; want %v", cpy.Float32, test.Float32) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Float32s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Float32s)).Data { - t.Error("Float32s: address of copy was the same as original; they should be different") - goto Float64s - } - if len(cpy.Float32s) != len(test.Float32s) { - t.Errorf("Float32s: len was %d; want %d", len(cpy.Float32s), len(test.Float32s)) - goto Float64s - } - for i, v := range test.Float32s { - if v != cpy.Float32s[i] { - t.Errorf("Float32s: got %v at index %d of the copy; want %v", cpy.Float32s[i], i, v) - } - } - -Float64s: - if cpy.Float64 != test.Float64 { - t.Errorf("Float64: got %v; want %v", cpy.Float64, test.Float64) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Float64s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Float64s)).Data { - t.Error("Float64s: address of copy was the same as original; they should be different") - goto Complex64s - } - if len(cpy.Float64s) != len(test.Float64s) { - t.Errorf("Float64s: len was %d; want %d", len(cpy.Float64s), len(test.Float64s)) - goto Complex64s - } - for i, v := range test.Float64s { - if v != cpy.Float64s[i] { - t.Errorf("Float64s: got %v at index %d of the copy; want %v", cpy.Float64s[i], i, v) - } - } - -Complex64s: - if cpy.Complex64 != test.Complex64 { - t.Errorf("Complex64: got %v; want %v", cpy.Complex64, test.Complex64) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Complex64s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Complex64s)).Data { - t.Error("Complex64s: address of copy was the same as original; they should be different") - goto Complex128s - } - if len(cpy.Complex64s) != len(test.Complex64s) { - t.Errorf("Complex64s: len was %d; want %d", len(cpy.Complex64s), len(test.Complex64s)) - goto Complex128s - } - for i, v := range test.Complex64s { - if v != cpy.Complex64s[i] { - t.Errorf("Complex64s: got %v at index %d of the copy; want %v", cpy.Complex64s[i], i, v) - } - } - -Complex128s: - if cpy.Complex128 != test.Complex128 { - t.Errorf("Complex128s: got %v; want %v", cpy.Complex128s, test.Complex128s) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Complex128s)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Complex128s)).Data { - t.Error("Complex128s: address of copy was the same as original; they should be different") - goto Interfaces - } - if len(cpy.Complex128s) != len(test.Complex128s) { - t.Errorf("Complex128s: len was %d; want %d", len(cpy.Complex128s), len(test.Complex128s)) - goto Interfaces - } - for i, v := range test.Complex128s { - if v != cpy.Complex128s[i] { - t.Errorf("Complex128s: got %v at index %d of the copy; want %v", cpy.Complex128s[i], i, v) - } - } - -Interfaces: - if cpy.Interface != test.Interface { - t.Errorf("Interface: got %v; want %v", cpy.Interface, test.Interface) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&test.Interfaces)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.Interfaces)).Data { - t.Error("Interfaces: address of copy was the same as original; they should be different") - return - } - if len(cpy.Interfaces) != len(test.Interfaces) { - t.Errorf("Interfaces: len was %d; want %d", len(cpy.Interfaces), len(test.Interfaces)) - return - } - for i, v := range test.Interfaces { - if v != cpy.Interfaces[i] { - t.Errorf("Interfaces: got %v at index %d of the copy; want %v", cpy.Interfaces[i], i, v) - } - } -} - -// not meant to be exhaustive -func TestComplexSlices(t *testing.T) { - orig3Int := [][][]int{[][]int{[]int{1, 2, 3}, []int{11, 22, 33}}, [][]int{[]int{7, 8, 9}, []int{66, 77, 88, 99}}} - cpyI := Copy(orig3Int).([][][]int) - if (*reflect.SliceHeader)(unsafe.Pointer(&orig3Int)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyI)).Data { - t.Error("[][][]int: address of copy was the same as original; they should be different") - return - } - if len(orig3Int) != len(cpyI) { - t.Errorf("[][][]int: len of copy was %d; want %d", len(cpyI), len(orig3Int)) - goto sliceMap - } - for i, v := range orig3Int { - if len(v) != len(cpyI[i]) { - t.Errorf("[][][]int: len of element %d was %d; want %d", i, len(cpyI[i]), len(v)) - continue - } - for j, vv := range v { - if len(vv) != len(cpyI[i][j]) { - t.Errorf("[][][]int: len of element %d:%d was %d, want %d", i, j, len(cpyI[i][j]), len(vv)) - continue - } - for k, vvv := range vv { - if vvv != cpyI[i][j][k] { - t.Errorf("[][][]int: element %d:%d:%d was %d, want %d", i, j, k, cpyI[i][j][k], vvv) - } - } - } - - } - -sliceMap: - slMap := []map[int]string{map[int]string{0: "a", 1: "b"}, map[int]string{10: "k", 11: "l", 12: "m"}} - cpyM := Copy(slMap).([]map[int]string) - if (*reflect.SliceHeader)(unsafe.Pointer(&slMap)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpyM)).Data { - t.Error("[]map[int]string: address of copy was the same as original; they should be different") - } - if len(slMap) != len(cpyM) { - t.Errorf("[]map[int]string: len of copy was %d; want %d", len(cpyM), len(slMap)) - goto done - } - for i, v := range slMap { - if len(v) != len(cpyM[i]) { - t.Errorf("[]map[int]string: len of element %d was %d; want %d", i, len(cpyM[i]), len(v)) - continue - } - for k, vv := range v { - val, ok := cpyM[i][k] - if !ok { - t.Errorf("[]map[int]string: element %d was expected to have a value at key %d, it didn't", i, k) - continue - } - if val != vv { - t.Errorf("[]map[int]string: element %d, key %d: got %s, want %s", i, k, val, vv) - } - } - } -done: -} - -type A struct { - Int int - String string - UintSl []uint - NilSl []string - Map map[string]int - MapB map[string]*B - SliceB []B - B - T time.Time -} - -type B struct { - Vals []string -} - -var AStruct = A{ - Int: 42, - String: "Konichiwa", - UintSl: []uint{0, 1, 2, 3}, - Map: map[string]int{"a": 1, "b": 2}, - MapB: map[string]*B{ - "hi": &B{Vals: []string{"hello", "bonjour"}}, - "bye": &B{Vals: []string{"good-bye", "au revoir"}}, - }, - SliceB: []B{ - B{Vals: []string{"Ciao", "Aloha"}}, - }, - B: B{Vals: []string{"42"}}, - T: time.Now(), -} - -func TestStructA(t *testing.T) { - cpy := Copy(AStruct).(A) - if &cpy == &AStruct { - t.Error("expected copy to have a different address than the original; it was the same") - return - } - if cpy.Int != AStruct.Int { - t.Errorf("A.Int: got %v, want %v", cpy.Int, AStruct.Int) - } - if cpy.String != AStruct.String { - t.Errorf("A.String: got %v; want %v", cpy.String, AStruct.String) - } - if (*reflect.SliceHeader)(unsafe.Pointer(&cpy.UintSl)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&AStruct.UintSl)).Data { - t.Error("A.Uintsl: expected the copies address to be different; it wasn't") - goto NilSl - } - if len(cpy.UintSl) != len(AStruct.UintSl) { - t.Errorf("A.UintSl: got len of %d, want %d", len(cpy.UintSl), len(AStruct.UintSl)) - goto NilSl - } - for i, v := range AStruct.UintSl { - if cpy.UintSl[i] != v { - t.Errorf("A.UintSl %d: got %d, want %d", i, cpy.UintSl[i], v) - } - } - -NilSl: - if cpy.NilSl != nil { - t.Error("A.NilSl: expected slice to be nil, it wasn't") - } - - if *(*uintptr)(unsafe.Pointer(&cpy.Map)) == *(*uintptr)(unsafe.Pointer(&AStruct.Map)) { - t.Error("A.Map: expected the copy's address to be different; it wasn't") - goto AMapB - } - if len(cpy.Map) != len(AStruct.Map) { - t.Errorf("A.Map: got len of %d, want %d", len(cpy.Map), len(AStruct.Map)) - goto AMapB - } - for k, v := range AStruct.Map { - val, ok := cpy.Map[k] - if !ok { - t.Errorf("A.Map: expected the key %s to exist in the copy, it didn't", k) - continue - } - if val != v { - t.Errorf("A.Map[%s]: got %d, want %d", k, val, v) - } - } - -AMapB: - if *(*uintptr)(unsafe.Pointer(&cpy.MapB)) == *(*uintptr)(unsafe.Pointer(&AStruct.MapB)) { - t.Error("A.MapB: expected the copy's address to be different; it wasn't") - goto ASliceB - } - if len(cpy.MapB) != len(AStruct.MapB) { - t.Errorf("A.MapB: got len of %d, want %d", len(cpy.MapB), len(AStruct.MapB)) - goto ASliceB - } - for k, v := range AStruct.MapB { - val, ok := cpy.MapB[k] - if !ok { - t.Errorf("A.MapB: expected the key %s to exist in the copy, it didn't", k) - continue - } - if unsafe.Pointer(val) == unsafe.Pointer(v) { - t.Errorf("A.MapB[%s]: expected the addresses of the values to be different; they weren't", k) - continue - } - // the slice headers should point to different data - if (*reflect.SliceHeader)(unsafe.Pointer(&v.Vals)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&val.Vals)).Data { - t.Errorf("%s: expected B's SliceHeaders to point to different Data locations; they did not.", k) - continue - } - for i, vv := range v.Vals { - if vv != val.Vals[i] { - t.Errorf("A.MapB[%s].Vals[%d]: got %s want %s", k, i, vv, val.Vals[i]) - } - } - } - -ASliceB: - if (*reflect.SliceHeader)(unsafe.Pointer(&AStruct.SliceB)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.SliceB)).Data { - t.Error("A.SliceB: expected the copy's address to be different; it wasn't") - goto B - } - - if len(AStruct.SliceB) != len(cpy.SliceB) { - t.Errorf("A.SliceB: got length of %d; want %d", len(cpy.SliceB), len(AStruct.SliceB)) - goto B - } - - for i := range AStruct.SliceB { - if unsafe.Pointer(&AStruct.SliceB[i]) == unsafe.Pointer(&cpy.SliceB[i]) { - t.Errorf("A.SliceB[%d]: expected them to have different addresses, they didn't", i) - continue - } - if (*reflect.SliceHeader)(unsafe.Pointer(&AStruct.SliceB[i].Vals)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.SliceB[i].Vals)).Data { - t.Errorf("A.SliceB[%d]: expected B.Vals SliceHeader.Data to point to different locations; they did not", i) - continue - } - if len(AStruct.SliceB[i].Vals) != len(cpy.SliceB[i].Vals) { - t.Errorf("A.SliceB[%d]: expected B's vals to have the same length, they didn't", i) - continue - } - for j, val := range AStruct.SliceB[i].Vals { - if val != cpy.SliceB[i].Vals[j] { - t.Errorf("A.SliceB[%d].Vals[%d]: got %v; want %v", i, j, cpy.SliceB[i].Vals[j], val) - } - } - } -B: - if unsafe.Pointer(&AStruct.B) == unsafe.Pointer(&cpy.B) { - t.Error("A.B: expected them to have different addresses, they didn't") - goto T - } - if (*reflect.SliceHeader)(unsafe.Pointer(&AStruct.B.Vals)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&cpy.B.Vals)).Data { - t.Error("A.B.Vals: expected the SliceHeaders.Data to point to different locations; they didn't") - goto T - } - if len(AStruct.B.Vals) != len(cpy.B.Vals) { - t.Error("A.B.Vals: expected their lengths to be the same, they weren't") - goto T - } - for i, v := range AStruct.B.Vals { - if v != cpy.B.Vals[i] { - t.Errorf("A.B.Vals[%d]: got %s want %s", i, cpy.B.Vals[i], v) - } - } -T: - if fmt.Sprintf("%p", &AStruct.T) == fmt.Sprintf("%p", &cpy.T) { - t.Error("A.T: expected them to have different addresses, they didn't") - return - } - if AStruct.T != cpy.T { - t.Errorf("A.T: got %v, want %v", cpy.T, AStruct.T) - } -} - -type Unexported struct { - A string - B int - aa string - bb int - cc []int - dd map[string]string -} - -func TestUnexportedFields(t *testing.T) { - u := &Unexported{ - A: "A", - B: 42, - aa: "aa", - bb: 42, - cc: []int{1, 2, 3}, - dd: map[string]string{"hello": "bonjour"}, - } - cpy := Copy(u).(*Unexported) - if cpy == u { - t.Error("expected addresses to be different, they weren't") - return - } - if u.A != cpy.A { - t.Errorf("Unexported.A: got %s want %s", cpy.A, u.A) - } - if u.B != cpy.B { - t.Errorf("Unexported.A: got %d want %d", cpy.B, u.B) - } - if cpy.aa != "" { - t.Errorf("Unexported.aa: unexported field should not be set, it was set to %s", cpy.aa) - } - if cpy.bb != 0 { - t.Errorf("Unexported.bb: unexported field should not be set, it was set to %d", cpy.bb) - } - if cpy.cc != nil { - t.Errorf("Unexported.cc: unexported field should not be set, it was set to %#v", cpy.cc) - } - if cpy.dd != nil { - t.Errorf("Unexported.dd: unexported field should not be set, it was set to %#v", cpy.dd) - } -} - -// Note: this test will fail until https://github.com/golang/go/issues/15716 is -// fixed and the version it is part of gets released. -type T struct { - time.Time -} - -func TestTimeCopy(t *testing.T) { - tests := []struct { - Y int - M time.Month - D int - h int - m int - s int - nsec int - TZ string - }{ - {2016, time.July, 4, 23, 11, 33, 3000, "America/New_York"}, - {2015, time.October, 31, 9, 44, 23, 45935, "UTC"}, - {2014, time.May, 5, 22, 01, 50, 219300, "Europe/Prague"}, - } - - for i, test := range tests { - l, err := time.LoadLocation(test.TZ) - if err != nil { - t.Errorf("%d: unexpected error: %s", i, err) - continue - } - var x T - x.Time = time.Date(test.Y, test.M, test.D, test.h, test.m, test.s, test.nsec, l) - c := Copy(x).(T) - if fmt.Sprintf("%p", &c) == fmt.Sprintf("%p", &x) { - t.Errorf("%d: expected the copy to have a different address than the original value; they were the same: %p %p", i, &c, &x) - continue - } - if x.UnixNano() != c.UnixNano() { - t.Errorf("%d: nanotime: got %v; want %v", i, c.UnixNano(), x.UnixNano()) - continue - } - if x.Location() != c.Location() { - t.Errorf("%d: location: got %q; want %q", i, c.Location(), x.Location()) - } - } -} - -func TestPointerToStruct(t *testing.T) { - type Foo struct { - Bar int - } - - f := &Foo{Bar: 42} - cpy := Copy(f) - if f == cpy { - t.Errorf("expected copy to point to a different location: orig: %p; copy: %p", f, cpy) - } - if !reflect.DeepEqual(f, cpy) { - t.Errorf("expected the copy to be equal to the original (except for memory location); it wasn't: got %#v; want %#v", f, cpy) - } -} - -func TestIssue9(t *testing.T) { - // simple pointer copy - x := 42 - testA := map[string]*int{ - "a": nil, - "b": &x, - } - copyA := Copy(testA).(map[string]*int) - if unsafe.Pointer(&testA) == unsafe.Pointer(©A) { - t.Fatalf("expected the map pointers to be different: testA: %v\tcopyA: %v", unsafe.Pointer(&testA), unsafe.Pointer(©A)) - } - if !reflect.DeepEqual(testA, copyA) { - t.Errorf("got %#v; want %#v", copyA, testA) - } - if testA["b"] == copyA["b"] { - t.Errorf("entries for 'b' pointed to the same address: %v; expected them to point to different addresses", testA["b"]) - } - - // map copy - type Foo struct { - Alpha string - } - - type Bar struct { - Beta string - Gamma int - Delta *Foo - } - - type Biz struct { - Epsilon map[int]*Bar - } - - testB := Biz{ - Epsilon: map[int]*Bar{ - 0: &Bar{}, - 1: &Bar{ - Beta: "don't panic", - Gamma: 42, - Delta: nil, - }, - 2: &Bar{ - Beta: "sudo make me a sandwich.", - Gamma: 11, - Delta: &Foo{ - Alpha: "okay.", - }, - }, - }, - } - - copyB := Copy(testB).(Biz) - if !reflect.DeepEqual(testB, copyB) { - t.Errorf("got %#v; want %#v", copyB, testB) - return - } - - // check that the maps point to different locations - if unsafe.Pointer(&testB.Epsilon) == unsafe.Pointer(©B.Epsilon) { - t.Fatalf("expected the map pointers to be different; they weren't: testB: %v\tcopyB: %v", unsafe.Pointer(&testB.Epsilon), unsafe.Pointer(©B.Epsilon)) - } - - for k, v := range testB.Epsilon { - if v == nil && copyB.Epsilon[k] == nil { - continue - } - if v == nil && copyB.Epsilon[k] != nil { - t.Errorf("%d: expected copy of a nil entry to be nil; it wasn't: %#v", k, copyB.Epsilon[k]) - continue - } - if v == copyB.Epsilon[k] { - t.Errorf("entries for '%d' pointed to the same address: %v; expected them to point to different addresses", k, v) - continue - } - if v.Beta != copyB.Epsilon[k].Beta { - t.Errorf("%d.Beta: got %q; want %q", k, copyB.Epsilon[k].Beta, v.Beta) - } - if v.Gamma != copyB.Epsilon[k].Gamma { - t.Errorf("%d.Gamma: got %d; want %d", k, copyB.Epsilon[k].Gamma, v.Gamma) - } - if v.Delta == nil && copyB.Epsilon[k].Delta == nil { - continue - } - if v.Delta == nil && copyB.Epsilon[k].Delta != nil { - t.Errorf("%d.Delta: got %#v; want nil", k, copyB.Epsilon[k].Delta) - } - if v.Delta == copyB.Epsilon[k].Delta { - t.Errorf("%d.Delta: expected the pointers to be different, they were the same: %v", k, v.Delta) - continue - } - if v.Delta.Alpha != copyB.Epsilon[k].Delta.Alpha { - t.Errorf("%d.Delta.Foo: got %q; want %q", k, v.Delta.Alpha, copyB.Epsilon[k].Delta.Alpha) - } - } - - // test that map keys are deep copied - testC := map[*Foo][]string{ - &Foo{Alpha: "Henry Dorsett Case"}: []string{ - "Cutter", - }, - &Foo{Alpha: "Molly Millions"}: []string{ - "Rose Kolodny", - "Cat Mother", - "Steppin' Razor", - }, - } - - copyC := Copy(testC).(map[*Foo][]string) - if unsafe.Pointer(&testC) == unsafe.Pointer(©C) { - t.Fatalf("expected the map pointers to be different; they weren't: testB: %v\tcopyB: %v", unsafe.Pointer(&testB.Epsilon), unsafe.Pointer(©B.Epsilon)) - } - - // make sure the lengths are the same - if len(testC) != len(copyC) { - t.Fatalf("got len %d; want %d", len(copyC), len(testC)) - } - - // check that everything was deep copied: since the key is a pointer, we check to - // see if the pointers are different but the values being pointed to are the same. - for k, v := range testC { - for kk, vv := range copyC { - if *kk == *k { - if kk == k { - t.Errorf("key pointers should be different: orig: %p; copy: %p", k, kk) - } - // check that the slices are the same but different - if !reflect.DeepEqual(v, vv) { - t.Errorf("expected slice contents to be the same; they weren't: orig: %v; copy: %v", v, vv) - } - - if (*reflect.SliceHeader)(unsafe.Pointer(&v)).Data == (*reflect.SliceHeader)(unsafe.Pointer(&vv)).Data { - t.Errorf("expected the SliceHeaders.Data to point to different locations; they didn't: %v", (*reflect.SliceHeader)(unsafe.Pointer(&v)).Data) - } - break - } - } - } - - type Bizz struct { - *Foo - } - - testD := map[Bizz]string{ - Bizz{&Foo{"Neuromancer"}}: "Rio", - Bizz{&Foo{"Wintermute"}}: "Berne", - } - copyD := Copy(testD).(map[Bizz]string) - if len(copyD) != len(testD) { - t.Fatalf("copy had %d elements; expected %d", len(copyD), len(testD)) - } - - for k, v := range testD { - var found bool - for kk, vv := range copyD { - if reflect.DeepEqual(k, kk) { - found = true - // check that Foo points to different locations - if unsafe.Pointer(k.Foo) == unsafe.Pointer(kk.Foo) { - t.Errorf("Expected Foo to point to different locations; they didn't: orig: %p; copy %p", k.Foo, kk.Foo) - break - } - if *k.Foo != *kk.Foo { - t.Errorf("Expected copy of the key's Foo field to have the same value as the original, it wasn't: orig: %#v; copy: %#v", k.Foo, kk.Foo) - } - if v != vv { - t.Errorf("Expected the values to be the same; the weren't: got %v; want %v", vv, v) - } - } - } - if !found { - t.Errorf("expected key %v to exist in the copy; it didn't", k) - } - } -} - -type I struct { - A string -} - -func (i *I) DeepCopy() interface{} { - return &I{A: "custom copy"} -} - -type NestI struct { - I *I -} - -func TestInterface(t *testing.T) { - i := &I{A: "A"} - copied := Copy(i).(*I) - if copied.A != "custom copy" { - t.Errorf("expected value %v, but it's %v", "custom copy", copied.A) - } - // check for nesting values - ni := &NestI{I: &I{A: "A"}} - copiedNest := Copy(ni).(*NestI) - if copiedNest.I.A != "custom copy" { - t.Errorf("expected value %v, but it's %v", "custom copy", copiedNest.I.A) - } -} diff --git a/hashing/hasher_test.go b/hashing/hasher_test.go index b496589..e9532d6 100644 --- a/hashing/hasher_test.go +++ b/hashing/hasher_test.go @@ -7,14 +7,14 @@ import ( "os" "strconv" "testing" + + "github.com/stretchr/testify/assert" + ) func TestMurmurHashOnAlphanumericData(t *testing.T) { inFile, err := os.Open("../testdata/murmur3-sample-data-v2.csv") - if err != nil { - t.Error("Missing test file...") - return - } + assert.Nil(t, err) defer inFile.Close() reader := csv.NewReader(bufio.NewReader(inFile)) @@ -32,19 +32,13 @@ func TestMurmurHashOnAlphanumericData(t *testing.T) { digest, _ := strconv.ParseUint(arr[2], 10, 32) calculated := NewMurmur332Hasher(uint32(seed)).Hash([]byte(str)) - if calculated != uint32(digest) { - t.Errorf("%d: Murmur hash calculation failed for string %s. Should be %d and was %d", line, str, digest, calculated) - break - } + assert.Equal(t, calculated, uint32(digest)) } } func TestMurmurHashOnNonAlphanumericData(t *testing.T) { inFile, err := os.Open("../testdata/murmur3-sample-data-non-alpha-numeric-v2.csv") - if err != nil { - t.Error("Missing test file...") - return - } + assert.Nil(t, err) defer inFile.Close() reader := csv.NewReader(bufio.NewReader(inFile)) @@ -62,9 +56,6 @@ func TestMurmurHashOnNonAlphanumericData(t *testing.T) { digest, _ := strconv.ParseUint(arr[2], 10, 32) calculated := NewMurmur332Hasher(uint32(seed)).Hash([]byte(str)) - if calculated != uint32(digest) { - t.Errorf("%d: Murmur hash calculation failed for string %s. Should be %d and was %d", line, str, digest, calculated) - break - } + assert.Equal(t, calculated, uint32(digest)) } } diff --git a/hashing/murmur128_test.go b/hashing/murmur128_test.go index 0aacb79..30fa1ae 100644 --- a/hashing/murmur128_test.go +++ b/hashing/murmur128_test.go @@ -1,17 +1,17 @@ package hashing import ( - "os" + "os" "strconv" "strings" "testing" + + "github.com/stretchr/testify/assert" ) func TestMurmur128(t *testing.T) { raw, err := os.ReadFile("../testdata/murmur3_64_uuids.csv") - if err != nil { - t.Error("error reading murmur128 test cases files: ", err.Error()) - } + assert.Nil(t, err) lines := strings.Split(string(raw), "\n") for _, line := range lines { @@ -23,8 +23,6 @@ func TestMurmur128(t *testing.T) { expected, _ := strconv.ParseInt(fields[2], 10, 64) h1, _ := Sum128WithSeed([]byte(fields[0]), uint32(seed)) - if int64(h1) != expected { - t.Errorf("Hashes don't match. Expected: %d, actual: %d", expected, uint64(h1)) - } + assert.Equal(t, expected, int64(h1)) } } diff --git a/hashing/util_test.go b/hashing/util_test.go index 19f5810..691adbb 100644 --- a/hashing/util_test.go +++ b/hashing/util_test.go @@ -1,22 +1,18 @@ package hashing -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" +) func TestEncode(t *testing.T) { hash, err := Encode(nil, "something") - if hash != "" { - t.Error("Unexpected result") - } - if err == nil || err.Error() != "Hasher could not be nil" { - t.Error("Unexpected error message") - } - + assert.ErrorContains(t, err, "Hasher could not be nil") + assert.Equal(t, "", hash) + hasher := NewMurmur332Hasher(0) hash2, err := Encode(hasher, "something") - if err != nil { - t.Error("It should not return error") - } - if hash2 != "NDE0MTg0MjI2MQ==" { - t.Error("Unexpected result") - } + assert.Nil(t, err) + assert.Equal(t, "NDE0MTg0MjI2MQ==", hash2) } diff --git a/redis/helpers/helpers_test.go b/redis/helpers/helpers_test.go index d968aa0..f838af7 100644 --- a/redis/helpers/helpers_test.go +++ b/redis/helpers/helpers_test.go @@ -4,58 +4,38 @@ import ( "errors" "testing" - "github.com/splitio/go-toolkit/v6/redis" "github.com/splitio/go-toolkit/v6/redis/mocks" + "github.com/stretchr/testify/assert" ) func TestEnsureConnected(t *testing.T) { - redisClient := mocks.MockClient{ - PingCall: func() redis.Result { - return &mocks.MockResultOutput{ - ErrCall: func() error { return nil }, - StringCall: func() string { return "PONG" }, - } - }, - } - EnsureConnected(&redisClient) + var resMock mocks.MockResultOutput + resMock.On("String").Return(pong).Once() + resMock.On("Err").Return(nil).Once() + + var clientMock mocks.MockClient + clientMock.On("Ping").Return(&resMock).Once() + EnsureConnected(&clientMock) } func TestEnsureConnectedError(t *testing.T) { - defer func() { - if r := recover(); r != nil { - if r != "Couldn't connect to redis: someError" { - t.Error("Expected \"Couldn't connect to redis: someError\". Got: ", r) - } - } - }() - redisClient := mocks.MockClient{ - PingCall: func() redis.Result { - return &mocks.MockResultOutput{ - ErrCall: func() error { return errors.New("someError") }, - StringCall: func() string { return "" }, - } - }, - } - EnsureConnected(&redisClient) - t.Error("Should not reach this line") + var resMock mocks.MockResultOutput + resMock.On("String").Return("").Once() + resMock.On("Err").Return(errors.New("someError")).Once() + + var clientMock mocks.MockClient + clientMock.On("Ping").Return(&resMock).Once() + + assert.Panics(t, func() { EnsureConnected(&clientMock) }) } func TestEnsureConnectedNotPong(t *testing.T) { - defer func() { - if r := recover(); r != nil { - if r != "Invalid redis ping response when connecting: PANG" { - t.Error("Invalid redis ping response when connecting: PANG", r) - } - } - }() - redisClient := mocks.MockClient{ - PingCall: func() redis.Result { - return &mocks.MockResultOutput{ - ErrCall: func() error { return nil }, - StringCall: func() string { return "PANG" }, - } - }, - } - EnsureConnected(&redisClient) - t.Error("Should not reach this line") + var resMock mocks.MockResultOutput + resMock.On("String").Return("PANG").Once() + resMock.On("Err").Return(nil).Once() + + var clientMock mocks.MockClient + clientMock.On("Ping").Return(&resMock).Once() + + assert.Panics(t, func() { EnsureConnected(&clientMock) }) } diff --git a/redis/mocks/mocks.go b/redis/mocks/mocks.go index fb84349..e24aa1f 100644 --- a/redis/mocks/mocks.go +++ b/redis/mocks/mocks.go @@ -3,324 +3,299 @@ package mocks import ( "time" + "github.com/splitio/go-toolkit/v6/common" "github.com/splitio/go-toolkit/v6/redis" + "github.com/stretchr/testify/mock" ) -// MockResultOutput mocks struct -type MockResultOutput struct { - ErrCall func() error - IntCall func() int64 - StringCall func() string - BoolCall func() bool - DurationCall func() time.Duration - ResultCall func() (int64, error) - ResultStringCall func() (string, error) - MultiCall func() ([]string, error) - MultiInterfaceCall func() ([]interface{}, error) - MapStringStringCall func() (map[string]string, error) -} - -// Int mocks Int -func (m *MockResultOutput) Int() int64 { - return m.IntCall() +type MockClient struct { + mock.Mock } -// Err mocks Err -func (m *MockResultOutput) Err() error { - return m.ErrCall() +// ClusterCountKeysInSlot implements redis.Client. +func (m *MockClient) ClusterCountKeysInSlot(slot int) redis.Result { + return m.Called(slot).Get(0).(redis.Result) } -// String mocks String -func (m *MockResultOutput) String() string { - return m.StringCall() +// ClusterKeysInSlot implements redis.Client. +func (m *MockClient) ClusterKeysInSlot(slot int, count int) redis.Result { + return m.Called(slot, count).Get(0).(redis.Result) } -// Bool mocks Bool -func (m *MockResultOutput) Bool() bool { - return m.BoolCall() +// ClusterMode implements redis.Client. +func (m *MockClient) ClusterMode() bool { + return m.Called().Bool(0) } -// Duration mocks Duration -func (m *MockResultOutput) Duration() time.Duration { - return m.DurationCall() +// ClusterSlotForKey implements redis.Client. +func (m *MockClient) ClusterSlotForKey(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -// Result mocks Result -func (m *MockResultOutput) Result() (int64, error) { - return m.ResultCall() +// Decr implements redis.Client. +func (m *MockClient) Decr(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -// ResultString mocks ResultString -func (m *MockResultOutput) ResultString() (string, error) { - return m.ResultStringCall() +// Del implements redis.Client. +func (m *MockClient) Del(keys ...string) redis.Result { + return m.Called(common.AsInterfaceSlice(keys)...).Get(0).(redis.Result) } -// Multi mocks Multi -func (m *MockResultOutput) Multi() ([]string, error) { - return m.MultiCall() +// Eval implements redis.Client. +func (m *MockClient) Eval(script string, keys []string, args ...interface{}) redis.Result { + return m.Called(append([]interface{}{script, keys}, args...)...).Get(0).(redis.Result) } -// MultiInterface mocks MultiInterface -func (m *MockResultOutput) MultiInterface() ([]interface{}, error) { - return m.MultiInterfaceCall() +// Exists implements redis.Client. +func (m *MockClient) Exists(keys ...string) redis.Result { + return m.Called(common.AsInterfaceSlice(keys)...).Get(0).(redis.Result) } -// MapStringString mocks MapStringString -func (m *MockResultOutput) MapStringString() (map[string]string, error) { - return m.MapStringStringCall() +// Expire implements redis.Client. +func (m *MockClient) Expire(key string, value time.Duration) redis.Result { + return m.Called(key, value).Get(0).(redis.Result) } -// MpockPipeline impl -type MockPipeline struct { - LRangeCall func(key string, start, stop int64) - LTrimCall func(key string, start, stop int64) - LLenCall func(key string) - HIncrByCall func(key string, field string, value int64) - HLenCall func(key string) - SetCall func(key string, value interface{}, expiration time.Duration) - IncrCall func(key string) - DecrCall func(key string) - SAddCall func(key string, members ...interface{}) - SRemCall func(key string, members ...interface{}) - SMembersCall func(key string) - DelCall func(keys ...string) - ExecCall func() ([]redis.Result, error) +// Get implements redis.Client. +func (m *MockClient) Get(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -func (m *MockPipeline) LRange(key string, start, stop int64) { - m.LRangeCall(key, start, stop) +// HGetAll implements redis.Client. +func (m *MockClient) HGetAll(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -func (m *MockPipeline) LTrim(key string, start, stop int64) { - m.LTrimCall(key, start, stop) +// HIncrBy implements redis.Client. +func (m *MockClient) HIncrBy(key string, field string, value int64) redis.Result { + return m.Called(key, field, value).Get(0).(redis.Result) } -func (m *MockPipeline) LLen(key string) { - m.LLenCall(key) +// HSet implements redis.Client. +func (m *MockClient) HSet(key string, hashKey string, value interface{}) redis.Result { + return m.Called(key, hashKey, value).Get(0).(redis.Result) + } -func (m *MockPipeline) HIncrBy(key string, field string, value int64) { - m.HIncrByCall(key, field, value) +// Incr implements redis.Client. +func (m *MockClient) Incr(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -func (m *MockPipeline) HLen(key string) { - m.HLenCall(key) +// Keys implements redis.Client. +func (m *MockClient) Keys(pattern string) redis.Result { + return m.Called(pattern).Get(0).(redis.Result) + } -func (m *MockPipeline) Set(key string, value interface{}, expiration time.Duration) { - m.SetCall(key, value, expiration) +// LLen implements redis.Client. +func (m *MockClient) LLen(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -func (m *MockPipeline) Incr(key string) { - m.IncrCall(key) +// LRange implements redis.Client. +func (m *MockClient) LRange(key string, start int64, stop int64) redis.Result { + return m.Called(key, start, stop).Get(0).(redis.Result) } -func (m *MockPipeline) Decr(key string) { - m.DecrCall(key) +// LTrim implements redis.Client. +func (m *MockClient) LTrim(key string, start int64, stop int64) redis.Result { + return m.Called(key, start, stop).Get(0).(redis.Result) + } -func (m *MockPipeline) SAdd(key string, members ...interface{}) { - m.SAddCall(key, members...) +// MGet implements redis.Client. +func (m *MockClient) MGet(keys []string) redis.Result { + return m.Called(keys).Get(0).(redis.Result) } -func (m *MockPipeline) SRem(key string, members ...interface{}) { - m.SRemCall(key, members...) +// Ping implements redis.Client. +func (m *MockClient) Ping() redis.Result { + return m.Called().Get(0).(redis.Result) } -func (m *MockPipeline) SMembers(key string) { - m.SMembersCall(key) +// Pipeline implements redis.Client. +func (m *MockClient) Pipeline() redis.Pipeline { + return m.Called().Get(0).(redis.Pipeline) } -func (m *MockPipeline) Del(keys ...string) { - m.DelCall(keys...) +// RPush implements redis.Client. +func (m *MockClient) RPush(key string, values ...interface{}) redis.Result { + return m.Called(append([]interface{}{key}, values...)...).Get(0).(redis.Result) + } -func (m *MockPipeline) Exec() ([]redis.Result, error) { - return m.ExecCall() +// SAdd implements redis.Client. +func (m *MockClient) SAdd(key string, members ...interface{}) redis.Result { + return m.Called(append([]interface{}{key}, members...)...).Get(0).(redis.Result) } -// MockClient mocks for testing purposes -type MockClient struct { - ClusterModeCall func() bool - ClusterCountKeysInSlotCall func(slot int) redis.Result - ClusterSlotForKeyCall func(key string) redis.Result - ClusterKeysInSlotCall func(slot int, count int) redis.Result - DelCall func(keys ...string) redis.Result - GetCall func(key string) redis.Result - SetCall func(key string, value interface{}, expiration time.Duration) redis.Result - PingCall func() redis.Result - ExistsCall func(keys ...string) redis.Result - KeysCall func(pattern string) redis.Result - SMembersCall func(key string) redis.Result - SIsMemberCall func(key string, member interface{}) redis.Result - SAddCall func(key string, members ...interface{}) redis.Result - SRemCall func(key string, members ...interface{}) redis.Result - IncrCall func(key string) redis.Result - DecrCall func(key string) redis.Result - RPushCall func(key string, values ...interface{}) redis.Result - LRangeCall func(key string, start, stop int64) redis.Result - LTrimCall func(key string, start, stop int64) redis.Result - LLenCall func(key string) redis.Result - ExpireCall func(key string, value time.Duration) redis.Result - TTLCall func(key string) redis.Result - MGetCall func(keys []string) redis.Result - SCardCall func(key string) redis.Result - EvalCall func(script string, keys []string, args ...interface{}) redis.Result - HIncrByCall func(key string, field string, value int64) redis.Result - HGetAllCall func(key string) redis.Result - HSetCall func(key string, hashKey string, value interface{}) redis.Result - TypeCall func(key string) redis.Result - PipelineCall func() redis.Pipeline - ScanCall func(cursor uint64, match string, count int64) redis.Result +// SCard implements redis.Client. +func (m *MockClient) SCard(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -func (m *MockClient) ClusterMode() bool { - return m.ClusterModeCall() +// SIsMember implements redis.Client. +func (m *MockClient) SIsMember(key string, member interface{}) redis.Result { + return m.Called(key, member).Get(0).(redis.Result) + } -func (m *MockClient) ClusterCountKeysInSlot(slot int) redis.Result { - return m.ClusterCountKeysInSlotCall(slot) +// SMembers implements redis.Client. +func (m *MockClient) SMembers(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -func (m *MockClient) ClusterSlotForKey(key string) redis.Result { - return m.ClusterSlotForKeyCall(key) +// SRem implements redis.Client. +func (m *MockClient) SRem(key string, members ...interface{}) redis.Result { + return m.Called(append([]interface{}{key}, members...)...).Get(0).(redis.Result) } -func (m *MockClient) ClusterKeysInSlot(slot int, count int) redis.Result { - return m.ClusterKeysInSlotCall(slot, count) +// Scan implements redis.Client. +func (m *MockClient) Scan(cursor uint64, match string, count int64) redis.Result { + return m.Called(cursor, match, count).Get(0).(redis.Result) } -// Del mocks get -func (m *MockClient) Del(keys ...string) redis.Result { - return m.DelCall(keys...) +// Set implements redis.Client. +func (m *MockClient) Set(key string, value interface{}, expiration time.Duration) redis.Result { + return m.Called(key, value, expiration).Get(0).(redis.Result) } -// Get mocks get -func (m *MockClient) Get(key string) redis.Result { - return m.GetCall(key) +// TTL implements redis.Client. +func (m *MockClient) TTL(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -// Set mocks set -func (m *MockClient) Set(key string, value interface{}, expiration time.Duration) redis.Result { - return m.SetCall(key, value, expiration) +// Type implements redis.Client. +func (m *MockClient) Type(key string) redis.Result { + return m.Called(key).Get(0).(redis.Result) } -// Exists mocks set -func (m *MockClient) Exists(keys ...string) redis.Result { - return m.ExistsCall(keys...) +type MockPipeline struct { + mock.Mock } -// Ping mocks ping -func (m *MockClient) Ping() redis.Result { - return m.PingCall() +// Decr implements redis.Pipeline. +func (m *MockPipeline) Decr(key string) { + m.Called(key) } -// Keys mocks keys -func (m *MockClient) Keys(pattern string) redis.Result { - return m.KeysCall(pattern) +// Del implements redis.Pipeline. +func (m *MockPipeline) Del(keys ...string) { + m.Called(common.AsInterfaceSlice(keys)) } -// SMembers mocks SMembers -func (m *MockClient) SMembers(key string) redis.Result { - return m.SMembersCall(key) +// Exec implements redis.Pipeline. +func (m *MockPipeline) Exec() ([]redis.Result, error) { + args := m.Called() + return args.Get(0).([]redis.Result), args.Error(1) } -// SIsMember mocks SIsMember -func (m *MockClient) SIsMember(key string, member interface{}) redis.Result { - return m.SIsMemberCall(key, member) +// HIncrBy implements redis.Pipeline. +func (m *MockPipeline) HIncrBy(key string, field string, value int64) { + m.Called(key, field, value) } -// SAdd mocks SAdd -func (m *MockClient) SAdd(key string, members ...interface{}) redis.Result { - return m.SAddCall(key, members...) +// HLen implements redis.Pipeline. +func (m *MockPipeline) HLen(key string) { + m.Called(key) } -// SRem mocks SRem -func (m *MockClient) SRem(key string, members ...interface{}) redis.Result { - return m.SRemCall(key, members...) +// Incr implements redis.Pipeline. +func (m *MockPipeline) Incr(key string) { + m.Called(key) } -// Incr mocks Incr -func (m *MockClient) Incr(key string) redis.Result { - return m.IncrCall(key) +// LLen implements redis.Pipeline. +func (m *MockPipeline) LLen(key string) { + m.Called(key) } -// Decr mocks Decr -func (m *MockClient) Decr(key string) redis.Result { - return m.DecrCall(key) +// LRange implements redis.Pipeline. +func (m *MockPipeline) LRange(key string, start int64, stop int64) { + m.Called(key, start, stop) } -// RPush mocks RPush -func (m *MockClient) RPush(key string, values ...interface{}) redis.Result { - return m.RPushCall(key, values...) +// LTrim implements redis.Pipeline. +func (m *MockPipeline) LTrim(key string, start int64, stop int64) { + m.Called(key, start, stop) +} + +// SAdd implements redis.Pipeline. +func (m *MockPipeline) SAdd(key string, members ...interface{}) { + m.Called(append([]interface{}{key}, members...)...) } -// LRange mocks LRange -func (m *MockClient) LRange(key string, start, stop int64) redis.Result { - return m.LRangeCall(key, start, stop) +// SMembers implements redis.Pipeline. +func (m *MockPipeline) SMembers(key string) { + m.Called(key) } -// LTrim mocks LTrim -func (m *MockClient) LTrim(key string, start, stop int64) redis.Result { - return m.LTrimCall(key, start, stop) +// SRem implements redis.Pipeline. +func (m *MockPipeline) SRem(key string, members ...interface{}) { + m.Called(append([]interface{}{key}, members...)...) } -// LLen mocks LLen -func (m *MockClient) LLen(key string) redis.Result { - return m.LLenCall(key) +// Set implements redis.Pipeline. +func (m *MockPipeline) Set(key string, value interface{}, expiration time.Duration) { + m.Called(key, value) } -// Expire mocks Expire -func (m *MockClient) Expire(key string, value time.Duration) redis.Result { - return m.ExpireCall(key, value) +type MockResultOutput struct { + mock.Mock } -// TTL mocks TTL -func (m *MockClient) TTL(key string) redis.Result { - return m.TTLCall(key) +// Bool implements redis.Result. +func (m *MockResultOutput) Bool() bool { + return m.Called().Bool(0) } -// MGet mocks MGet -func (m *MockClient) MGet(keys []string) redis.Result { - return m.MGetCall(keys) +// Duration implements redis.Result. +func (m *MockResultOutput) Duration() time.Duration { + return m.Called().Get(0).(time.Duration) } -// SCard mocks SCard -func (m *MockClient) SCard(key string) redis.Result { - return m.SCardCall(key) +// Err implements redis.Result. +func (m *MockResultOutput) Err() error { + return m.Called().Error(0) } -// Eval mocks Eval -func (m *MockClient) Eval(script string, keys []string, args ...interface{}) redis.Result { - return m.EvalCall(script, keys, args...) +// Int implements redis.Result. +func (m *MockResultOutput) Int() int64 { + return m.Called().Get(0).(int64) } -// HIncrBy mocks HIncrByCall -func (m *MockClient) HIncrBy(key string, field string, value int64) redis.Result { - return m.HIncrByCall(key, field, value) +func (m *MockResultOutput) String() string { + return m.Called().Get(0).(string) } -// HGetAll mocks HGetAll -func (m *MockClient) HGetAll(key string) redis.Result { - return m.HGetAllCall(key) +// MapStringString implements redis.Result. +func (m *MockResultOutput) MapStringString() (map[string]string, error) { + args := m.Called() + return args.Get(0).(map[string]string), args.Error(1) } -// HSet implements HGetAll wrapper for redis -func (m *MockClient) HSet(key string, hashKey string, value interface{}) redis.Result { - return m.HSetCall(key, hashKey, value) +// Multi implements redis.Result. +func (m *MockResultOutput) Multi() ([]string, error) { + args := m.Called() + return args.Get(0).([]string), args.Error(1) } -// Type implements Type wrapper for redis with prefix -func (m *MockClient) Type(key string) redis.Result { - return m.TypeCall(key) +// MultiInterface implements redis.Result. +func (m *MockResultOutput) MultiInterface() ([]interface{}, error) { + args := m.Called() + return args.Get(0).([]interface{}), args.Error(1) } -// Pipeline mock -func (m *MockClient) Pipeline() redis.Pipeline { - return m.PipelineCall() +// Result implements redis.Result. +func (m *MockResultOutput) Result() (int64, error) { + args := m.Called() + return args.Get(0).(int64), args.Error(1) } -// Scan mock -func (m *MockClient) Scan(cursor uint64, match string, count int64) redis.Result { - return m.ScanCall(cursor, match, count) +// ResultString implements redis.Result. +func (m *MockResultOutput) ResultString() (string, error) { + args := m.Called() + return args.String(0), args.Error(1) } diff --git a/redis/wrapper_test.go b/redis/wrapper_test.go index e5fb32b..cd7b865 100644 --- a/redis/wrapper_test.go +++ b/redis/wrapper_test.go @@ -18,29 +18,20 @@ func TestRedisWrapperKeysAndScan(t *testing.T) { } keys, err := client.Keys("utest*").Multi() - if err != nil { - t.Error("there should not be any error. Got: ", err) - } - - if len(keys) != 10 { - t.Error("should be 10 keys. Got: ", len(keys)) - } - + assert.Nil(t, err) + assert.Equal(t, 10, len(keys)) var cursor uint64 + scanKeys := make([]string, 0) for { result := client.Scan(cursor, "utest*", 10) - if result.Err() != nil { - t.Error("there should not be any error. Got: ", result.Err()) - } + assert.Nil(t, result.Err()) cursor = uint64(result.Int()) keys, err := result.Multi() - if err != nil { - t.Error("there should not be any error. Got: ", err) - } - + assert.Nil(t, err) + scanKeys = append(scanKeys, keys...) if cursor == 0 { @@ -48,10 +39,7 @@ func TestRedisWrapperKeysAndScan(t *testing.T) { } } - if len(scanKeys) != 10 { - t.Error("should be 10 keys. Got: ", len(scanKeys)) - } - + assert.Equal(t, 10, len(scanKeys)) for i := 0; i < 10; i++ { client.Del(fmt.Sprintf("utest.key-del%d", i)) } @@ -84,64 +72,24 @@ func TestRedisWrapperPipeline(t *testing.T) { pipe.Decr("key-incr") pipe.Del([]string{"key-del1", "key-del2"}...) result, err := pipe.Exec() - if err != nil { - t.Error("there should not be any error. Got: ", err) - } - - if len(result) != 14 { - t.Error("there should be 13 elements") - } + assert.Nil(t, err) + assert.Equal(t, 14, len(result)) items, _ := result[0].Multi() assert.Equal(t, []string{"e1", "e2", "e3"}, items) - if l := result[1].Int(); l != 3 { - t.Error("length should be 3. is: ", l) - } - - if i := client.LLen("key1").Int(); i != 1 { - t.Error("new length should be 1. Is: ", i) - } - - if c := result[3].Int(); c != 5 { - t.Error("count should be 5. Is: ", c) - } - - if c := result[4].Int(); c != 4 { - t.Error("count should be 5. Is: ", c) - } - - if c := result[5].Int(); c != 7 { - t.Error("count should be 5. Is: ", c) - } - - if l := result[6].Int(); l != 2 { - t.Error("hlen should be 2. is: ", l) - } - - if ib := client.HIncrBy("key-test", "field-test", 1); ib.Int() != 6 { - t.Error("new count should be 6") - } - - if ib := client.Get("key-set"); ib.String() != "field-test-1" { - t.Error("it should be field-test-1") - } - - if c := result[8].Int(); c != 2 { - t.Error("count should be 2. Is: ", c) - } - if d, _ := result[9].Multi(); len(d) != 2 { - t.Error("count should be 2. Is: ", len(d)) - } - if c := result[10].Int(); c != 2 { - t.Error("count should be 2. Is: ", c) - } - if c := result[11].Int(); c != 1 { - t.Error("count should be 1. Is: ", c) - } - if c := result[12].Int(); c != 0 { - t.Error("count should be zero. Is: ", c) - } - if c := result[13].Int(); c != 2 { - t.Error("count should be 2. Is: ", c) - } + assert.Equal(t, int64(3), result[1].Int()) + assert.Equal(t, int64(1), client.LLen("key1").Int()) + assert.Equal(t, int64(5), result[3].Int()) + assert.Equal(t, int64(4), result[4].Int()) + assert.Equal(t, int64(7), result[5].Int()) + assert.Equal(t, int64(2), result[6].Int()) + assert.Equal(t, int64(6), client.HIncrBy("key-test", "field-test", 1).Int()) + assert.Equal(t, "field-test-1", client.Get("key-set").String()) + assert.Equal(t, int64(2), result[8].Int()) + d, _ := result[9].Multi() + assert.Equal(t, 2, len(d)) + assert.Equal(t, int64(2), result[10].Int()) + assert.Equal(t, int64(1), result[11].Int()) + assert.Equal(t, int64(0), result[12].Int()) + assert.Equal(t, int64(2), result[13].Int()) } diff --git a/sse/event_test.go b/sse/event_test.go index a64e257..d48806a 100644 --- a/sse/event_test.go +++ b/sse/event_test.go @@ -2,6 +2,9 @@ package sse import ( "testing" + + "github.com/stretchr/testify/assert" + ) func TestEventBuilder(t *testing.T) { @@ -13,30 +16,16 @@ func TestEventBuilder(t *testing.T) { builder.AddLine(":some Comment") e := builder.Build() - if e.Event() != "message" { - t.Error("event should be 'message'") - } - if e.Data() != "something" { - t.Error("data should be 'something'") - } - if e.ID() != "1234" { - t.Error("Id should be 1234") - } - if e.Retry() != 1 { - t.Error("retry should be 1234") - } - if e.IsEmpty() { - t.Error("event should not be empty") - } - if e.IsError() { - t.Error("event is not an error") - } + assert.Equal(t, "message", e.Event()) + assert.Equal(t, "something", e.Data()) + assert.Equal(t, "1234", e.ID()) + assert.Equal(t, int64(1), e.Retry()) + assert.False(t, e.IsEmpty()) + assert.False(t, e.IsEmpty()) builder.Reset() builder.AddLine("event: error") builder.AddLine("data: someError") e2 := builder.Build() - if !e2.IsError() { - t.Error("event is an error") - } + assert.True(t, e2.IsError()) } diff --git a/sse/mocks/mocks.go b/sse/mocks/mocks.go index b90ec8c..235f108 100644 --- a/sse/mocks/mocks.go +++ b/sse/mocks/mocks.go @@ -1,34 +1,31 @@ package mocks +import "github.com/stretchr/testify/mock" + type RawEventMock struct { - IDCall func() string - EventCall func() string - DataCall func() string - RetryCall func() int64 - IsErrorCall func() bool - IsEmptyCall func() bool + mock.Mock } func (r *RawEventMock) ID() string { - return r.IDCall() + return r.Called().String(0) } func (r *RawEventMock) Event() string { - return r.EventCall() + return r.Called().String(0) } func (r *RawEventMock) Data() string { - return r.DataCall() + return r.Called().String(0) } func (r *RawEventMock) Retry() int64 { - return r.RetryCall() + return r.Called().Get(0).(int64) } func (r *RawEventMock) IsError() bool { - return r.IsErrorCall() + return r.Called().Bool(0) } func (r *RawEventMock) IsEmpty() bool { - return r.IsEmptyCall() + return r.Called().Bool(0) } diff --git a/sse/sse_test.go b/sse/sse_test.go index 59e256b..a6a2ee2 100644 --- a/sse/sse_test.go +++ b/sse/sse_test.go @@ -1,7 +1,6 @@ package sse import ( - "errors" "fmt" "net/http" "net/http/httptest" @@ -10,16 +9,15 @@ import ( "time" "github.com/splitio/go-toolkit/v6/logging" + "github.com/stretchr/testify/assert" ) func TestSSEErrorConnecting(t *testing.T) { logger := logging.NewLogger(&logging.LoggerOptions{}) client, _ := NewClient("", 120, 10, logger) err := client.Do(make(map[string]string), make(map[string]string), func(e RawEvent) { t.Error("It should not execute anything") }) - asErrConecting := &ErrConnectionFailed{} - if !errors.As(err, &asErrConecting) { - t.Errorf("Unexpected type of error: %+v", err) - } + _, ok := err.(*ErrConnectionFailed) + assert.True(t, ok) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -34,25 +32,19 @@ func TestSSEErrorConnecting(t *testing.T) { mockedClient.lifecycle.Setup() err = mockedClient.Do(make(map[string]string), make(map[string]string), func(e RawEvent) { - t.Error("Should not execute callback") + assert.Fail(t, "Should not execute callback") }) - if !errors.As(err, &asErrConecting) { - t.Errorf("Unexpected type of error: %+v", err) - } + _, ok = err.(*ErrConnectionFailed) + assert.True(t, ok) } func TestSSE(t *testing.T) { logger := logging.NewLogger(&logging.LoggerOptions{}) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("some") != "some" { - t.Error("It should send header") - } - flusher, err := w.(http.Flusher) - if !err { - t.Error("Unexpected error") - return - } + assert.Equal(t, "some", r.Header.Get("some")) + flusher, ok := w.(http.Flusher) + assert.True(t, ok) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -79,18 +71,14 @@ func TestSSE(t *testing.T) { result = e mutextTest.Unlock() }) - if err != nil { - t.Error("sse client ended in error:", err) - } + assert.Nil(t, err) }() time.Sleep(2 * time.Second) mockedClient.Shutdown(true) mutextTest.RLock() - if result.Data() != `{"id":"YCh53QfLxO:0:0","data":"some","timestamp":1591911770828}` { - t.Error("Unexpected result: ", result.Data()) - } + assert.Equal(t, `{"id":"YCh53QfLxO:0:0","data":"some","timestamp":1591911770828}`, result.Data()) mutextTest.RUnlock() } @@ -103,11 +91,8 @@ func TestSSENoTimeout(t *testing.T) { finished := false mutexTest.Unlock() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - flusher, err := w.(http.Flusher) - if !err { - t.Error("Unexpected error") - return - } + flusher, ok := w.(http.Flusher) + assert.True(t, ok) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -129,15 +114,11 @@ func TestSSENoTimeout(t *testing.T) { time.Sleep(1500 * time.Millisecond) mutexTest.RLock() - if finished { - t.Error("It should not be finished") - } + assert.False(t, finished) mutexTest.RUnlock() time.Sleep(1500 * time.Millisecond) mutexTest.RLock() - if !finished { - t.Error("It should be finished") - } + assert.True(t, finished) mutexTest.RUnlock() clientSSE.Shutdown(true) } @@ -146,11 +127,8 @@ func TestStopBlock(t *testing.T) { logger := logging.NewLogger(&logging.LoggerOptions{}) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - flusher, err := w.(http.Flusher) - if !err { - t.Error("Unexpected error") - return - } + flusher, ok := w.(http.Flusher) + assert.True(t, ok) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -172,9 +150,7 @@ func TestStopBlock(t *testing.T) { waiter := make(chan struct{}, 1) go func() { err := mockedClient.Do(make(map[string]string), make(map[string]string), func(e RawEvent) {}) - if err != nil { - t.Error("sse client ended in error: ", err) - } + assert.Nil(t, err) waiter <- struct{}{} }() @@ -187,11 +163,8 @@ func TestConnectionEOF(t *testing.T) { logger := logging.NewLogger(&logging.LoggerOptions{}) var ts *httptest.Server ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - flusher, err := w.(http.Flusher) - if !err { - t.Error("Unexpected error") - return - } + flusher, ok := w.(http.Flusher) + assert.True(t, ok) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -211,60 +184,6 @@ func TestConnectionEOF(t *testing.T) { mockedClient.lifecycle.Setup() err := mockedClient.Do(make(map[string]string), make(map[string]string), func(e RawEvent) {}) - if err != ErrReadingStream { - t.Error("Should have triggered an ErrorReadingStreamError. Got: ", err) - } - + assert.ErrorIs(t, err, ErrReadingStream) mockedClient.Shutdown(true) } - -/* -func TestCustom(t *testing.T) { - url := `https://streaming.split.io/event-stream` - logger := logging.NewLogger(&logging.LoggerOptions{LogLevel: logging.LevelError, StandardLoggerFlags: log.Llongfile}) - client, _ := NewClient(url, 50, logger) - - ready := make(chan struct{}) - accessToken := `` - channels := "NzM2MDI5Mzc0_MTgyNTg1MTgwNg==_splits,[?occupancy=metrics.publishers]control_pri,[?occupancy=metrics.publishers]control_sec" - go func() { - err := client.Do( - map[string]string{ - "accessToken": accessToken, - "v": "1.1", - "channel": channels, - }, - func(e RawEvent) { - fmt.Printf("Event: %+v\n", e) - }) - if err != nil { - t.Error("sse error:", err) - } - ready <- struct{}{} - }() - time.Sleep(5 * time.Second) - client.Shutdown(true) - <-ready - fmt.Println(1) - go func() { - err := client.Do( - map[string]string{ - "accessToken": accessToken, - "v": "1.1", - "channel": channels, - }, - func(e RawEvent) { - fmt.Printf("Event: %+v\n", e) - }) - if err != nil { - t.Error("sse error:", err) - } - ready <- struct{}{} - }() - time.Sleep(5 * time.Second) - client.Shutdown(true) - <-ready - fmt.Println(2) - -} -*/ diff --git a/struct/jsonvalidator/validator_test.go b/struct/jsonvalidator/validator_test.go index c79eb74..3dea230 100644 --- a/struct/jsonvalidator/validator_test.go +++ b/struct/jsonvalidator/validator_test.go @@ -2,6 +2,8 @@ package jsonvalidator import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestLen0(t *testing.T) { @@ -13,12 +15,8 @@ func TestLen0(t *testing.T) { originChild := OriginChild{two: "Test", three: 1} err := ValidateConfiguration(originChild, nil) - if err == nil { - t.Error("Should inform error") - } - if err.Error() != "no configuration provided" { - t.Error("Wrong message") - } + assert.NotNil(t, err) + assert.ErrorContains(t, err, "no configuration provided") } func TestSame(t *testing.T) { @@ -36,9 +34,7 @@ func TestSame(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"one\": 10, \"originChild\": {\"two\": \"test\", \"three\": 10}}")) - if err != nil { - t.Error("Should not inform error") - } + assert.Nil(t, err) } func TestDifferentPropertyParent(t *testing.T) { @@ -56,12 +52,8 @@ func TestDifferentPropertyParent(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"four\": 10, \"originChild\": {\"two\": \"test\", \"three\": 10}}")) - if err == nil { - t.Error("Should inform error") - } - if err.Error() != "\"four\" is not a valid property in configuration" { - t.Error("Wrong message") - } + assert.NotNil(t, err) + assert.ErrorContains(t, err, "\"four\" is not a valid property in configuration") } func TestDifferentPropertyChild(t *testing.T) { @@ -79,12 +71,8 @@ func TestDifferentPropertyChild(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"one\": 10, \"originChild\": {\"two\": \"test\", \"four\": 10}}")) - if err == nil { - t.Error("Should inform error") - } - if err.Error() != "\"originChild.four\" is not a valid property in configuration" { - t.Error("Wrong message", err.Error()) - } + assert.NotNil(t, err) + assert.ErrorContains(t, err, "\"originChild.four\" is not a valid property in configuration") } func TestDifferentParentAndChild(t *testing.T) { @@ -102,12 +90,8 @@ func TestDifferentParentAndChild(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"one\": 10, \"testChild\": {\"two\": \"test\", \"three\": 10}}")) - if err == nil { - t.Error("Should inform error") - } - if err.Error() != "\"testChild\" is not a valid property in configuration" { - t.Error("Wrong message, it should inform parent") - } + assert.NotNil(t, err) + assert.ErrorContains(t, err, "\"testChild\" is not a valid property in configuration") } func TestDifferentPropertyInChild(t *testing.T) { @@ -125,12 +109,8 @@ func TestDifferentPropertyInChild(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"one\": 10, \"originChild\": {\"two\": \"test\", \"three\": 10, \"four\": 10}}")) - if err == nil { - t.Error("Should inform error") - } - if err.Error() != "\"originChild.four\" is not a valid property in configuration" { - t.Error("Wrong message=", err.Error()) - } + assert.NotNil(t, err) + assert.ErrorContains(t, err, "\"originChild.four\" is not a valid property in configuration") } func TestDifferentPropertyInChildBool(t *testing.T) { @@ -148,12 +128,8 @@ func TestDifferentPropertyInChildBool(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"one\": 10, \"originChild\": {\"two\": \"test\", \"three\": 10, \"four\": true}}")) - if err == nil { - t.Error("Should inform error") - } - if err.Error() != "\"originChild.four\" is not a valid property in configuration" { - t.Error("Wrong message=") - } + assert.NotNil(t, err) + assert.ErrorContains(t, err, "\"originChild.four\" is not a valid property in configuration") } func TestDifferentPropertyInChildNumber(t *testing.T) { @@ -171,12 +147,8 @@ func TestDifferentPropertyInChildNumber(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"one\": 10, \"originChild\": {\"two\": \"test\", \"three\": 10, \"four\": 10}}")) - if err == nil { - t.Error("Should inform error") - } - if err.Error() != "\"originChild.four\" is not a valid property in configuration" { - t.Error("Wrong message=") - } + assert.NotNil(t, err) + assert.ErrorContains(t, err, "\"originChild.four\" is not a valid property in configuration") } func TestSameThirdLevel(t *testing.T) { @@ -200,11 +172,7 @@ func TestSameThirdLevel(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"one\": 10, \"originChild\": {\"child\": {\"two\": \"test\", \"three\": 10}, \"three\": 10}}")) - if err != nil { - t.Error(err.Error()) - - t.Error("Should not inform error") - } + assert.Nil(t, err) } func TestDifferenthirdLevel(t *testing.T) { @@ -228,10 +196,6 @@ func TestDifferenthirdLevel(t *testing.T) { origin := Origin{OriginChild: originChild, One: 1} err := ValidateConfiguration(origin, []byte("{\"one\": 10, \"originChild\": {\"child\": {\"t\": \"test\", \"three\": 10}, \"three\": 10}}")) - if err == nil { - t.Error("Should inform error") - } - if err.Error() != "\"originChild.child.t\" is not a valid property in configuration" { - t.Error("Wrong message", err.Error()) - } + assert.NotNil(t, err) + assert.ErrorContains(t, err, "\"originChild.child.t\" is not a valid property in configuration") } diff --git a/struct/traits/lifecycle/lifecycle_test.go b/struct/traits/lifecycle/lifecycle_test.go index e0ee14f..52e7c6e 100644 --- a/struct/traits/lifecycle/lifecycle_test.go +++ b/struct/traits/lifecycle/lifecycle_test.go @@ -4,31 +4,19 @@ import ( "sync/atomic" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestLifecycleManager(t *testing.T) { m := Manager{} m.Setup() - if !m.BeginInitialization() { - t.Error("initialization should begin properly.") - } - - if m.IsRunning() { - t.Error("isRunning should be false while initialization is going on") - } - - if m.BeginInitialization() { - t.Error("initialization should fail if called more than once.") - } - - if !m.InitializationComplete() { - t.Error("should complete initialization correctly") - } - - if !m.IsRunning() { - t.Error("it should be running") - } + assert.True(t, m.BeginInitialization()) + assert.False(t, m.IsRunning()) + assert.False(t, m.BeginInitialization()) + assert.True(t, m.InitializationComplete()) + assert.True(t, m.IsRunning()) done := make(chan struct{}, 1) go func() { @@ -43,36 +31,22 @@ func TestLifecycleManager(t *testing.T) { } }() - if !m.BeginShutdown() { - t.Error("shutdown should be correctly propagated") - } - if m.BeginShutdown() { - t.Error("once shutdown is started, it should no longer propagate further requests") - } - m.AwaitShutdownComplete() - if m.IsRunning() { - t.Error("should not be running") - } - <-done // ensure that await actually waits + assert.True(t, m.BeginShutdown()) + assert.False(t, m.BeginShutdown()) - // Start again + m.AwaitShutdownComplete() - if !m.BeginInitialization() { - t.Error("initialization should begin properly.") - } + assert.False(t, m.IsRunning()) - if m.IsRunning() { - t.Error("isRunning should be false while initialization is going on") - } + <-done // ensure that await actually waits - if m.BeginInitialization() { - t.Error("initialization should fail if called more than once.") - } + // Start again - m.InitializationComplete() - if !m.IsRunning() { - t.Error("it should be running") - } + assert.True(t, m.BeginInitialization()) + assert.False(t, m.IsRunning()) + assert.False(t, m.BeginInitialization()) + assert.True(t, m.InitializationComplete()) + assert.True(t, m.IsRunning()) done = make(chan struct{}, 1) go func() { @@ -87,16 +61,13 @@ func TestLifecycleManager(t *testing.T) { } }() - if !m.BeginShutdown() { - t.Error("shutdown should be correctly propagated") - } - if m.BeginShutdown() { - t.Error("once shutdown is started, it should no longer propagate further requests") - } + assert.True(t, m.BeginShutdown()) + assert.False(t, m.BeginShutdown()) + m.AwaitShutdownComplete() - if m.IsRunning() { - t.Error("should not be running") - } + + assert.False(t, m.IsRunning()) + <-done // ensure that await actually waits } @@ -104,22 +75,11 @@ func TestLifecycleManagerAbnormalShutdown(t *testing.T) { m := Manager{} m.Setup() - if !m.BeginInitialization() { - t.Error("initialization should begin properly.") - } - - if m.IsRunning() { - t.Error("isRunning should be false while initialization is going on") - } - - if m.BeginInitialization() { - t.Error("initialization should fail if called more than once.") - } - - m.InitializationComplete() - if !m.IsRunning() { - t.Error("it should be running") - } + assert.True(t, m.BeginInitialization()) + assert.False(t, m.IsRunning()) + assert.False(t, m.BeginInitialization()) + assert.True(t, m.InitializationComplete()) + assert.True(t, m.IsRunning()) done := make(chan struct{}, 1) go func() { @@ -134,30 +94,18 @@ func TestLifecycleManagerAbnormalShutdown(t *testing.T) { } }() + m.AwaitShutdownComplete() - if m.IsRunning() { - t.Error("should not be running") - } + assert.False(t, m.IsRunning()) <-done // ensure that await actually waits // Start again - if !m.BeginInitialization() { - t.Error("initialization should begin properly.") - } - - if m.IsRunning() { - t.Error("isRunning should be false while initialization is going on") - } - - if m.BeginInitialization() { - t.Error("initialization should fail if called more than once.") - } - - m.InitializationComplete() - if !m.IsRunning() { - t.Error("it should be running") - } + assert.True(t, m.BeginInitialization()) + assert.False(t, m.IsRunning()) + assert.False(t, m.BeginInitialization()) + assert.True(t, m.InitializationComplete()) + assert.True(t, m.IsRunning()) done = make(chan struct{}, 1) go func() { @@ -172,17 +120,12 @@ func TestLifecycleManagerAbnormalShutdown(t *testing.T) { } }() - if !m.BeginShutdown() { - t.Error("shutdown should be correctly propagated") - } + assert.True(t, m.BeginShutdown()) + assert.False(t, m.BeginShutdown()) - if m.BeginShutdown() { - t.Error("once shutdown is started, it should no longer propagate further requests") - } m.AwaitShutdownComplete() - if m.IsRunning() { - t.Error("should not be running") - } + assert.False(t, m.IsRunning()) + <-done // ensure that await actually waits } @@ -190,15 +133,9 @@ func TestShutdownRequestWhileInitNotComplete(t *testing.T) { m := Manager{} m.Setup() - m.BeginInitialization() - if !m.BeginShutdown() { - t.Error("should accept the shutdown request") - } - - if m.InitializationComplete() { - t.Error("initialization cannot complete.") - } - + assert.True(t, m.BeginInitialization()) + assert.True(t, m.BeginShutdown()) + assert.False(t, m.InitializationComplete()) m.ShutdownComplete() // Now restart the lifecycle to see if it works properly @@ -221,14 +158,11 @@ func TestShutdownRequestWhileInitNotComplete(t *testing.T) { } } }() - m.BeginShutdown() + + assert.True(t, m.BeginShutdown()) m.AwaitShutdownComplete() - if m.IsRunning() { - t.Error("should not be running") - } + assert.False(t, m.IsRunning()) <-done // ensure that await actually waits - if atomic.LoadInt32(&executed) != 0 { - t.Error("the goroutine should have not executed further than the InitializationComplete check.") - } + assert.Equal(t, int32(0), atomic.LoadInt32(&executed)) } diff --git a/sync/atomicbool_test.go b/sync/atomicbool_test.go index 445ee8e..64d53c8 100644 --- a/sync/atomicbool_test.go +++ b/sync/atomicbool_test.go @@ -2,40 +2,20 @@ package sync import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestAtomicBool(t *testing.T) { a := NewAtomicBool(false) - if a.IsSet() { - t.Error("initial value should be false") - } - - if !a.TestAndSet() { - t.Error("compare and swap should succeed with no other concurrent access.") - } - - if a.TestAndSet() { - t.Error("compare and swap should return false if it didn't change anything.") - } - - if !a.IsSet() { - t.Error("should now be true") - } + assert.False(t, a.IsSet()) + assert.True(t, a.TestAndSet()) + assert.False(t, a.TestAndSet()) + assert.True(t, a.IsSet()) b := NewAtomicBool(true) - if !b.IsSet() { - t.Error("initial value should be true") - } - - if b.TestAndClear() != true { - t.Error("compare and swap should succeed with no other concurrent access.") - } - - if !a.TestAndClear() { - t.Error("compare and swap should return false if it didn't change anything.") - } - - if b.IsSet() { - t.Error("should now be false") - } + assert.True(t, b.IsSet()) + assert.True(t, b.TestAndClear()) + assert.False(t, b.TestAndClear()) + assert.False(t, b.IsSet()) } diff --git a/workerpool/workerpool_test.go b/workerpool/workerpool_test.go index 9606b7f..5ea67ad 100644 --- a/workerpool/workerpool_test.go +++ b/workerpool/workerpool_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/splitio/go-toolkit/v6/logging" + "github.com/stretchr/testify/assert" ) var resMutex sync.RWMutex @@ -55,23 +56,17 @@ func TestWorkerAdminConstructionAndNormalOperation(t *testing.T) { } resMutex.RLock() - if results["worker_2"] > 10 { - t.Error("Worker should have stopped working!") - } + assert.Less(t, results["worker_2"], 10) resMutex.RUnlock() + time.Sleep(time.Second * 1) errs := wa.StopAll(false) - if errs != nil { - t.Error("Not all workers stopped properly") - t.Error(errs) - } + assert.Nil(t, errs) time.Sleep(time.Second * 1) for _, i := range []int{1, 2, 3} { wName := fmt.Sprintf("worker_%d", i) - if wa.IsWorkerRunning(wName) { - t.Errorf("Worker %s should be stopped", wName) - } + assert.False(t, wa.IsWorkerRunning(wName)) } } @@ -131,21 +126,15 @@ func TestWaitingForWorkersToFinish(t *testing.T) { } resMutex.RLock() - if results["worker_2"] > 10 { - t.Error("Worker should have stopped working!") - } + assert.Less(t, results["worker_2"], 10) resMutex.RUnlock() time.Sleep(time.Second * 1) + errs := wa.StopAll(true) - if errs != nil { - t.Error("Not all workers stopped properly") - t.Error(errs) - } + assert.Nil(t, errs) for _, i := range []int{1, 2, 3, 4} { wName := fmt.Sprintf("worker_%d", i) - if wa.IsWorkerRunning(wName) { - t.Errorf("Worker %s should be stopped", wName) - } + assert.False(t, wa.IsWorkerRunning(wName)) } } From 714ee30c71ea18ca828f3b5283f6da31aa491fd3 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Thu, 28 Mar 2024 16:42:43 -0300 Subject: [PATCH 4/8] update ci --- .github/workflows/ci.yml | 4 ++-- Makefile | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b3e82df..ac9959a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,13 +25,13 @@ jobs: - name: Set up Go version uses: actions/setup-go@v3 with: - go-version: '1.18.0' + go-version: '1.21.0' - name: Go mod run: go mod tidy - name: Execute tests - run: go test -coverprofile=coverage.out -count=1 -race ./... + run: make test - name: SonarQube Scan (Push) if: github.event_name == 'push' diff --git a/Makefile b/Makefile index 2b35782..b706847 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,11 @@ GO ?= go +COVERAGE_OUT = coverage.out .PHONY: test test-norace test: - $(GO) test ./... -count=1 -race + $(GO) test ./... -count=1 -race -coverprofile=$(COVERAGE_OUT) test-norace: $(GO) test ./... -count=1 - From 7499252b9a865d69bbcff9e96030006d492b6ca1 Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Fri, 5 Apr 2024 18:45:53 -0300 Subject: [PATCH 5/8] remove comparable constraint from cache value --- datastructures/cache/cache.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datastructures/cache/cache.go b/datastructures/cache/cache.go index c1ed7dc..453cd4e 100644 --- a/datastructures/cache/cache.go +++ b/datastructures/cache/cache.go @@ -12,13 +12,13 @@ const ( ) // SimpleLRU is an in-memory TTL & LRU cache -type SimpleLRU[K comparable, V comparable] interface { +type SimpleLRU[K comparable, V any] interface { Get(key K) (V, error) Set(key K, value V) error } // SimpleLRUImpl implements the Simple interface -type SimpleLRUImpl[K comparable, V comparable] struct { +type SimpleLRUImpl[K comparable, V any] struct { ttl time.Duration maxLen int ttls map[K]time.Time @@ -27,7 +27,7 @@ type SimpleLRUImpl[K comparable, V comparable] struct { mutex sync.Mutex } -type centry[K comparable, V comparable] struct { +type centry[K comparable, V any] struct { key K value V } @@ -99,7 +99,7 @@ func (c *SimpleLRUImpl[K, V]) Set(key K, value V) error { } // NewSimple returns a new Simple instance of the specified size and TTL -func NewSimpleLRU[K comparable, V comparable](maxSize int, ttl time.Duration) (*SimpleLRUImpl[K, V], error) { +func NewSimpleLRU[K comparable, V any](maxSize int, ttl time.Duration) (*SimpleLRUImpl[K, V], error) { if maxSize <= 0 { return nil, fmt.Errorf("Cache size should be > 0. Is: %d", maxSize) } From 544d3ac961c0b53330392b270d3f0e21a36ce4ed Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Thu, 25 Jul 2024 15:07:49 -0300 Subject: [PATCH 6/8] add deref --- common/common.go | 8 ++++++++ common/common_test.go | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/common/common.go b/common/common.go index de86446..3663b3e 100644 --- a/common/common.go +++ b/common/common.go @@ -34,6 +34,14 @@ func PointerOf[T any](x interface{}) *T { return &ta } +// DerefOr returns value pointed by `tp` or the fallback supplied +func DerefOr[T any](tp *T, fallback T) T { + if tp == nil { + return fallback + } + return *tp +} + // PartitionSliceByLength partitions a slice into multiple slices of up to `maxItems` size func PartitionSliceByLength[T comparable](items []T, maxItems int) [][]T { var splitted [][]T diff --git a/common/common_test.go b/common/common_test.go index a010f41..8e82c9e 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -62,6 +62,14 @@ func TestDedupeInNewSlice(t *testing.T) { assert.ElementsMatch(t, []string{"a", "c"}, UnorderedDedupedCopy([]string{"c", "c", "a"})) } +func TestDerefOr(t *testing.T) { + assert.Equal(t, "hola", DerefOr(Ref("hola"), "sarasa")) + assert.Equal(t, "sarasa", DerefOr(nil, "sarasa")) + assert.Equal(t, "", DerefOr(nil, "")) + assert.Equal(t, 1, Ref(1), 2) + assert.Equal(t, 2, nil, 2) +} + func TestValueOr(t *testing.T) { assert.Equal(t, int64(3), ValueOr[int64](0, 3)) assert.Equal(t, (*int)(nil), ValueOr[*int](nil, nil)) From f4ba753312e6cdefd295d63e5c224bd040243c1b Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Thu, 25 Jul 2024 15:13:32 -0300 Subject: [PATCH 7/8] fix test --- common/common_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/common_test.go b/common/common_test.go index 8e82c9e..e7eb44e 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -66,8 +66,8 @@ func TestDerefOr(t *testing.T) { assert.Equal(t, "hola", DerefOr(Ref("hola"), "sarasa")) assert.Equal(t, "sarasa", DerefOr(nil, "sarasa")) assert.Equal(t, "", DerefOr(nil, "")) - assert.Equal(t, 1, Ref(1), 2) - assert.Equal(t, 2, nil, 2) + assert.Equal(t, 1, DerefOr(Ref(1), 2)) + assert.Equal(t, 2, DerefOr(nil, 2)) } func TestValueOr(t *testing.T) { From d2571ad4940283a934832b7d58dcc7166dc4aa2a Mon Sep 17 00:00:00 2001 From: Martin Redolatti Date: Thu, 22 Aug 2024 11:47:36 -0300 Subject: [PATCH 8/8] allow manually evicting a key --- datastructures/cache/cache.go | 26 +++++++++++++++++++++----- datastructures/cache/cache_test.go | 29 +++++++++++++++++++++-------- datastructures/cache/multilevel.go | 2 +- 3 files changed, 43 insertions(+), 14 deletions(-) diff --git a/datastructures/cache/cache.go b/datastructures/cache/cache.go index 453cd4e..68cfb66 100644 --- a/datastructures/cache/cache.go +++ b/datastructures/cache/cache.go @@ -15,6 +15,7 @@ const ( type SimpleLRU[K comparable, V any] interface { Get(key K) (V, error) Set(key K, value V) error + FlushKey(key K) } // SimpleLRUImpl implements the Simple interface @@ -98,16 +99,31 @@ func (c *SimpleLRUImpl[K, V]) Set(key K, value V) error { return nil } +// FlushKey removes an entry form the cache if it exists. Does nothing otherwise +func (c *SimpleLRUImpl[K, V]) FlushKey(key K) { + c.mutex.Lock() + defer c.mutex.Unlock() + + if node, ok := c.items[key]; ok { + c.lru.MoveToBack(node) + delete(c.items, key) + if c.ttls != nil { + delete(c.ttls, key) + } + c.lru.Remove(c.lru.Back()) + } +} + // NewSimple returns a new Simple instance of the specified size and TTL func NewSimpleLRU[K comparable, V any](maxSize int, ttl time.Duration) (*SimpleLRUImpl[K, V], error) { if maxSize <= 0 { return nil, fmt.Errorf("Cache size should be > 0. Is: %d", maxSize) } - - var ttls map[K]time.Time = nil - if ttl != NoTTL { - ttls = make(map[K]time.Time) - } + + var ttls map[K]time.Time = nil + if ttl != NoTTL { + ttls = make(map[K]time.Time) + } return &SimpleLRUImpl[K, V]{ maxLen: maxSize, diff --git a/datastructures/cache/cache_test.go b/datastructures/cache/cache_test.go index 1bf9eda..ef9c247 100644 --- a/datastructures/cache/cache_test.go +++ b/datastructures/cache/cache_test.go @@ -63,6 +63,19 @@ func TestSimpleCache(t *testing.T) { assert.Equal(t, asExpired.When, ttl.Add(cache.ttl)) } + + assert.Nil(t, cache.Set("lala", 123)) + v, err := cache.Get("lala") + assert.Nil(t, err) + assert.Equal(t, 123, v) + + cache.FlushKey("lala") + v, err = cache.Get("lala") + assert.NotNil(t, err) + assert.Equal(t, 0, v) + + var exp *Miss + assert.ErrorAs(t, err, &exp) } func TestSimpleCacheHighConcurrency(t *testing.T) { @@ -103,25 +116,25 @@ func TestInt64Cache(t *testing.T) { for i := int64(1); i <= 5; i++ { val, err := c.Get(i) - assert.Nil(t, err) - assert.Equal(t, i, val) + assert.Nil(t, err) + assert.Equal(t, i, val) } c.Set(6, 6) // Oldest item (1) should have been removed val, err := c.Get(1) - assert.NotNil(t, err) + assert.NotNil(t, err) _, ok := err.(*Miss) - assert.True(t, ok) - assert.Equal(t, int64(0), val) + assert.True(t, ok) + assert.Equal(t, int64(0), val) // 2-6 should be available for i := int64(2); i <= 6; i++ { val, err := c.Get(i) - assert.Nil(t, err) - assert.Equal(t, i, val) + assert.Nil(t, err) + assert.Equal(t, i, val) } - assert.Equal(t, 5, len(c.items)) + assert.Equal(t, 5, len(c.items)) } diff --git a/datastructures/cache/multilevel.go b/datastructures/cache/multilevel.go index 0c2254e..15309e4 100644 --- a/datastructures/cache/multilevel.go +++ b/datastructures/cache/multilevel.go @@ -49,7 +49,7 @@ func (c *MultiLevelCacheImpl[K, V]) Get(ctx context.Context, key K) (V, error) { } } - var empty V + var empty V if item == empty || err != nil { return empty, &Miss{Where: "ALL_LEVELS", Key: key} }