Skip to content

Commit

Permalink
[app/dns] Support per-client configuration for fakedns (#2212)
Browse files Browse the repository at this point in the history
* Move `filterIP` from `hosts.go` to `dnscommon.go`

* Implement adding pools for fakedns.HolderMulti

* Implement per-client fakedns for DNS app

* Remove `dns.ClientWithIPOption` and replace with new programming model

* Implement JSON config support for new fakedns config

* Fix lint and tests

* Fix some codacy analysis
  • Loading branch information
Vigilans authored Dec 15, 2022
1 parent 32475d9 commit f8ac919
Show file tree
Hide file tree
Showing 24 changed files with 1,333 additions and 578 deletions.
562 changes: 312 additions & 250 deletions app/dns/config.pb.go

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions app/dns/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ option java_multiple_files = true;
import "common/net/address.proto";
import "common/net/destination.proto";
import "app/router/routercommon/common.proto";
import "app/dns/fakedns/fakedns.proto";

import "common/protoext/extensions.proto";

Expand All @@ -31,6 +32,8 @@ message NameServer {
repeated v2ray.core.app.router.routercommon.GeoIP geoip = 3;
repeated OriginalRule original_rules = 4;

v2ray.core.app.dns.fakedns.FakeDnsPoolMulti fake_dns = 11;

// Deprecated. Use fallback_strategy.
bool skipFallback = 6 [deprecated = true];

Expand Down Expand Up @@ -91,8 +94,12 @@ message Config {
// (IPv6).
bytes client_ip = 3;

// Static domain-ip mapping in DNS server.
repeated HostMapping static_hosts = 4;

// Global fakedns object.
v2ray.core.app.dns.fakedns.FakeDnsPoolMulti fake_dns = 16;

// Tag is the inbound tag of DNS client.
string tag = 6;

Expand Down Expand Up @@ -143,8 +150,12 @@ message SimplifiedConfig {
// (IPv6).
string client_ip = 3;

// Static domain-ip mapping in DNS server.
repeated HostMapping static_hosts = 4;

// Global fakedns object.
v2ray.core.app.dns.fakedns.FakeDnsPoolMulti fake_dns = 16;

// Tag is the inbound tag of DNS client.
string tag = 6;

Expand Down Expand Up @@ -204,6 +215,8 @@ message SimplifiedNameServer {
repeated v2ray.core.app.router.routercommon.GeoIP geoip = 3;
repeated OriginalRule original_rules = 4;

v2ray.core.app.dns.fakedns.FakeDnsPoolMulti fake_dns = 11;

// Deprecated. Use fallback_strategy.
bool skipFallback = 6 [deprecated = true];

Expand Down
172 changes: 115 additions & 57 deletions app/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"strings"
"sync"

core "github.com/v2fly/v2ray-core/v5"
"github.com/v2fly/v2ray-core/v5/app/dns/fakedns"
"github.com/v2fly/v2ray-core/v5/app/router"
"github.com/v2fly/v2ray-core/v5/common"
"github.com/v2fly/v2ray-core/v5/common/errors"
Expand All @@ -28,11 +30,11 @@ import (
// DNS is a DNS rely server.
type DNS struct {
sync.Mutex
ipOption dns.IPOption
hosts *StaticHosts
clients []*Client
ctx context.Context
clientTags map[string]bool
fakeDNSEngine *FakeDNSEngine
domainMatcher strmatcher.IndexMatcher
matcherInfos []DomainMatcherInfo
}
Expand Down Expand Up @@ -78,31 +80,31 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
clients = append(clients, NewLocalDNSClient())
}

s := &DNS{
hosts: hosts,
clients: clients,
ctx: ctx,
}

// Establish members related to global DNS state
domainMatcher, matcherInfos, err := establishDomainRules(config, clients, nsClientMap)
if err != nil {
s.clientTags = make(map[string]bool)
for _, client := range clients {
s.clientTags[client.tag] = true
}
if err := establishDomainRules(s, config, nsClientMap); err != nil {
return nil, err
}
if err := establishExpectedIPs(config, clients, nsClientMap); err != nil {
if err := establishExpectedIPs(s, config, nsClientMap); err != nil {
return nil, err
}
clientTags := make(map[string]bool)
for _, client := range clients {
clientTags[client.tag] = true
if err := establishFakeDNS(s, config, nsClientMap); err != nil {
return nil, err
}

return &DNS{
ipOption: toIPOption(config.QueryStrategy),
hosts: hosts,
clients: clients,
ctx: ctx,
clientTags: clientTags,
domainMatcher: domainMatcher,
matcherInfos: matcherInfos,
}, nil
return s, nil
}

func establishDomainRules(config *Config, clients []*Client, nsClientMap map[int]int) (strmatcher.IndexMatcher, []DomainMatcherInfo, error) {
func establishDomainRules(s *DNS, config *Config, nsClientMap map[int]int) error {
domainRuleCount := 0
for _, ns := range config.NameServer {
domainRuleCount += len(ns.PrioritizedDomain)
Expand All @@ -128,7 +130,7 @@ func establishDomainRules(config *Config, clients []*Client, nsClientMap map[int
for _, domain := range ns.PrioritizedDomain {
domainRule, err := toStrMatcher(domain.Type, domain.Domain)
if err != nil {
return nil, nil, newError("failed to create prioritized domain").Base(err).AtWarning()
return newError("failed to create prioritized domain").Base(err).AtWarning()
}
originalRuleIdx := ruleCurr
if ruleCurr < len(ns.OriginalRules) {
Expand All @@ -151,18 +153,20 @@ func establishDomainRules(config *Config, clients []*Client, nsClientMap map[int
domainRuleIdx: uint16(originalRuleIdx),
}
if err != nil {
return nil, nil, newError("failed to create prioritized domain").Base(err).AtWarning()
return newError("failed to create prioritized domain").Base(err).AtWarning()
}
}
clients[clientIdx].domains = rules
s.clients[clientIdx].domains = rules
}
if err := domainMatcher.Build(); err != nil {
return nil, nil, err
return err
}
return domainMatcher, matcherInfos, nil
s.domainMatcher = domainMatcher
s.matcherInfos = matcherInfos
return nil
}

func establishExpectedIPs(config *Config, clients []*Client, nsClientMap map[int]int) error {
func establishExpectedIPs(s *DNS, config *Config, nsClientMap map[int]int) error {
geoipContainer := router.GeoIPMatcherContainer{}
for nsIdx, ns := range config.NameServer {
clientIdx := nsClientMap[nsIdx]
Expand All @@ -174,11 +178,51 @@ func establishExpectedIPs(config *Config, clients []*Client, nsClientMap map[int
}
matchers = append(matchers, matcher)
}
clients[clientIdx].expectIPs = matchers
s.clients[clientIdx].expectIPs = matchers
}
return nil
}

func establishFakeDNS(s *DNS, config *Config, nsClientMap map[int]int) error {
fakeHolders := &fakedns.HolderMulti{}
fakeDefault := (*fakedns.HolderMulti)(nil)
if config.FakeDns != nil {
defaultEngine, err := fakeHolders.AddPoolMulti(config.FakeDns)
if err != nil {
return newError("fail to create fake dns").Base(err).AtWarning()
}
fakeDefault = defaultEngine
}
for nsIdx, ns := range config.NameServer {
clientIdx := nsClientMap[nsIdx]
if ns.FakeDns == nil {
continue
}
engine, err := fakeHolders.AddPoolMulti(ns.FakeDns)
if err != nil {
return newError("fail to create fake dns").Base(err).AtWarning()
}
s.clients[clientIdx].fakeDNS = NewFakeDNSServer(engine)
s.clients[clientIdx].queryStrategy.FakeEnable = true
}
// Do not create FakeDNSEngine feature if no FakeDNS server is configured
if fakeHolders.IsEmpty() {
return nil
}
// Add FakeDNSEngine feature when DNS feature is added for the first time
s.fakeDNSEngine = &FakeDNSEngine{dns: s, fakeHolders: fakeHolders, fakeDefault: fakeDefault}
return core.RequireFeatures(s.ctx, func(client dns.Client) error {
v := core.MustFromContext(s.ctx)
if v.GetFeature(dns.FakeDNSEngineType()) != nil {
return nil
}
if client, ok := client.(dns.ClientWithFakeDNS); ok {
return v.AddFeature(client.AsFakeDNSEngine())
}
return nil
})
}

// Type implements common.HasType.
func (*DNS) Type() interface{} {
return dns.ClientType()
Expand All @@ -200,25 +244,29 @@ func (s *DNS) IsOwnLink(ctx context.Context) bool {
return inbound != nil && s.clientTags[inbound.Tag]
}

// AsFakeDNSClient implements dns.ClientWithFakeDNS.
func (s *DNS) AsFakeDNSClient() dns.Client {
return &FakeDNSClient{DNS: s}
}

// AsFakeDNSEngine implements dns.ClientWithFakeDNS.
func (s *DNS) AsFakeDNSEngine() dns.FakeDNSEngine {
return s.fakeDNSEngine
}

// LookupIP implements dns.Client.
func (s *DNS) LookupIP(domain string) ([]net.IP, error) {
return s.lookupIPInternal(domain, s.ipOption)
return s.lookupIPInternal(domain, dns.IPOption{IPv4Enable: true, IPv6Enable: true, FakeEnable: false})
}

// LookupIPv4 implements dns.IPv4Lookup.
func (s *DNS) LookupIPv4(domain string) ([]net.IP, error) {
if option := s.ipOption.With(dns.IPOption{IPv4Enable: true}); option.IsValid() {
return s.lookupIPInternal(domain, option)
}
return nil, dns.ErrEmptyResponse
return s.lookupIPInternal(domain, dns.IPOption{IPv4Enable: true, FakeEnable: false})
}

// LookupIPv6 implements dns.IPv6Lookup.
func (s *DNS) LookupIPv6(domain string) ([]net.IP, error) {
if option := s.ipOption.With(dns.IPOption{IPv6Enable: true}); option.IsValid() {
return s.lookupIPInternal(domain, option)
}
return nil, dns.ErrEmptyResponse
return s.lookupIPInternal(domain, dns.IPOption{IPv6Enable: true, FakeEnable: false})
}

func (s *DNS) lookupIPInternal(domain string, option dns.IPOption) ([]net.IP, error) {
Expand Down Expand Up @@ -257,33 +305,20 @@ func (s *DNS) lookupIPInternal(domain string, option dns.IPOption) ([]net.IP, er
newError("failed to lookup ip for domain ", domain, " at server ", client.Name()).Base(err).WriteToLog()
}
if err != context.Canceled && err != context.DeadlineExceeded && err != errExpectedIPNonMatch {
return nil, err // Continues lookup for certain errors
return nil, err // Only continue lookup for certain errors
}
}

if len(errs) == 0 {
return nil, dns.ErrEmptyResponse
}
return nil, newError("returning nil for domain ", domain).Base(errors.Combine(errs...))
}

// GetIPOption implements ClientWithIPOption.
func (s *DNS) GetIPOption() *dns.IPOption {
return &s.ipOption
}

// SetQueryOption implements ClientWithIPOption.
func (s *DNS) SetQueryOption(isIPv4Enable, isIPv6Enable bool) {
s.ipOption.IPv4Enable = isIPv4Enable
s.ipOption.IPv6Enable = isIPv6Enable
}

// SetFakeDNSOption implements ClientWithIPOption.
func (s *DNS) SetFakeDNSOption(isFakeEnable bool) {
s.ipOption.FakeEnable = isFakeEnable
}

func (s *DNS) sortClients(domain string, option dns.IPOption) []*Client {
clients := make([]*Client, 0, len(s.clients))
clientUsed := make([]bool, len(s.clients))
clientNames := make([]string, 0, len(s.clients))
clientIdxs := make([]int, 0, len(s.clients))
domainRules := []string{}

// Priority domain matching
Expand All @@ -295,12 +330,12 @@ func (s *DNS) sortClients(domain string, option dns.IPOption) []*Client {
switch {
case clientUsed[info.clientIdx]:
continue
case !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS"):
case !option.FakeEnable && isFakeDNS(client.server):
continue
}
clientUsed[info.clientIdx] = true
clients = append(clients, client)
clientNames = append(clientNames, client.Name())
clientIdxs = append(clientIdxs, int(info.clientIdx))
}

// Default round-robin query
Expand All @@ -309,7 +344,7 @@ func (s *DNS) sortClients(domain string, option dns.IPOption) []*Client {
switch {
case clientUsed[idx]:
continue
case !option.FakeEnable && strings.EqualFold(client.Name(), "FakeDNS"):
case !option.FakeEnable && isFakeDNS(client.server):
continue
case client.fallbackStrategy == FallbackStrategy_Disabled:
continue
Expand All @@ -318,19 +353,42 @@ func (s *DNS) sortClients(domain string, option dns.IPOption) []*Client {
}
clientUsed[idx] = true
clients = append(clients, client)
clientNames = append(clientNames, client.Name())
clientIdxs = append(clientIdxs, idx)
}

if len(domainRules) > 0 {
newError("domain ", domain, " matches following rules: ", domainRules).AtDebug().WriteToLog()
}
if len(clientNames) > 0 {
newError("domain ", domain, " will use DNS in order: ", clientNames, " ", toReqTypes(option)).AtDebug().WriteToLog()
if len(clientIdxs) > 0 {
newError("domain ", domain, " will use DNS in order: ", s.formatClientNames(clientIdxs, option), " ", toReqTypes(option)).AtDebug().WriteToLog()
}

return clients
}

func (s *DNS) formatClientNames(clientIdxs []int, option dns.IPOption) []string {
clientNames := make([]string, 0, len(clientIdxs))
counter := make(map[string]uint, len(clientIdxs))
for _, clientIdx := range clientIdxs {
client := s.clients[clientIdx]
var name string
if option.With(client.queryStrategy).FakeEnable {
name = fmt.Sprintf("%s(DNS idx:%d)", client.fakeDNS.Name(), clientIdx)
} else {
name = client.Name()
}
counter[name]++
clientNames = append(clientNames, name)
}
for idx, clientIdx := range clientIdxs {
name := clientNames[idx]
if counter[name] > 1 {
clientNames[idx] = fmt.Sprintf("%s(DNS idx:%d)", name, clientIdx)
}
}
return clientNames
}

func init() {
common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
return New(ctx, config.(*Config))
Expand Down
10 changes: 10 additions & 0 deletions app/dns/dnscommon.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,13 @@ L:

return ipRecord, nil
}

func filterIP(ips []net.Address, option dns_feature.IPOption) []net.Address {
filtered := make([]net.Address, 0, len(ips))
for _, ip := range ips {
if (ip.Family().IsIPv4() && option.IPv4Enable) || (ip.Family().IsIPv6() && option.IPv6Enable) {
filtered = append(filtered, ip)
}
}
return filtered
}
Loading

0 comments on commit f8ac919

Please sign in to comment.