Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use DNS record TTL and whether a record is in-use to determine expiration #7941

Merged
merged 7 commits into from
Jun 28, 2021
97 changes: 51 additions & 46 deletions pkg/network/dns_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package network

import (
"sort"
"strings"
"sync"
"sync/atomic"
"time"
Expand All @@ -27,19 +26,17 @@ type reverseDNSCache struct {
mux sync.Mutex
data map[util.Address]*dnsCacheVal
exit chan struct{}
ttl time.Duration
size int

// maxDomainsPerIP is the maximum number of domains mapped to a single IP
maxDomainsPerIP int
oversizedLogLimit *util.LogLimit
}

func newReverseDNSCache(size int, ttl, expirationPeriod time.Duration) *reverseDNSCache {
func newReverseDNSCache(size int, expirationPeriod time.Duration) *reverseDNSCache {
cache := &reverseDNSCache{
data: make(map[util.Address]*dnsCacheVal),
exit: make(chan struct{}),
ttl: ttl,
size: size,
oversizedLogLimit: util.NewLogLimit(10, time.Minute*10),
maxDomainsPerIP: 1000,
Expand All @@ -60,7 +57,7 @@ func newReverseDNSCache(size int, ttl, expirationPeriod time.Duration) *reverseD
return cache
}

func (c *reverseDNSCache) Add(translation *translation, now time.Time) bool {
func (c *reverseDNSCache) Add(translation *translation) bool {
if translation == nil {
return false
}
Expand All @@ -71,17 +68,16 @@ func (c *reverseDNSCache) Add(translation *translation, now time.Time) bool {
return false
}

exp := now.Add(c.ttl).UnixNano()
for _, addr := range translation.ips {
for addr, deadline := range translation.ips {
val, ok := c.data[addr]
if ok {
val.expiration = exp
if rejected := val.merge(translation.dns, c.maxDomainsPerIP); rejected && c.oversizedLogLimit.ShouldLog() {
if rejected := val.merge(translation.dns, deadline, c.maxDomainsPerIP); rejected && c.oversizedLogLimit.ShouldLog() {
log.Warnf("%s mapped to too many domains, DNS information will be dropped (this will be logged the first 10 times, and then at most every 10 minutes)", addr)
}
} else {
atomic.AddInt64(&c.added, 1)
c.data[addr] = &dnsCacheVal{names: []string{translation.dns}, expiration: exp}
// flag as in use, so mapping survives until next time connections are queried, in case TTL is shorter
c.data[addr] = &dnsCacheVal{names: map[string]time.Time{translation.dns: deadline}, inUse: true}
ISauve marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand All @@ -91,7 +87,14 @@ func (c *reverseDNSCache) Add(translation *translation, now time.Time) bool {
return true
}

func (c *reverseDNSCache) Get(conns []ConnectionStats, now time.Time) map[util.Address][]string {
func (c *reverseDNSCache) Get(conns []ConnectionStats) map[util.Address][]string {
c.mux.Lock()
defer c.mux.Unlock()

for _, val := range c.data {
val.inUse = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you please add a note perhaps here or at the type definition explaining the lifecycle of inUse. I don't quite follow how it works.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note added.

}

if len(conns) == 0 {
return nil
}
Expand All @@ -100,7 +103,6 @@ func (c *reverseDNSCache) Get(conns []ConnectionStats, now time.Time) map[util.A
resolved = make(map[util.Address][]string)
unresolved = make(map[util.Address]struct{})
oversized = make(map[util.Address]struct{})
expiration = now.Add(c.ttl).UnixNano()
)

collectNamesForIP := func(addr util.Address) {
Expand All @@ -116,7 +118,7 @@ func (c *reverseDNSCache) Get(conns []ConnectionStats, now time.Time) map[util.A
return
}

names := c.getNamesForIP(addr, expiration)
names := c.getNamesForIP(addr)
if len(names) == 0 {
unresolved[addr] = struct{}{}
} else if len(names) == c.maxDomainsPerIP {
Expand All @@ -126,12 +128,10 @@ func (c *reverseDNSCache) Get(conns []ConnectionStats, now time.Time) map[util.A
}
}

c.mux.Lock()
for _, conn := range conns {
collectNamesForIP(conn.Source)
collectNamesForIP(conn.Dest)
}
c.mux.Unlock()

// Update stats for telemetry
atomic.AddInt64(&c.lookups, int64(len(resolved)+len(unresolved)))
Expand Down Expand Up @@ -170,14 +170,22 @@ func (c *reverseDNSCache) Close() {
}

func (c *reverseDNSCache) Expire(now time.Time) {
deadline := now.UnixNano()
expired := 0
c.mux.Lock()
for addr, val := range c.data {
if val.expiration > deadline {
if val.inUse {
continue
}

for ip, deadline := range val.names {
if deadline.Before(now) {
delete(val.names, ip)
}
}

if len(val.names) != 0 {
continue
}
expired++
delete(c.data, addr)
}
Expand All @@ -192,61 +200,58 @@ func (c *reverseDNSCache) Expire(now time.Time) {
)
}

func (c *reverseDNSCache) getNamesForIP(ip util.Address, updatedTTL int64) []string {
func (c *reverseDNSCache) getNamesForIP(ip util.Address) []string {
val, ok := c.data[ip]
if !ok {
return nil
}

val.expiration = updatedTTL
val.inUse = true
return val.copy()
}

type dnsCacheVal struct {
// opting for a []string instead of map[string]struct{} since common case is len(names) == 1
names []string
expiration int64
names map[string]time.Time
// inUse keeps track of whether this dns cache record is currently in use by a connection.
// This flag is reset to false every time reverseDnsCache.Get is called.
// This flag is only set to true if reverseDNSCache.getNamesForIP returns this struct.
// If inUse is set, then this record will not be expired out.
inUse bool
}

func (v *dnsCacheVal) merge(name string, maxSize int) (rejected bool) {
normalized := strings.ToLower(name)
if i := sort.SearchStrings(v.names, normalized); i < len(v.names) && v.names[i] == normalized {
func (v *dnsCacheVal) merge(name string, deadline time.Time, maxSize int) (rejected bool) {
if exp, ok := v.names[name]; ok {
if deadline.After(exp) {
v.names[name] = deadline
brycekahle marked this conversation as resolved.
Show resolved Hide resolved
v.inUse = true
}
return false
}

if len(v.names) == maxSize {
return true
}

v.names = append(v.names, normalized)
sort.Strings(v.names)
v.names[name] = deadline
brycekahle marked this conversation as resolved.
Show resolved Hide resolved
v.inUse = true
return false
}

func (v *dnsCacheVal) copy() []string {
cpy := make([]string, len(v.names))
copy(cpy, v.names)
cpy := make([]string, 0, len(v.names))
for n := range v.names {
cpy = append(cpy, n)
}
sort.Strings(cpy)
return cpy
}

type translation struct {
dns string
ips []util.Address
ips map[util.Address]time.Time
}

func newTranslation(domain []byte) *translation {
return &translation{
dns: string(domain),
ips: nil,
func (t *translation) add(addr util.Address, ttl time.Duration) {
if _, ok := t.ips[addr]; ok {
return
}
}

func (t *translation) add(addr util.Address) {
for _, other := range t.ips {
if other == addr {
return
}
}

t.ips = append(t.ips, addr)
t.ips[addr] = time.Now().Add(ttl)
}
Loading