Skip to content

Commit

Permalink
all: client runtime index
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Mar 25, 2024
1 parent 2611534 commit 1428d60
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 37 deletions.
84 changes: 84 additions & 0 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ package client
import (
"encoding"
"fmt"
"net/netip"
"sync"

"github.com/AdguardTeam/AdGuardHome/internal/whois"
)
Expand Down Expand Up @@ -157,3 +159,85 @@ func (r *Runtime) IsEmpty() (ok bool) {
r.dhcp == nil &&
r.hostsFile == nil
}

// RuntimeIndex stores information about runtime clients.
type RuntimeIndex struct {
// indexMu protects index.
indexMu *sync.RWMutex

// index maps IP address to runtime client.
index map[netip.Addr]*Runtime
}

// NewRuntimeIndex returns initialized runtime index.
func NewRuntimeIndex() (ri *RuntimeIndex) {
return &RuntimeIndex{
indexMu: &sync.RWMutex{},
index: map[netip.Addr]*Runtime{},
}
}

// Client returns the saved runtime client by ip. If no such client exists,
// returns nil.
func (ri *RuntimeIndex) Client(ip netip.Addr) (rc *Runtime, ok bool) {
ri.indexMu.RLock()
defer ri.indexMu.RUnlock()

rc, ok = ri.index[ip]

return rc, ok
}

// Add saves the runtime client by ip.
func (ri *RuntimeIndex) Add(ip netip.Addr, rc *Runtime) {
ri.indexMu.Lock()
defer ri.indexMu.Unlock()

ri.index[ip] = rc
}

// Size returns the number of the runtime clients.
func (ri *RuntimeIndex) Size() (n int) {
ri.indexMu.RLock()
defer ri.indexMu.RUnlock()

return len(ri.index)
}

// Range calls cb for each runtime client.
func (ri *RuntimeIndex) Range(cb func(ip netip.Addr, rc *Runtime) (cont bool)) {
ri.indexMu.RLock()
defer ri.indexMu.RUnlock()

for ip, rc := range ri.index {
if !cb(ip, rc) {
return
}
}
}

// Delete removes the runtime client by ip.
func (ri *RuntimeIndex) Delete(ip netip.Addr) {
ri.indexMu.Lock()
defer ri.indexMu.Unlock()

delete(ri.index, ip)
}

// DeleteBySrc removes all runtime clients that have information only from the
// specified source and returns the number of removed clients.
func (ri *RuntimeIndex) DeleteBySrc(src Source) (n int) {
ri.indexMu.Lock()
defer ri.indexMu.Unlock()

for ip, rc := range ri.index {
rc.Unset(src)

if rc.IsEmpty() {
delete(ri.index, ip)
n++
}
}

return n
}
50 changes: 18 additions & 32 deletions internal/home/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ type clientsContainer struct {
// types (string, netip.Addr, and so on).
list map[string]*client.Persistent // name -> client

// clientIndex stores information about persistent clients.
clientIndex *client.Index

// ipToRC maps IP addresses to runtime client information.
ipToRC map[netip.Addr]*client.Runtime
// runtimeIndex stores information about runtime clients.
runtimeIndex *client.RuntimeIndex

allTags *stringutil.Set

Expand Down Expand Up @@ -104,7 +105,7 @@ func (clients *clientsContainer) Init(
}

clients.list = map[string]*client.Persistent{}
clients.ipToRC = map[netip.Addr]*client.Runtime{}
clients.runtimeIndex = client.NewRuntimeIndex()

clients.clientIndex = client.NewIndex()

Expand Down Expand Up @@ -362,7 +363,7 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source)
return client.SourcePersistent
}

rc, ok := clients.ipToRC[ip]
rc, ok := clients.runtimeIndex.Client(ip)
if ok {
src, _ = rc.Info()
}
Expand Down Expand Up @@ -558,10 +559,7 @@ func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtim
return nil, false
}

clients.lock.Lock()
defer clients.lock.Unlock()

rc, ok = clients.ipToRC[ip]
rc, ok = clients.runtimeIndex.Client(ip)

return rc, ok
}
Expand Down Expand Up @@ -733,12 +731,12 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
return
}

