diff --git a/dns/resolver.go b/dns/resolver.go index 7e45252ee3..38c5d37506 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -19,6 +19,7 @@ import ( D "github.com/miekg/dns" "github.com/samber/lo" + "golang.org/x/exp/maps" "golang.org/x/sync/singleflight" ) @@ -370,6 +371,23 @@ type NameServer struct { PreferH3 bool } +func (ns NameServer) Equal(ns2 NameServer) bool { + defer func() { + // C.ProxyAdapter compare maybe panic, just ignore + recover() + }() + if ns.Net == ns2.Net && + ns.Addr == ns2.Addr && + ns.Interface == ns2.Interface && + ns.ProxyAdapter == ns2.ProxyAdapter && + ns.ProxyName == ns2.ProxyName && + maps.Equal(ns.Params, ns2.Params) && + ns.PreferH3 == ns2.PreferH3 { + return true + } + return false +} + type FallbackFilter struct { GeoIP bool GeoIPCode string @@ -399,20 +417,47 @@ func NewResolver(config Config) *Resolver { ipv6Timeout: time.Duration(config.IPv6Timeout) * time.Millisecond, } + var nameServerCache []struct { + NameServer + dnsClient + } + cacheTransform := func(nameserver []NameServer) (result []dnsClient) { + LOOP: + for _, ns := range nameserver { + for _, nsc := range nameServerCache { + if nsc.NameServer.Equal(ns) { + result = append(result, nsc.dnsClient) + continue LOOP + } + } + // not in cache + dc := transform([]NameServer{ns}, defaultResolver) + if len(dc) > 0 { + dc := dc[0] + nameServerCache = append(nameServerCache, struct { + NameServer + dnsClient + }{NameServer: ns, dnsClient: dc}) + result = append(result, dc) + } + } + return + } + r := &Resolver{ ipv6: config.IPv6, - main: transform(config.Main, defaultResolver), + main: cacheTransform(config.Main), lruCache: cache.New(cache.WithSize[string, *D.Msg](4096), cache.WithStale[string, *D.Msg](true)), hosts: config.Hosts, ipv6Timeout: time.Duration(config.IPv6Timeout) * time.Millisecond, } if len(config.Fallback) != 0 { - r.fallback = transform(config.Fallback, defaultResolver) + r.fallback = cacheTransform(config.Fallback) } if len(config.ProxyServer) != 0 { - r.proxyServer = transform(config.ProxyServer, defaultResolver) + r.proxyServer = cacheTransform(config.ProxyServer) } if len(config.Policy) != 0 { @@ -426,6 +471,7 @@ func NewResolver(config Config) *Resolver { triePolicy = nil } } + for _, p := range config.Policy { domain, nameserver := p.Extract() domain = strings.ToLower(domain) @@ -439,7 +485,7 @@ func NewResolver(config Config) *Resolver { insertTriePolicy() r.policy = append(r.policy, domainSetPolicy{ domainSetProvider: p, - dnsClients: transform(nameserver, defaultResolver), + dnsClients: cacheTransform(nameserver), }) continue } @@ -458,7 +504,7 @@ func NewResolver(config Config) *Resolver { r.policy = append(r.policy, geositePolicy{ matcher: matcher, inverse: inverse, - dnsClients: transform(nameserver, defaultResolver), + dnsClients: cacheTransform(nameserver), }) continue } @@ -466,7 +512,7 @@ func NewResolver(config Config) *Resolver { if triePolicy == nil { triePolicy = trie.New[[]dnsClient]() } - _ = triePolicy.Insert(domain, transform(nameserver, defaultResolver)) + _ = triePolicy.Insert(domain, cacheTransform(nameserver)) } insertTriePolicy() }