diff --git a/pkg/loadRequest/loadDns/dns_requester.go b/pkg/loadRequest/loadDns/dns_requester.go index 4afd2d90..8f7a8c4e 100644 --- a/pkg/loadRequest/loadDns/dns_requester.go +++ b/pkg/loadRequest/loadDns/dns_requester.go @@ -25,12 +25,14 @@ package loadDns import ( + "context" "crypto/tls" "github.com/kdoctor-io/kdoctor/pkg/k8s/apis/system/v1beta1" "github.com/kdoctor-io/kdoctor/pkg/utils/stats" "github.com/miekg/dns" "go.uber.org/zap" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "net" "sync" "time" ) @@ -142,12 +144,22 @@ func (b *Work) Finish() { b.report.finalize(total) } -func (b *Work) makeRequest(client *dns.Client, conn *dns.Conn, wg *sync.WaitGroup) { +func (b *Work) makeRequest(conn *dns.Conn, wg *sync.WaitGroup) { defer wg.Done() var msg *dns.Msg var rtt time.Duration var err error + client := new(dns.Client) + client.Net = b.Protocol + client.Timeout = time.Duration(b.Timeout) * time.Millisecond + if b.Protocol == "tcp-tls" { + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + client.TLSConfig = tlsConfig + } + if b.Protocol == "tcp" || b.Protocol == "tcp-tls" { msg, rtt, err = client.Exchange(b.Msg, b.ServerAddr) @@ -166,15 +178,10 @@ func (b *Work) makeRequest(client *dns.Client, conn *dns.Conn, wg *sync.WaitGrou } func (b *Work) runWorker() { - client := new(dns.Client) - client.Net = b.Protocol - client.Timeout = time.Duration(b.Timeout) * time.Millisecond - conn, _ := client.Dial(b.ServerAddr) - if b.Protocol == "tcp-tls" { - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - } - client.TLSConfig = tlsConfig + conn, err := b.makeConn() + if err != nil { + b.Logger.Sugar().Errorf("failed create dns conn,err=%v", err) + return } wg := &sync.WaitGroup{} for { @@ -185,7 +192,7 @@ func (b *Work) runWorker() { return case <-b.qosTokenBucket: wg.Add(1) - go b.makeRequest(client, conn, wg) + go b.makeRequest(conn, wg) } } } @@ -250,3 +257,12 @@ func (b *Work) AggregateMetric() *v1beta1.DNSMetrics { return metric } + +func (b *Work) makeConn() (*dns.Conn, error) { + var err error + d := net.Dialer{Timeout: time.Duration(b.Timeout) * time.Millisecond} + conn := new(dns.Conn) + conn.Conn, err = d.DialContext(context.Background(), "udp", b.ServerAddr) + + return conn, err +}