rc, ok := clients.ipToRC[ip]
rc, ok := clients.runtimeIndex.Client(ip)
if !ok {
// Create a RuntimeClient implicitly so that we don't do this check
// again.
rc = &client.Runtime{}
clients.ipToRC[ip] = rc
clients.runtimeIndex.Add(ip, rc)

log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
} else {
Expand Down Expand Up @@ -797,7 +795,7 @@ func (clients *clientsContainer) addHostLocked(
host string,
src client.Source,
) (ok bool) {
rc, ok := clients.ipToRC[ip]
rc, ok := clients.runtimeIndex.Client(ip)
if !ok {
if src < client.SourceDHCP {
if clients.dhcp.HostByIP(ip) != "" {
Expand All @@ -806,52 +804,39 @@ func (clients *clientsContainer) addHostLocked(
}

rc = &client.Runtime{}
clients.ipToRC[ip] = rc
clients.runtimeIndex.Add(ip, rc)
}

rc.SetInfo(src, []string{host})

log.Debug("clients: adding client info %s -> %q %q [%d]", ip, src, host, len(clients.ipToRC))
log.Debug("clients: adding client info %s -> %q %q [%d]", ip, src, host, clients.runtimeIndex.Size())

return true
}

// rmHostsBySrc removes all entries that match the specified source.
func (clients *clientsContainer) rmHostsBySrc(src client.Source) {
n := 0
for ip, rc := range clients.ipToRC {
rc.Unset(src)
if rc.IsEmpty() {
delete(clients.ipToRC, ip)
n++
}
}

log.Debug("clients: removed %d client aliases", n)
}

// addFromHostsFile fills the client-hostname pairing index from the system's
// hosts files.
func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) {
clients.lock.Lock()
defer clients.lock.Unlock()

clients.rmHostsBySrc(client.SourceHostsFile)
deleted := clients.runtimeIndex.DeleteBySrc(client.SourceHostsFile)
log.Debug("clients: removed %d client aliases from system hosts file", deleted)

n := 0
added := 0
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
// Only the first name of the first record is considered a canonical
// hostname for the IP address.
//
// TODO(e.burkov): Consider using all the names from all the records.
if clients.addHostLocked(addr, names[0], client.SourceHostsFile) {
n++
added++
}

return true
})

log.Debug("clients: added %d client aliases from system hosts file", n)
log.Debug("clients: added %d client aliases from system hosts file", added)
}

// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
Expand All @@ -875,7 +860,8 @@ func (clients *clientsContainer) addFromSystemARP() {
clients.lock.Lock()
defer clients.lock.Unlock()

clients.rmHostsBySrc(client.SourceARP)
deleted := clients.runtimeIndex.DeleteBySrc(client.SourceARP)
log.Debug("clients: removed %d client aliases from arp neighborhood", deleted)

added := 0
for _, n := range ns {
Expand Down
9 changes: 6 additions & 3 deletions internal/home/clients_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,9 @@ func TestClientsWHOIS(t *testing.T) {
t.Run("new_client", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.255")
clients.setWHOISInfo(ip, whois)
rc := clients.ipToRC[ip]
rc, ok := clients.runtimeIndex.Client(ip)
require.NotNil(t, rc)
require.True(t, ok)

assert.Equal(t, whois, rc.WHOIS())
})
Expand All @@ -256,8 +257,9 @@ func TestClientsWHOIS(t *testing.T) {
assert.True(t, ok)

clients.setWHOISInfo(ip, whois)
rc := clients.ipToRC[ip]
rc, ok := clients.runtimeIndex.Client(ip)
require.NotNil(t, rc)
require.True(t, ok)

assert.Equal(t, whois, rc.WHOIS())
})
Expand All @@ -274,8 +276,9 @@ func TestClientsWHOIS(t *testing.T) {
assert.True(t, ok)

clients.setWHOISInfo(ip, whois)
rc := clients.ipToRC[ip]
rc, ok := clients.runtimeIndex.Client(ip)
require.Nil(t, rc)
require.False(t, ok)

assert.True(t, clients.remove("client1"))
})
Expand Down
6 changes: 4 additions & 2 deletions internal/home/clientshttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
data.Clients = append(data.Clients, cj)
}

for ip, rc := range clients.ipToRC {
clients.runtimeIndex.Range(func(ip netip.Addr, rc *client.Runtime) (cont bool) {
src, host := rc.Info()
cj := runtimeClientJSON{
WHOIS: whoisOrEmpty(rc),
Expand All @@ -111,7 +111,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
}

data.RuntimeClients = append(data.RuntimeClients, cj)
}

return true
})

for _, l := range clients.dhcp.Leases() {
cj := runtimeClientJSON{
Expand Down

0 comments on commit 1428d60

Please sign in to comment.