Skip to content

Commit

Permalink
all: client persistent index
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Feb 21, 2024
1 parent 4605e7c commit a4fc949
Show file tree
Hide file tree
Showing 12 changed files with 198 additions and 169 deletions.
7 changes: 0 additions & 7 deletions internal/aghtest/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"time"

"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
Expand Down Expand Up @@ -94,9 +93,6 @@ type AddressProcessor struct {
OnClose func() (err error)
}

// type check
var _ client.AddressProcessor = (*AddressProcessor)(nil)

// Process implements the [client.AddressProcessor] interface for
// *AddressProcessor.
func (p *AddressProcessor) Process(ip netip.Addr) {
Expand All @@ -114,9 +110,6 @@ type AddressUpdater struct {
OnUpdateAddress func(ip netip.Addr, host string, info *whois.Info)
}

// type check
var _ client.AddressUpdater = (*AddressUpdater)(nil)

// UpdateAddress implements the [client.AddressUpdater] interface for
// *AddressUpdater.
func (p *AddressUpdater) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
Expand Down
7 changes: 7 additions & 0 deletions internal/aghtest/interface_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package aghtest_test

import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
)
Expand All @@ -13,3 +14,9 @@ var _ filtering.Resolver = (*aghtest.Resolver)(nil)

// type check
var _ dnsforward.ClientsContainer = (*aghtest.ClientsContainer)(nil)

// type check
var _ client.AddressProcessor = (*aghtest.AddressProcessor)(nil)

// type check
var _ client.AddressUpdater = (*aghtest.AddressUpdater)(nil)
42 changes: 21 additions & 21 deletions internal/home/clientindex.go → internal/client/index.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package home
package client

import (
"fmt"
Expand Down Expand Up @@ -26,8 +26,8 @@ func macToKey(mac net.HardwareAddr) (key macKey) {
}
}

