diff --git a/internal/dnsforward/clientid_test.go b/internal/dnsforward/clientid_test.go index c0db6c9d666..097930bc080 100644 --- a/internal/dnsforward/clientid_test.go +++ b/internal/dnsforward/clientid_test.go @@ -218,8 +218,8 @@ func TestServer_clientIDFromDNSContext(t *testing.T) { } srv := &Server{ - conf: ServerConfig{TLSConfig: tlsConf}, - logger: slogutil.NewDiscardLogger(), + conf: ServerConfig{TLSConfig: tlsConf}, + baseLogger: slogutil.NewDiscardLogger(), } var ( diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 7e12e8c3289..c2054217254 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -318,6 +318,7 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) { trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies) conf = &proxy.Config{ + Logger: s.baseLogger.With(slogutil.KeyPrefix, "dnsproxy"), HTTP3: srvConf.ServeHTTP3, Ratelimit: int(srvConf.Ratelimit), RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4, @@ -342,10 +343,6 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) { MessageConstructor: s, } - if s.logger != nil { - conf.Logger = s.logger.With(slogutil.KeyPrefix, "dnsproxy") - } - if srvConf.EDNSClientSubnet.UseCustom { // TODO(s.chzhen): Use netip.Addr instead of net.IP inside dnsproxy. conf.EDNSAddr = net.IP(srvConf.EDNSClientSubnet.CustomIP.AsSlice()) diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 98916797ede..107eea397a2 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -123,11 +123,9 @@ type Server struct { // access drops disallowed clients. access *accessManager - // logger is used for logging during server routines. - // - // TODO(d.kolyshev): Make it never nil. - // TODO(d.kolyshev): Use this logger. - logger *slog.Logger + // baseLogger is used to create loggers for other entities. It should not + // have a prefix and must not be nil. + baseLogger *slog.Logger // localDomainSuffix is the suffix used to detect internal hosts. It // must be a valid domain name plus dots on each side. @@ -246,7 +244,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { stats: p.Stats, queryLog: p.QueryLog, privateNets: p.PrivateNets, - logger: p.Logger.With(slogutil.KeyPrefix, "dnsforward"), + baseLogger: p.Logger, // TODO(e.burkov): Use some case-insensitive string comparison. localDomainSuffix: strings.ToLower(localDomainSuffix), etcHosts: etcHosts, @@ -615,7 +613,7 @@ func (s *Server) prepareInternalDNS() (err error) { return fmt.Errorf("preparing ipset settings: %w", err) } - ipsetLogger := s.logger.With(slogutil.KeyPrefix, "ipset") + ipsetLogger := s.baseLogger.With(slogutil.KeyPrefix, "ipset") s.ipset, err = newIpsetHandler(context.TODO(), ipsetLogger, ipsetList) if err != nil { // Don't wrap the error, because it's informative enough as is. @@ -685,7 +683,7 @@ func (s *Server) setupAddrProc() { s.addrProc = client.EmptyAddrProc{} } else { c := s.conf.AddrProcConf - c.BaseLogger = s.logger + c.BaseLogger = s.baseLogger c.DialContext = s.DialContext c.PrivateSubnets = s.privateNets c.UsePrivateRDNS = s.conf.UsePrivateRDNS @@ -729,6 +727,7 @@ func validateBlockingMode( func (s *Server) prepareInternalProxy() (err error) { srvConf := s.conf conf := &proxy.Config{ + Logger: s.baseLogger.With(slogutil.KeyPrefix, "dnsproxy"), CacheEnabled: true, CacheSizeBytes: 4096, PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig, @@ -741,10 +740,6 @@ func (s *Server) prepareInternalProxy() (err error) { MessageConstructor: s, } - if s.logger != nil { - conf.Logger = s.logger.With(slogutil.KeyPrefix, "dnsproxy") - } - err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration) if err != nil { return fmt.Errorf("invalid upstream mode: %w", err) diff --git a/internal/dnsforward/process_internal_test.go b/internal/dnsforward/process_internal_test.go index ba1133b9a53..9781b3b0305 100644 --- a/internal/dnsforward/process_internal_test.go +++ b/internal/dnsforward/process_internal_test.go @@ -431,7 +431,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { dnsFilter: createTestDNSFilter(t), dhcpServer: dhcp, localDomainSuffix: localDomainSuffix, - logger: slogutil.NewDiscardLogger(), + baseLogger: slogutil.NewDiscardLogger(), } req := &dns.Msg{ @@ -567,7 +567,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) { dnsFilter: createTestDNSFilter(t), dhcpServer: testDHCP, localDomainSuffix: tc.suffix, - logger: slogutil.NewDiscardLogger(), + baseLogger: slogutil.NewDiscardLogger(), } req := &dns.Msg{ diff --git a/internal/dnsforward/stats_test.go b/internal/dnsforward/stats_test.go index 8626c18003a..6e4d5d8623a 100644 --- a/internal/dnsforward/stats_test.go +++ b/internal/dnsforward/stats_test.go @@ -203,7 +203,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) { ql := &testQueryLog{} st := &testStats{} srv := &Server{ - logger: slogutil.NewDiscardLogger(), + baseLogger: slogutil.NewDiscardLogger(), queryLog: ql, stats: st, anonymizer: aghnet.NewIPMut(nil), diff --git a/internal/rdns/rdns.go b/internal/rdns/rdns.go index 4beb6cb315e..de84243cdfa 100644 --- a/internal/rdns/rdns.go +++ b/internal/rdns/rdns.go @@ -3,7 +3,6 @@ package rdns import ( "context" - "fmt" "log/slog" "net/netip" "time" @@ -96,7 +95,7 @@ func (r *Default) Process(ctx context.Context, ip netip.Addr) (host string, chan host, ttl, err := r.exchanger.Exchange(ip) if err != nil { - r.logger.DebugContext(ctx, "resolving ip", "ip", ip, slogutil.KeyError, err) + r.logger.DebugContext(ctx, "resolving", "ip", ip, slogutil.KeyError, err) } ttl = max(ttl, r.cacheTTL) @@ -108,7 +107,7 @@ func (r *Default) Process(ctx context.Context, ip netip.Addr) (host string, chan err = r.cache.Set(ip, item) if err != nil { - r.logger.DebugContext(ctx, "adding item to cache", "item", ip, slogutil.KeyError, err) + r.logger.DebugContext(ctx, "adding item to cache", "key", ip, slogutil.KeyError, err) } // TODO(e.burkov): The name doesn't change if it's neither stored in cache @@ -125,7 +124,7 @@ func (r *Default) findInCache(ctx context.Context, ip netip.Addr) (host string, r.logger.DebugContext( ctx, "retrieving item from cache", - "item", ip, + "key", ip, slogutil.KeyError, err, ) } @@ -133,17 +132,7 @@ func (r *Default) findInCache(ctx context.Context, ip netip.Addr) (host string, return "", true } - item, ok := val.(*cacheItem) - if !ok { - r.logger.DebugContext( - ctx, - "bad type of cache item", - "item", ip, - "type", fmt.Sprintf("%T", val), - ) - - return "", true - } + item := val.(*cacheItem) return item.host, time.Now().After(item.expiry) } diff --git a/internal/rdns/rdns_test.go b/internal/rdns/rdns_test.go index b8a9dfed973..f0b27ed87a0 100644 --- a/internal/rdns/rdns_test.go +++ b/internal/rdns/rdns_test.go @@ -1,7 +1,6 @@ package rdns_test import ( - "context" "net/netip" "testing" "time" @@ -114,14 +113,15 @@ func TestDefault_Process(t *testing.T) { return revAddr2, time.Hour, nil } + ctx := testutil.ContextWithTimeout(t, testTimeout) require.EventuallyWithT(t, func(t *assert.CollectT) { - got, changed = r.Process(context.TODO(), ip1) + got, changed = r.Process(ctx, ip1) assert.True(t, changed) assert.Equal(t, revAddr2, got) }, 2*cacheTTL, time.Millisecond*100) assert.Never(t, func() (changed bool) { - _, changed = r.Process(context.TODO(), ip1) + _, changed = r.Process(testutil.ContextWithTimeout(t, testTimeout), ip1) return changed }, 2*cacheTTL, time.Millisecond*100) diff --git a/internal/whois/whois.go b/internal/whois/whois.go index 7119013181b..3dffd7a1fc6 100644 --- a/internal/whois/whois.go +++ b/internal/whois/whois.go @@ -20,6 +20,7 @@ import ( "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/bluele/gcache" + "github.com/c2h5oh/datasize" ) const ( @@ -240,9 +241,9 @@ func (w *Default) query(ctx context.Context, target, serverAddr string) (data [] // queryAll queries WHOIS server and handles redirects. func (w *Default) queryAll(ctx context.Context, target string) (info map[string]string, err error) { server := net.JoinHostPort(w.serverAddr, w.portStr) - var data []byte for range w.maxRedirects { + var data []byte data, err = w.query(ctx, target, server) if err != nil { // Don't wrap the error since it's informative enough as is. @@ -252,7 +253,7 @@ func (w *Default) queryAll(ctx context.Context, target string) (info map[string] w.logger.DebugContext( ctx, "received response", - "size", len(data), + "size", datasize.ByteSize(len(data)), "source", server, "target", target, ) @@ -315,7 +316,7 @@ func (w *Default) requestInfo( item := toCacheItem(info, w.cacheTTL) err := w.cache.Set(ip, item) if err != nil { - w.logger.DebugContext(ctx, "adding item to cache", "item", ip, slogutil.KeyError, err) + w.logger.DebugContext(ctx, "adding item to cache", "key", ip, slogutil.KeyError, err) } }() @@ -350,7 +351,7 @@ func (w *Default) findInCache(ctx context.Context, ip netip.Addr) (wi *Info, exp w.logger.DebugContext( ctx, "retrieving item from cache", - "item", ip, + "key", ip, slogutil.KeyError, err, ) } @@ -358,19 +359,7 @@ func (w *Default) findInCache(ctx context.Context, ip netip.Addr) (wi *Info, exp return nil, false } - item, ok := val.(*cacheItem) - if !ok { - w.logger.DebugContext( - ctx, - "bad type of cache item", - "item", ip, - "type", fmt.Sprintf("%T", val), - ) - - return nil, false - } - - return fromCacheItem(item) + return fromCacheItem(val.(*cacheItem)) } // Info is the filtered WHOIS data for a runtime client.