diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index 6bdea236d..d0c7ebe3f 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -7,11 +7,8 @@ package netstack import ( "context" - "crypto/rand" - "encoding/binary" "errors" "fmt" - "io" "net" "os" "strconv" @@ -20,7 +17,6 @@ import ( "golang.zx2c4.com/wireguard/tun" - "golang.org/x/net/dns/dnsmessage" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -39,6 +35,7 @@ type netTun struct { incomingPacket chan buffer.VectorisedView mtu int dnsServers []net.IP + resolver *net.Resolver hasV4, hasV6 bool } type endpoint netTun @@ -129,6 +126,11 @@ func CreateNetTUN(localAddresses []net.IP, dnsServers []net.IP, mtu int) (tun.De dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) } + dev.resolver = &net.Resolver{ + PreferGo: true, + Dial: (*Net)(dev).dialDNS, + } + dev.events <- tun.EventUp return dev, (*Net)(dev), nil } @@ -247,458 +249,14 @@ func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { } var ( - errNoSuchHost = errors.New("no such host") - errLameReferral = errors.New("lame referral") - errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message") - errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message") - errServerMisbehaving = errors.New("server misbehaving") - errInvalidDNSResponse = errors.New("invalid DNS response") - errNoAnswerFromDNSServer = errors.New("no answer from DNS server") - errServerTemporarilyMisbehaving = errors.New("server misbehaving") - errCanceled = errors.New("operation was canceled") - errTimeout = errors.New("i/o timeout") - errNumericPort = errors.New("port must be numeric") - errNoSuitableAddress = errors.New("no suitable address found") - errMissingAddress = errors.New("missing address") + errCanceled = errors.New("operation was canceled") + errTimeout = errors.New("i/o timeout") + errNumericPort = errors.New("port must be numeric") + errNoSuitableAddress = errors.New("no suitable address found") + errMissingAddress = errors.New("missing address") ) -func (net *Net) LookupHost(host string) (addrs []string, err error) { - return net.LookupContextHost(context.Background(), host) -} - -func isDomainName(s string) bool { - l := len(s) - if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { - return false - } - last := byte('.') - nonNumeric := false - partlen := 0 - for i := 0; i < len(s); i++ { - c := s[i] - switch { - default: - return false - case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': - nonNumeric = true - partlen++ - case '0' <= c && c <= '9': - partlen++ - case c == '-': - if last == '.' { - return false - } - partlen++ - nonNumeric = true - case c == '.': - if last == '.' || last == '-' { - return false - } - if partlen > 63 || partlen == 0 { - return false - } - partlen = 0 - } - last = c - } - if last == '-' || partlen > 63 { - return false - } - return nonNumeric -} - -func randU16() uint16 { - var b [2]byte - _, err := rand.Read(b[:]) - if err != nil { - panic(err) - } - return binary.LittleEndian.Uint16(b[:]) -} - -func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) { - id = randU16() - b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true}) - b.EnableCompression() - if err := b.StartQuestions(); err != nil { - return 0, nil, nil, err - } - if err := b.Question(q); err != nil { - return 0, nil, nil, err - } - tcpReq, err = b.Finish() - udpReq = tcpReq[2:] - l := len(tcpReq) - 2 - tcpReq[0] = byte(l >> 8) - tcpReq[1] = byte(l) - return id, udpReq, tcpReq, err -} - -func equalASCIIName(x, y dnsmessage.Name) bool { - if x.Length != y.Length { - return false - } - for i := 0; i < int(x.Length); i++ { - a := x.Data[i] - b := y.Data[i] - if 'A' <= a && a <= 'Z' { - a += 0x20 - } - if 'A' <= b && b <= 'Z' { - b += 0x20 - } - if a != b { - return false - } - } - return true -} - -func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool { - if !respHdr.Response { - return false - } - if reqID != respHdr.ID { - return false - } - if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) { - return false - } - return true -} - -func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { - if _, err := c.Write(b); err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - b = make([]byte, 512) - for { - n, err := c.Read(b) - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - var p dnsmessage.Parser - h, err := p.Start(b[:n]) - if err != nil { - continue - } - q, err := p.Question() - if err != nil || !checkResponse(id, query, h, q) { - continue - } - return p, h, nil - } -} - -func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { - if _, err := c.Write(b); err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - b = make([]byte, 1280) - if _, err := io.ReadFull(c, b[:2]); err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - l := int(b[0])<<8 | int(b[1]) - if l > len(b) { - b = make([]byte, l) - } - n, err := io.ReadFull(c, b[:l]) - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - var p dnsmessage.Parser - h, err := p.Start(b[:n]) - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage - } - q, err := p.Question() - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage - } - if !checkResponse(id, query, h, q) { - return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse - } - return p, h, nil -} - -func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { - q.Class = dnsmessage.ClassINET - id, udpReq, tcpReq, err := newRequest(q) - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage - } - - for _, useUDP := range []bool{true, false} { - ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) - defer cancel() - - var c net.Conn - var err error - if useUDP { - c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53}) - } else { - c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53}) - } - - if err != nil { - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - if d, ok := ctx.Deadline(); ok && !d.IsZero() { - c.SetDeadline(d) - } - var p dnsmessage.Parser - var h dnsmessage.Header - if useUDP { - p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) - } else { - p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) - } - c.Close() - if err != nil { - if err == context.Canceled { - err = errCanceled - } else if err == context.DeadlineExceeded { - err = errTimeout - } - return dnsmessage.Parser{}, dnsmessage.Header{}, err - } - if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone { - return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse - } - if h.Truncated { - continue - } - return p, h, nil - } - return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer -} - -func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { - if h.RCode == dnsmessage.RCodeNameError { - return errNoSuchHost - } - _, err := p.AnswerHeader() - if err != nil && err != dnsmessage.ErrSectionDone { - return errCannotUnmarshalDNSMessage - } - if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone { - return errLameReferral - } - if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError { - if h.RCode == dnsmessage.RCodeServerFailure { - return errServerTemporarilyMisbehaving - } - return errServerMisbehaving - } - return nil -} - -func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error { - for { - h, err := p.AnswerHeader() - if err == dnsmessage.ErrSectionDone { - return errNoSuchHost - } - if err != nil { - return errCannotUnmarshalDNSMessage - } - if h.Type == qtype { - return nil - } - if err := p.SkipAnswer(); err != nil { - return errCannotUnmarshalDNSMessage - } - } -} - -func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { - var lastErr error - - n, err := dnsmessage.NewName(name) - if err != nil { - return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage - } - q := dnsmessage.Question{ - Name: n, - Type: qtype, - Class: dnsmessage.ClassINET, - } - - for i := 0; i < 2; i++ { - for _, server := range tnet.dnsServers { - p, h, err := tnet.exchange(ctx, server, q, time.Second*5) - if err != nil { - dnsErr := &net.DNSError{ - Err: err.Error(), - Name: name, - Server: server.String(), - } - if nerr, ok := err.(net.Error); ok && nerr.Timeout() { - dnsErr.IsTimeout = true - } - if _, ok := err.(*net.OpError); ok { - dnsErr.IsTemporary = true - } - lastErr = dnsErr - continue - } - - if err := checkHeader(&p, h); err != nil { - dnsErr := &net.DNSError{ - Err: err.Error(), - Name: name, - Server: server.String(), - } - if err == errServerTemporarilyMisbehaving { - dnsErr.IsTemporary = true - } - if err == errNoSuchHost { - dnsErr.IsNotFound = true - return p, server.String(), dnsErr - } - lastErr = dnsErr - continue - } - - err = skipToAnswer(&p, qtype) - if err == nil { - return p, server.String(), nil - } - lastErr = &net.DNSError{ - Err: err.Error(), - Name: name, - Server: server.String(), - } - if err == errNoSuchHost { - lastErr.(*net.DNSError).IsNotFound = true - return p, server.String(), lastErr - } - } - } - return dnsmessage.Parser{}, "", lastErr -} - -func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) { - if host == "" || (!tnet.hasV6 && !tnet.hasV4) { - return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} - } - zlen := len(host) - if strings.IndexByte(host, ':') != -1 { - if zidx := strings.LastIndexByte(host, '%'); zidx != -1 { - zlen = zidx - } - } - if ip := net.ParseIP(host[:zlen]); ip != nil { - return []string{host[:zlen]}, nil - } - - if !isDomainName(host) { - return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} - } - type result struct { - p dnsmessage.Parser - server string - error - } - var addrsV4, addrsV6 []net.IP - lanes := 0 - if tnet.hasV4 { - lanes++ - } - if tnet.hasV6 { - lanes++ - } - lane := make(chan result, lanes) - var lastErr error - if tnet.hasV4 { - go func() { - p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA) - lane <- result{p, server, err} - }() - } - if tnet.hasV6 { - go func() { - p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA) - lane <- result{p, server, err} - }() - } - for l := 0; l < lanes; l++ { - result := <-lane - if result.error != nil { - if lastErr == nil { - lastErr = result.error - } - continue - } - - loop: - for { - h, err := result.p.AnswerHeader() - if err != nil && err != dnsmessage.ErrSectionDone { - lastErr = &net.DNSError{ - Err: errCannotMarshalDNSMessage.Error(), - Name: host, - Server: result.server, - } - } - if err != nil { - break - } - switch h.Type { - case dnsmessage.TypeA: - a, err := result.p.AResource() - if err != nil { - lastErr = &net.DNSError{ - Err: errCannotMarshalDNSMessage.Error(), - Name: host, - Server: result.server, - } - break loop - } - addrsV4 = append(addrsV4, net.IP(a.A[:])) - - case dnsmessage.TypeAAAA: - aaaa, err := result.p.AAAAResource() - if err != nil { - lastErr = &net.DNSError{ - Err: errCannotMarshalDNSMessage.Error(), - Name: host, - Server: result.server, - } - break loop - } - addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:])) - - default: - if err := result.p.SkipAnswer(); err != nil { - lastErr = &net.DNSError{ - Err: errCannotMarshalDNSMessage.Error(), - Name: host, - Server: result.server, - } - break loop - } - continue - } - } - } - // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled - var addrs []net.IP - if tnet.hasV6 { - addrs = append(addrsV6, addrsV4...) - } else { - addrs = append(addrsV4, addrsV6...) - } - - if len(addrs) == 0 && lastErr != nil { - return nil, lastErr - } - saddrs := make([]string, 0, len(addrs)) - for _, ip := range addrs { - saddrs = append(saddrs, ip.String()) - } - return saddrs, nil -} +func (net *Net) Resolver() *net.Resolver { return net.resolver } func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) { if deadline.IsZero() { @@ -748,20 +306,25 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. if err != nil || port < 0 || port > 65535 { return nil, &net.OpError{Op: "dial", Err: errNumericPort} } - allAddr, err := tnet.LookupContextHost(ctx, host) - if err != nil { - return nil, &net.OpError{Op: "dial", Err: err} - } + var addrs []net.IP - for _, addr := range allAddr { - if strings.IndexByte(addr, ':') != -1 && acceptV6 { - addrs = append(addrs, net.ParseIP(addr)) - } else if strings.IndexByte(addr, '.') != -1 && acceptV4 { - addrs = append(addrs, net.ParseIP(addr)) + if addr := net.ParseIP(host); addr != nil { + addrs = []net.IP{addr} + } else { + allAddr, err := tnet.resolver.LookupHost(ctx, host) + if err != nil { + return nil, &net.OpError{Op: "dial", Err: err} + } + for _, addr := range allAddr { + if strings.IndexByte(addr, ':') != -1 && acceptV6 { + addrs = append(addrs, net.ParseIP(addr)) + } else if strings.IndexByte(addr, '.') != -1 && acceptV4 { + addrs = append(addrs, net.ParseIP(addr)) + } + } + if len(addrs) == 0 && len(allAddr) != 0 { + return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress} } - } - if len(addrs) == 0 && len(allAddr) != 0 { - return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress} } var firstErr error @@ -816,3 +379,32 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net. func (tnet *Net) Dial(network, address string) (net.Conn, error) { return tnet.DialContext(context.Background(), network, address) } + +func (tnet *Net) dialDNS(ctx context.Context, network, address string) (net.Conn, error) { + if len(tnet.dnsServers) == 0 { + return tnet.DialContext(ctx, network, address) + } + + dnsIPs := make(map[string]struct{}, len(tnet.dnsServers)) + for _, dnsServer := range tnet.dnsServers { + ipAddress := dnsServer.String() + if host, _, err := net.SplitHostPort(address); err != nil && host == ipAddress { + return tnet.DialContext(ctx, network, address) + } + dnsIPs[ipAddress] = struct{}{} + } + + var lastErr error + for ipAddress := range dnsIPs { + conn, err := tnet.DialContext(ctx, network, net.JoinHostPort(ipAddress, "53")) + if err != nil { + if nerr, ok := err.(*net.OpError); ok { + lastErr = nerr + continue + } + return nil, err + } + return conn, nil + } + return nil, lastErr +}