// clientIndex stores all information about persistent clients.
type clientIndex struct {
// Index stores all information about persistent clients.
type Index struct {
// clientIDToUID maps client ID to UID.
clientIDToUID map[string]UID

Expand All @@ -38,26 +38,26 @@ type clientIndex struct {
macToUID map[macKey]UID

// uidToClient maps UID to the persistent client.
uidToClient map[UID]*persistentClient
uidToClient map[UID]*Persistent

// subnetToUID maps subnet to UID.
subnetToUID aghalg.SortedMap[netip.Prefix, UID]
}

// NewClientIndex initializes the new instance of client index.
func NewClientIndex() (ci *clientIndex) {
return &clientIndex{
// NewIndex initializes the new instance of client index.
func NewIndex() (ci *Index) {
return &Index{
clientIDToUID: map[string]UID{},
ipToUID: map[netip.Addr]UID{},
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
macToUID: map[macKey]UID{},
uidToClient: map[UID]*persistentClient{},
uidToClient: map[UID]*Persistent{},
}
}

// add stores information about a persistent client in the index. c must be
// Add stores information about a persistent client in the index. c must be
// non-nil and contain UID.
func (ci *clientIndex) add(c *persistentClient) {
func (ci *Index) Add(c *Persistent) {
if (c.UID == UID{}) {
panic("client must contain uid")
}
Expand All @@ -82,9 +82,9 @@ func (ci *clientIndex) add(c *persistentClient) {
ci.uidToClient[c.UID] = c
}

// clashes returns an error if the index contains a different persistent client
// Clashes returns an error if the index contains a different persistent client
// with at least a single identifier contained by c. c must be non-nil.
func (ci *clientIndex) clashes(c *persistentClient) (err error) {
func (ci *Index) Clashes(c *Persistent) (err error) {
for _, id := range c.ClientIDs {
existing, ok := ci.clientIDToUID[id]
if ok && existing != c.UID {
Expand Down Expand Up @@ -114,7 +114,7 @@ func (ci *clientIndex) clashes(c *persistentClient) (err error) {

// clashesIP returns a previous client with the same IP address as c. c must be
// non-nil.
func (ci *clientIndex) clashesIP(c *persistentClient) (p *persistentClient, ip netip.Addr) {
func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
for _, ip := range c.IPs {
existing, ok := ci.ipToUID[ip]
if ok && existing != c.UID {
Expand All @@ -127,7 +127,7 @@ func (ci *clientIndex) clashesIP(c *persistentClient) (p *persistentClient, ip n

// clashesSubnet returns a previous client with the same subnet as c. c must be
// non-nil.
func (ci *clientIndex) clashesSubnet(c *persistentClient) (p *persistentClient, s netip.Prefix) {
func (ci *Index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) {
for _, s = range c.Subnets {
var existing UID
var ok bool
Expand All @@ -153,7 +153,7 @@ func (ci *clientIndex) clashesSubnet(c *persistentClient) (p *persistentClient,

// clashesMAC returns a previous client with the same MAC address as c. c must
// be non-nil.
func (ci *clientIndex) clashesMAC(c *persistentClient) (p *persistentClient, mac net.HardwareAddr) {
func (ci *Index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) {
for _, mac = range c.MACs {
k := macToKey(mac)
existing, ok := ci.macToUID[k]
Expand All @@ -165,9 +165,9 @@ func (ci *clientIndex) clashesMAC(c *persistentClient) (p *persistentClient, mac
return nil, nil
}

// find finds persistent client by string representation of the client ID, IP
// Find finds persistent client by string representation of the client ID, IP
// address, or MAC.
func (ci *clientIndex) find(id string) (c *persistentClient, ok bool) {
func (ci *Index) Find(id string) (c *Persistent, ok bool) {
uid, found := ci.clientIDToUID[id]
if found {
return ci.uidToClient[uid], true
Expand All @@ -191,7 +191,7 @@ func (ci *clientIndex) find(id string) (c *persistentClient, ok bool) {
}

// find finds persistent client by IP address.
func (ci *clientIndex) findByIP(ip netip.Addr) (c *persistentClient, found bool) {
func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
uid, found := ci.ipToUID[ip]
if found {
return ci.uidToClient[uid], true
Expand All @@ -215,7 +215,7 @@ func (ci *clientIndex) findByIP(ip netip.Addr) (c *persistentClient, found bool)
}

// find finds persistent client by MAC.
func (ci *clientIndex) findByMAC(mac net.HardwareAddr) (c *persistentClient, found bool) {
func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
k := macToKey(mac)
uid, found := ci.macToUID[k]
if found {
Expand All @@ -225,9 +225,9 @@ func (ci *clientIndex) findByMAC(mac net.HardwareAddr) (c *persistentClient, fou
return nil, false
}

// del removes information about persistent client from the index. c must be
// Del removes information about persistent client from the index. c must be
// non-nil.
func (ci *clientIndex) del(c *persistentClient) {
func (ci *Index) Del(c *Persistent) {
for _, id := range c.ClientIDs {
delete(ci.clientIDToUID, id)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package home
package client

import (
"net"
Expand All @@ -9,6 +9,19 @@ import (
"github.com/stretchr/testify/require"
)

// newIDIndex is a helper function that returns a client index filled with
// persistent clients from the m. It also generates a UID for each client.
func newIDIndex(m []*Persistent) (ci *Index) {
ci = NewIndex()

for _, c := range m {
c.UID = MustNewUID()
ci.Add(c)
}

return ci
}

func TestClientIndex(t *testing.T) {
const (
cliIPNone = "1.2.3.4"
Expand All @@ -24,7 +37,7 @@ func TestClientIndex(t *testing.T) {
cliMAC = "11:11:11:11:11:11"
)

clients := []*persistentClient{{
clients := []*Persistent{{
Name: "client1",
IPs: []netip.Addr{
netip.MustParseAddr(cliIP1),
Expand All @@ -45,9 +58,9 @@ func TestClientIndex(t *testing.T) {
ci := newIDIndex(clients)

testCases := []struct {
want *Persistent
name string
ids []string
want *persistentClient
}{{
name: "ipv4_ipv6",
ids: []string{cliIP1, cliIPv6},
Expand All @@ -69,7 +82,7 @@ func TestClientIndex(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, id := range tc.ids {
c, ok := ci.find(id)
c, ok := ci.Find(id)
require.True(t, ok)

assert.Equal(t, tc.want, c)
Expand All @@ -78,7 +91,7 @@ func TestClientIndex(t *testing.T) {
}

t.Run("not_found", func(t *testing.T) {
_, ok := ci.find(cliIPNone)
_, ok := ci.Find(cliIPNone)
assert.False(t, ok)
})
}
Expand All @@ -92,7 +105,7 @@ func TestClientIndex_Clashes(t *testing.T) {
cliMAC = "11:11:11:11:11:11"
)

clients := []*persistentClient{{
clients := []*Persistent{{
Name: "client_with_ip",
IPs: []netip.Addr{netip.MustParseAddr(cliIP1)},
}, {
Expand All @@ -109,8 +122,8 @@ func TestClientIndex_Clashes(t *testing.T) {
ci := newIDIndex(clients)

testCases := []struct {
client *Persistent
name string
client *persistentClient
}{{
name: "ipv4",
client: clients[0],
Expand All @@ -127,14 +140,14 @@ func TestClientIndex_Clashes(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
clone := tc.client.shallowClone()
clone := tc.client.ShallowClone()
clone.UID = MustNewUID()

err := ci.clashes(clone)
err := ci.Clashes(clone)
require.Error(t, err)

ci.del(tc.client)
err = ci.clashes(clone)
ci.Del(tc.client)
err = ci.Clashes(clone)
require.NoError(t, err)
})
}
Expand All @@ -153,9 +166,9 @@ func mustParseMAC(s string) (mac net.HardwareAddr) {

func TestMACToKey(t *testing.T) {
testCases := []struct {
want any
name string
in string
want any
}{{
name: "column6",
in: "00:00:5e:00:53:01",
Expand Down
Loading

0 comments on commit a4fc949

Please sign in to comment.