Skip to content

Commit

Permalink
util/generic: add a generic version of sync.Map (#37474)
Browse files Browse the repository at this point in the history
ref #35983
  • Loading branch information
tangenta authored Aug 30, 2022
1 parent d5c96ce commit 5f35b32
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 0 deletions.
17 changes: 17 additions & 0 deletions util/generic/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

go_library(
name = "generic",
srcs = ["sync_map.go"],
importpath = "github.com/pingcap/tidb/util/generic",
visibility = ["//visibility:public"],
)

go_test(
name = "generic_test",
srcs = ["sync_map_test.go"],
deps = [
":generic",
"@com_github_stretchr_testify//require",
],
)
63 changes: 63 additions & 0 deletions util/generic/sync_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package generic

import "sync"

// SyncMap is the generic version of the sync.Map.
type SyncMap[K comparable, V any] struct {
item map[K]V
mu sync.RWMutex
}

// NewSyncMap returns a new SyncMap.
func NewSyncMap[K comparable, V any](capacity int) SyncMap[K, V] {
return SyncMap[K, V]{
item: make(map[K]V, capacity),
}
}

// Store stores a value.
func (m *SyncMap[K, V]) Store(key K, value V) {
m.mu.Lock()
m.item[key] = value
m.mu.Unlock()
}

// Load loads a key value.
func (m *SyncMap[K, V]) Load(key K) (V, bool) {
m.mu.RLock()
val, exist := m.item[key]
m.mu.RUnlock()
return val, exist
}

// Delete deletes a key value.
func (m *SyncMap[K, V]) Delete(key K) {
m.mu.Lock()
delete(m.item, key)
m.mu.Unlock()
}

// Keys returns all the keys in the map.
func (m *SyncMap[K, V]) Keys() []K {
ret := make([]K, 0, len(m.item))
m.mu.RLock()
for k := range m.item {
ret = append(ret, k)
}
m.mu.RUnlock()
return ret
}
62 changes: 62 additions & 0 deletions util/generic/sync_map_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package generic_test

import (
"sort"
"testing"

"github.com/pingcap/tidb/util/generic"
"github.com/stretchr/testify/require"
)

func TestSyncMap(t *testing.T) {
sm := generic.NewSyncMap[int64, string](10)
sm.Store(1, "a")
sm.Store(2, "b")
// Load an exist key.
v, ok := sm.Load(1)
require.True(t, ok)
require.Equal(t, "a", v)
// Load a non-exist key.
v, ok = sm.Load(3)
require.False(t, ok)
require.Equal(t, "", v)
// Overwrite the value.
sm.Store(1, "c")
v, ok = sm.Load(1)
require.True(t, ok)
require.Equal(t, "c", v)
// Drop an exist key.
sm.Delete(1)
v, ok = sm.Load(1)
require.False(t, ok)
require.Equal(t, "", v)
// Drop a non-exist key.
sm.Delete(3)
require.Equal(t, []int64{2}, sm.Keys())
v, ok = sm.Load(3)
require.False(t, ok)
require.Equal(t, "", v)

// Test the Keys() method.
sm = generic.NewSyncMap[int64, string](10)
sm.Store(2, "b")
sm.Store(1, "a")
sm.Store(3, "c")
keys := sm.Keys()
sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
require.Equal(t, []int64{1, 2, 3}, keys)
}

0 comments on commit 5f35b32

Please sign in to comment.