Skip to content

Commit

Permalink
all: safesearch rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
Mizzick committed Apr 12, 2024
1 parent fb3efbb commit b78ad8f
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 236 deletions.
7 changes: 4 additions & 3 deletions internal/client/persistent.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ type Persistent struct {
// upstream must be used.
UpstreamConfig *proxy.CustomUpstreamConfig

// TODO(d.kolyshev): Make SafeSearchConf a pointer.
SafeSearchConf filtering.SafeSearchConfig
SafeSearch filtering.SafeSearch
SafeSearch filtering.SafeSearch

// BlockedServices is the configuration of blocked services of a client.
BlockedServices *filtering.BlockedServices
Expand Down Expand Up @@ -95,6 +93,9 @@ type Persistent struct {
UseOwnBlockedServices bool
IgnoreQueryLog bool
IgnoreStatistics bool

// TODO(d.kolyshev): Make SafeSearchConf a pointer.
SafeSearchConf filtering.SafeSearchConfig
}

// SetTags sets the tags if they are known, otherwise logs an unknown tag.
Expand Down
34 changes: 13 additions & 21 deletions internal/dnsforward/dnsforward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package dnsforward

import (
"cmp"
"context"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
Expand Down Expand Up @@ -491,19 +490,10 @@ func TestServerRace(t *testing.T) {
}

func TestSafeSearch(t *testing.T) {
resolver := &aghtest.Resolver{
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)

return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
},
}

safeSearchConf := filtering.SafeSearchConfig{
Enabled: true,
Google: true,
Yandex: true,
CustomResolver: resolver,
Enabled: true,
Google: true,
Yandex: true,
}

filterConf := &filtering.Config{
Expand Down Expand Up @@ -540,7 +530,6 @@ func TestSafeSearch(t *testing.T) {
client := &dns.Client{}

yandexIP := netip.AddrFrom4([4]byte{213, 180, 193, 56})
googleIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")

testCases := []struct {
host string
Expand All @@ -564,19 +553,19 @@ func TestSafeSearch(t *testing.T) {
wantCNAME: "",
}, {
host: "www.google.com.",
want: googleIP,
want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.",
}, {
host: "www.google.com.af.",
want: googleIP,
want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.",
}, {
host: "www.google.be.",
want: googleIP,
want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.",
}, {
host: "www.google.by.",
want: googleIP,
want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.",
}}

Expand All @@ -593,12 +582,15 @@ func TestSafeSearch(t *testing.T) {

cname := testutil.RequireTypeAssert[*dns.CNAME](t, reply.Answer[0])
assert.Equal(t, tc.wantCNAME, cname.Target)

a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[len(reply.Answer)-1])
assert.NotEmpty(t, a.A)
} else {
require.Len(t, reply.Answer, 1)
}

a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[len(reply.Answer)-1])
assert.Equal(t, net.IP(tc.want.AsSlice()), a.A)
a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[len(reply.Answer)-1])
assert.Equal(t, net.IP(tc.want.AsSlice()), a.A)
}
})
}
}
Expand Down
17 changes: 8 additions & 9 deletions internal/dnsforward/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
req := pctx.Req
q := req.Question[0]
host := strings.TrimSuffix(q.Name, ".")

