Skip to content

Commit

Permalink
all: imp code, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Feb 6, 2024
1 parent 4a44e99 commit 7abe33d
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 86 deletions.
44 changes: 29 additions & 15 deletions internal/aghalg/orderedmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,27 @@ import (
"golang.org/x/exp/slices"
)

// OrderedMap is the implementation of the ordered map data structure.
type OrderedMap[K comparable, V any] struct {
// SortedMap is a map that keeps elements in order with internal sorting
// function.
type SortedMap[K comparable, V any] struct {
vals map[K]V
cmp func(a, b K) int
cmp func(a, b K) (res int)
keys []K
}

// NewOrderedMap initializes the new instance of ordered map. cmp is a sort
// function.
// NewSortedMap initializes the new instance of sorted map. cmp is a sort
// function to keep elements in order.
//
// TODO(s.chzhen): Use cmp.Compare in Go 1.21
func NewOrderedMap[K comparable, V any](cmp func(a, b K) int) OrderedMap[K, V] {
return OrderedMap[K, V]{
// TODO(s.chzhen): Use cmp.Compare in Go 1.21.
func NewSortedMap[K comparable, V any](cmp func(a, b K) (res int)) SortedMap[K, V] {
return SortedMap[K, V]{
vals: make(map[K]V),
cmp: cmp,
}
}

// Set adds val with key to the ordered map.
func (m *OrderedMap[K, V]) Set(key K, val V) {
// Set adds val with key to the sorted map.
func (m *SortedMap[K, V]) Set(key K, val V) {
i, has := slices.BinarySearchFunc(m.keys, key, m.cmp)
if has {
m.keys[i] = key
Expand All @@ -36,17 +37,30 @@ func (m *OrderedMap[K, V]) Set(key K, val V) {
m.vals[key] = val
}

// Del removes the value by key from the ordered map.
func (m *OrderedMap[K, V]) Del(key K) {
// Get returns val by key from the sorted map.
func (m *SortedMap[K, V]) Get(key K) (val V) {
return m.vals[key]
}

// Del removes the value by key from the sorted map.
func (m *SortedMap[K, V]) Del(key K) {
i, has := slices.BinarySearchFunc(m.keys, key, m.cmp)
if has {
m.keys = slices.Delete(m.keys, i, 1)
m.keys = slices.Delete(m.keys, i, i+1)
delete(m.vals, key)
}
}

// Range calls cb for each element of the map. If cb returns false it stops.
func (m *OrderedMap[K, V]) Range(cb func(K, V) (cont bool)) {
// Clear removes all elements from the sorted map.
func (m *SortedMap[K, V]) Clear() {
// TODO(s.chzhen): Use built-in clear in Go 1.21.
m.keys = nil
m.vals = make(map[K]V)
}

// Range calls cb for each element of the map, sorted by m.cmp. If cb returns
// false it stops.
func (m *SortedMap[K, V]) Range(cb func(K, V) (cont bool)) {
for _, k := range m.keys {
if !cb(k, m.vals[k]) {
return
Expand Down
15 changes: 9 additions & 6 deletions internal/aghalg/orderedmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
"github.com/stretchr/testify/assert"
)

func TestNewOrderedMap(t *testing.T) {
var m OrderedMap[string, int]
func TestNewSortedMap(t *testing.T) {
var m SortedMap[string, int]

letters := []string{}
for i := 0; i < 10; i++ {
Expand All @@ -17,7 +17,7 @@ func TestNewOrderedMap(t *testing.T) {
}

t.Run("create_and_fill", func(t *testing.T) {
m = NewOrderedMap[string, int](strings.Compare)
m = NewSortedMap[string, int](strings.Compare)

nums := []int{}
for i, r := range letters {
Expand All @@ -36,12 +36,15 @@ func TestNewOrderedMap(t *testing.T) {

assert.Equal(t, letters, gotLetters)
assert.Equal(t, nums, gotNums)
assert.Equal(t, nums[0], m.Get(letters[0]))
})

t.Run("clear", func(t *testing.T) {
for _, r := range letters {
m.Del(r)
}
lastLetter := letters[len(letters)-1]
m.Del(lastLetter)
assert.Equal(t, 0, m.Get(lastLetter))

m.Clear()

gotLetters := []string{}
m.Range(func(k string, _ int) bool {
Expand Down
77 changes: 43 additions & 34 deletions internal/home/clientindex.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,29 @@ import (
"net/netip"

"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"golang.org/x/exp/slices"
)

// macUID contains MAC and UID.
type macUID struct {
mac net.HardwareAddr
uid UID
// macKey contains MAC as byte array of 6, 8, or 20 bytes.
type macKey any

func macToKey(mac net.HardwareAddr) (key macKey) {
switch len(mac) {
case 6:
arr := [6]byte{}
copy(arr[:], mac[:])

return arr
case 8:
arr := [8]byte{}
copy(arr[:], mac[:])

return arr
default:
arr := [20]byte{}
copy(arr[:], mac[:])

return arr
}
}

// clientIndex stores all information about persistent clients.
Expand All @@ -20,9 +36,9 @@ type clientIndex struct {

ipToUID map[netip.Addr]UID

subnetToUID aghalg.OrderedMap[netip.Prefix, UID]
subnetToUID aghalg.SortedMap[netip.Prefix, UID]

macUIDs []*macUID
macToUID map[macKey]UID

uidToClient map[UID]*persistentClient
}
Expand All @@ -32,7 +48,8 @@ func NewClientIndex() (ci *clientIndex) {
return &clientIndex{
clientIDToUID: map[string]UID{},
ipToUID: map[netip.Addr]UID{},
subnetToUID: aghalg.NewOrderedMap[netip.Prefix, UID](subnetCompare),
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
macToUID: map[macKey]UID{},
uidToClient: map[UID]*persistentClient{},
}
}
Expand All @@ -52,14 +69,15 @@ func (ci *clientIndex) add(c *persistentClient) {
}

for _, mac := range c.MACs {
ci.macUIDs = append(ci.macUIDs, &macUID{mac, c.UID})
k := macToKey(mac)
ci.macToUID[k] = c.UID
}

ci.uidToClient[c.UID] = c
}

// contains returns true if the index already has information about persistent
// client.
// contains returns true if the index contains a persistent client with at least
// a single identifier contained by c.
func (ci *clientIndex) contains(c *persistentClient) (ok bool) {
for _, id := range c.ClientIDs {
_, ok = ci.clientIDToUID[id]
Expand All @@ -76,7 +94,7 @@ func (ci *clientIndex) contains(c *persistentClient) (ok bool) {
}

for _, pref := range c.Subnets {
ci.subnetToUID.Range(func(p netip.Prefix, id UID) bool {
ci.subnetToUID.Range(func(p netip.Prefix, _ UID) (cont bool) {
if pref == p {
ok = true

Expand All @@ -92,10 +110,8 @@ func (ci *clientIndex) contains(c *persistentClient) (ok bool) {
}

for _, mac := range c.MACs {
ok = slices.ContainsFunc(ci.macUIDs, func(muid *macUID) bool {
return slices.Compare(mac, muid.mac) == 0
})

k := macToKey(mac)
_, ok = ci.macToUID[k]
if ok {
return true
}
Expand All @@ -104,7 +120,7 @@ func (ci *clientIndex) contains(c *persistentClient) (ok bool) {
return false
}

// find finds persistent client by string represenation of the client ID, IP
// find finds persistent client by string representation of the client ID, IP
// address, or MAC.
func (ci *clientIndex) find(id string) (c *persistentClient, ok bool) {
uid, found := ci.clientIDToUID[id]
Expand All @@ -114,7 +130,11 @@ func (ci *clientIndex) find(id string) (c *persistentClient, ok bool) {

ip, err := netip.ParseAddr(id)
if err == nil {
return ci.findByIP(ip)
// MAC addresses can be successfully parsed as IP addresses.
c, found = ci.findByIP(ip)
if found {
return c, true
}
}

mac, err := net.ParseMAC(id)
Expand All @@ -132,7 +152,7 @@ func (ci *clientIndex) findByIP(ip netip.Addr) (c *persistentClient, found bool)
return ci.uidToClient[uid], true
}

ci.subnetToUID.Range(func(pref netip.Prefix, id UID) bool {
ci.subnetToUID.Range(func(pref netip.Prefix, id UID) (cont bool) {
if pref.Contains(ip) {
uid, found = id, true

Expand All @@ -151,17 +171,8 @@ func (ci *clientIndex) findByIP(ip netip.Addr) (c *persistentClient, found bool)

// find finds persistent client by MAC.
func (ci *clientIndex) findByMAC(mac net.HardwareAddr) (c *persistentClient, found bool) {
var uid UID
found = slices.ContainsFunc(ci.macUIDs, func(muid *macUID) bool {
if slices.Compare(mac, muid.mac) == 0 {
uid = muid.uid

return true
}

return false
})

k := macToKey(mac)
uid, found := ci.macToUID[k]
if found {
return ci.uidToClient[uid], true
}
Expand All @@ -184,10 +195,8 @@ func (ci *clientIndex) del(c *persistentClient) {
}

for _, mac := range c.MACs {
ci.macUIDs = append(ci.macUIDs, &macUID{mac, c.UID})
slices.DeleteFunc(ci.macUIDs, func(muid *macUID) bool {
return slices.Compare(mac, muid.mac) == 0
})
k := macToKey(mac)
delete(ci.macToUID, k)
}

delete(ci.uidToClient, c.UID)
Expand Down
Loading

0 comments on commit 7abe33d

Please sign in to comment.