resVal, err := s.dnsFilter.CheckHost(host, q.Qtype, dctx.setts)
if err != nil {
return nil, fmt.Errorf("checking host %q: %w", host, err)
Expand All @@ -39,22 +40,20 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
// TODO(a.garipov): Make CheckHost return a pointer.
res = &resVal
switch {
case res.IsFiltered:
log.Debug(
"dnsforward: host %q is filtered, reason: %q; rule: %q",
host,
res.Reason,
res.Rules[0].Text,
)
case res.IsFiltered && res.CanonName == "":
log.Debug("dnsforward: host %q is filtered, reason: %q", host, res.Reason)
pctx.Res = s.genDNSFilterMessage(pctx, res)
case res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) &&
case res.Reason.In(
filtering.Rewritten,
filtering.RewrittenRule,
filtering.FilteredSafeSearch) &&
res.CanonName != "" &&
len(res.IPList) == 0:
// Resolve the new canonical name, not the original host name. The
// original question is readded in processFilteringAfterResponse.
dctx.origQuestion = q
req.Question[0].Name = dns.Fqdn(res.CanonName)
case res.Reason == filtering.Rewritten:
case res.Reason.In(filtering.Rewritten, filtering.FilteredSafeSearch):
pctx.Res = s.getCNAMEWithIPs(req, res.IPList, res.CanonName)
case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts):
if err = s.filterDNSRewrite(req, res, pctx); err != nil {
Expand Down
3 changes: 2 additions & 1 deletion internal/dnsforward/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,8 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
return resultCodeSuccess
case
filtering.Rewritten,
filtering.RewrittenRule:
filtering.RewrittenRule,
filtering.FilteredSafeSearch:

if dctx.origQuestion.Name == "" {
// origQuestion is set in case we get only CNAME without IP from
Expand Down
3 changes: 0 additions & 3 deletions internal/filtering/safesearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ type SafeSearch interface {

// SafeSearchConfig is a struct with safe search related settings.
type SafeSearchConfig struct {
// CustomResolver is the resolver used by safe search.
CustomResolver Resolver `yaml:"-" json:"-"`

// Enabled indicates if safe search is enabled entirely.
Enabled bool `yaml:"enabled" json:"enabled"`

Expand Down
79 changes: 5 additions & 74 deletions internal/filtering/safesearch/safesearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@ package safesearch

import (
"bytes"
"context"
"encoding/binary"
"encoding/gob"
"fmt"
"net"
"net/netip"
"strings"
"sync"
Expand Down Expand Up @@ -67,7 +65,6 @@ type Default struct {
engine *urlfilter.DNSEngine

cache cache.Cache
resolver filtering.Resolver
logPrefix string
cacheTTL time.Duration
}
Expand All @@ -80,19 +77,13 @@ func NewDefault(
cacheSize uint,
cacheTTL time.Duration,
) (ss *Default, err error) {
var resolver filtering.Resolver = net.DefaultResolver
if conf.CustomResolver != nil {
resolver = conf.CustomResolver
}

ss = &Default{
mu: &sync.RWMutex{},

cache: cache.New(cache.Config{
EnableLRU: true,
MaxSize: cacheSize,
}),
resolver: resolver,
// Use %s, because the client safe-search names already contain double
// quotes.
logPrefix: fmt.Sprintf("safesearch %s: ", name),
Expand Down Expand Up @@ -228,18 +219,11 @@ func (ss *Default) searchHost(host string, qtype rules.RRType) (res *rules.DNSRe
// newResult creates Result object from rewrite rule. qtype must be either
// [dns.TypeA] or [dns.TypeAAAA], or [dns.TypeHTTPS]. If err is nil, res is
// never nil, so that the empty result is converted into a NODATA response.
//
// TODO(a.garipov): Use the main rewrite result mechanism used in
// [dnsforward.Server.filterDNSRequest]. Now we resolve IPs for CNAME to save
// them in the safe search cache.
func (ss *Default) newResult(
rewrite *rules.DNSRewrite,
qtype rules.RRType,
) (res *filtering.Result, err error) {
res = &filtering.Result{
Rules: []*filtering.ResultRule{{
FilterListID: rulelist.URLFilterIDSafeSearch,
}},
Reason: filtering.FilteredSafeSearch,
IsFiltered: true,
}
Expand All @@ -250,72 +234,19 @@ func (ss *Default) newResult(
return nil, fmt.Errorf("expected ip rewrite value, got %T(%[1]v)", rewrite.Value)
}

res.Rules[0].IP = ip

return res, nil
}

host := rewrite.NewCNAME
if host == "" {
return res, nil
}
res.Rules = []*filtering.ResultRule{{
FilterListID: rulelist.URLFilterIDSafeSearch,
IP: ip,
}}

res.CanonName = host
if qtype == dns.TypeHTTPS {
return res, nil
}

ss.log(log.DEBUG, "resolving %q", host)

ips, err := ss.resolver.LookupIP(context.Background(), qtypeToProto(qtype), host)
if err != nil {
return nil, fmt.Errorf("resolving cname: %w", err)
}

ss.log(log.DEBUG, "resolved %s", ips)

for _, ip := range ips {
// TODO(a.garipov): Remove this filtering once the resolver we use
// actually learns about network.
addr := fitToProto(ip, qtype)
if addr == (netip.Addr{}) {
continue
}

// TODO(e.burkov): Rules[0]?
res.Rules[0].IP = addr
}
res.CanonName = rewrite.NewCNAME

return res, nil
}

// qtypeToProto returns "ip4" for [dns.TypeA] and "ip6" for [dns.TypeAAAA].
// It panics for other types.
func qtypeToProto(qtype rules.RRType) (proto string) {
switch qtype {
case dns.TypeA:
return "ip4"
case dns.TypeAAAA:
return "ip6"
default:
panic(fmt.Errorf("safesearch: unsupported question type %s", dns.Type(qtype)))
}
}

// fitToProto returns a non-nil IP address if ip is the correct protocol version
// for qtype. qtype is expected to be either [dns.TypeA] or [dns.TypeAAAA].
func fitToProto(ip net.IP, qtype rules.RRType) (res netip.Addr) {
if ip4 := ip.To4(); qtype == dns.TypeA {
if ip4 != nil {
return netip.AddrFrom4([4]byte(ip4))
}
} else if ip = ip.To16(); ip != nil && qtype == dns.TypeAAAA {
return netip.AddrFrom16([16]byte(ip))
}

return netip.Addr{}
}

// setCacheResult stores data in cache for host. qtype is expected to be either
// [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) {
Expand Down
44 changes: 0 additions & 44 deletions internal/filtering/safesearch/safesearch_internal_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
package safesearch

import (
"context"
"net"
"net/netip"
"testing"
"time"

"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
Expand Down Expand Up @@ -79,47 +76,6 @@ func TestSafeSearchCacheYandex(t *testing.T) {
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
}

func TestSafeSearchCacheGoogle(t *testing.T) {
const domain = "www.google.ru"

ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})

res, err := ss.CheckHost(domain, testQType)
require.NoError(t, err)

assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)

resolver := &aghtest.Resolver{
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)

return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
},
}

ss = newForTest(t, defaultSafeSearchConf)
ss.resolver = resolver

// Lookup for safesearch domain.
rewrite := ss.searchHost(domain, testQType)

wantIP, _ := aghtest.HostToIPs(rewrite.NewCNAME)

res, err = ss.CheckHost(domain, testQType)
require.NoError(t, err)
require.Len(t, res.Rules, 1)

assert.Equal(t, wantIP, res.Rules[0].IP)

// Check cache.
cachedValue, isFound := ss.getCachedResult(domain, testQType)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)

assert.Equal(t, wantIP, cachedValue.Rules[0].IP)
}

const googleHost = "www.google.com"

var dnsRewriteSink *rules.DNSRewrite
Expand Down
Loading

0 comments on commit b78ad8f

Please sign in to comment.