From f4c719896da0b783a929f35d4df7fa1d63929c07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 1 Jun 2024 14:13:56 +0800 Subject: [PATCH] Make DNS hijack optional --- redirect_iptables.go | 63 ++++++++++++++++-------------- redirect_linux.go | 9 +++-- redirect_nftables.go | 50 +++++++++++++----------- stack_gvisor_udp.go | 1 - tun.go | 3 ++ tun_linux.go | 45 ++++++++++++++-------- tun_windows.go | 92 ++++++++++++++++++++++++-------------------- 7 files changed, 151 insertions(+), 112 deletions(-) diff --git a/redirect_iptables.go b/redirect_iptables.go index 7c2f551..a13c616 100644 --- a/redirect_iptables.go +++ b/redirect_iptables.go @@ -7,6 +7,7 @@ import ( "os/exec" "strings" + "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" @@ -109,42 +110,48 @@ func (r *autoRedirect) setupIPTables(family int) error { return err } } - var dnsServerAddress netip.Addr - if family == unix.AF_INET { - dnsServerAddress = r.tunOptions.Inet4Address[0].Addr().Next() - } else { - dnsServerAddress = r.tunOptions.Inet6Address[0].Addr().Next() - } - if len(routeAddress) > 0 { - for _, address := range routeAddress { - err = r.runShell(iptablesPath, "-r nat -A", tableNamePreRouteing, - "-d", address.String(), "-p udp --dport 53 -j DNAT --to", dnsServerAddress) - if err != nil { - return err + if r.tunOptions.DNSHijack { + dnsServer := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool { + return it.Is4() == (family == unix.AF_INET) + }) + if !dnsServer.IsValid() { + if family == unix.AF_INET { + dnsServer = r.tunOptions.Inet4Address[0].Addr().Next() + } else { + dnsServer = r.tunOptions.Inet6Address[0].Addr().Next() } } - } else if len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0 { - for _, name := range r.tunOptions.IncludeInterface { - err = r.runShell(iptablesPath, "-r nat -A", tableNamePreRouteing, - "-i", name, "-p udp --dport 53 -j DNAT --to", dnsServerAddress) - if err != nil { - return err + if len(routeAddress) > 0 { + for _, address := range routeAddress { + err = r.runShell(iptablesPath, "-r nat -A", tableNamePreRouteing, + "-d", address.String(), "-p udp --dport 53 -j DNAT --to", dnsServer) + if err != nil { + return err + } } - } - for _, uidRange := range r.tunOptions.IncludeUID { - for uid := uidRange.Start; uid <= uidRange.End; uid++ { + } else if len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0 { + for _, name := range r.tunOptions.IncludeInterface { err = r.runShell(iptablesPath, "-r nat -A", tableNamePreRouteing, - "-m owner --uid-owner", uid, "-p udp --dport 53 -j DNAT --to", dnsServerAddress) + "-i", name, "-p udp --dport 53 -j DNAT --to", dnsServer) if err != nil { return err } } - } - } else { - err = r.runShell(iptablesPath, "-r nat -A", tableNamePreRouteing, - "-p udp --dport 53 -j DNAT --to", dnsServerAddress) - if err != nil { - return err + for _, uidRange := range r.tunOptions.IncludeUID { + for uid := uidRange.Start; uid <= uidRange.End; uid++ { + err = r.runShell(iptablesPath, "-r nat -A", tableNamePreRouteing, + "-m owner --uid-owner", uid, "-p udp --dport 53 -j DNAT --to", dnsServer) + if err != nil { + return err + } + } + } + } else { + err = r.runShell(iptablesPath, "-r nat -A", tableNamePreRouteing, + "-p udp --dport 53 -j DNAT --to", dnsServer) + if err != nil { + return err + } } } diff --git a/redirect_linux.go b/redirect_linux.go index af1525b..3b2fdb3 100644 --- a/redirect_linux.go +++ b/redirect_linux.go @@ -2,15 +2,16 @@ package tun import ( "context" + "net/netip" + "os" + "os/exec" + "runtime" + "github.com/sagernet/nftables" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" - "net/netip" - "os" - "os/exec" - "runtime" "golang.org/x/sys/unix" ) diff --git a/redirect_nftables.go b/redirect_nftables.go index 6b3d290..d63226c 100644 --- a/redirect_nftables.go +++ b/redirect_nftables.go @@ -8,6 +8,7 @@ import ( "github.com/sagernet/nftables" "github.com/sagernet/nftables/binaryutil" "github.com/sagernet/nftables/expr" + "github.com/sagernet/sing/common" F "github.com/sagernet/sing/common/format" "golang.org/x/sys/unix" @@ -138,34 +139,39 @@ func (r *autoRedirect) setupNFTables(family int) error { } } - var dnsServerAddress netip.Addr - if table.Family == nftables.TableFamilyIPv4 { - dnsServerAddress = r.tunOptions.Inet4Address[0].Addr().Next() - } else { - dnsServerAddress = r.tunOptions.Inet6Address[0].Addr().Next() - } - - if len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0 { - for _, name := range r.tunOptions.IncludeInterface { - nft.AddRule(&nftables.Rule{ - Table: table, - Chain: chainPreRouting, - Exprs: nftablesRuleIfName(expr.MetaKeyIIFNAME, name, append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServerAddress)...)...), - }) + if r.tunOptions.DNSHijack { + dnsServer := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool { + return it.Is4() == (family == unix.AF_INET) + }) + if !dnsServer.IsValid() { + if family == unix.AF_INET { + dnsServer = r.tunOptions.Inet4Address[0].Addr().Next() + } else { + dnsServer = r.tunOptions.Inet6Address[0].Addr().Next() + } } - for _, uidRange := range r.tunOptions.IncludeUID { + if len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0 { + for _, name := range r.tunOptions.IncludeInterface { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: nftablesRuleIfName(expr.MetaKeyIIFNAME, name, append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServer)...)...), + }) + } + for _, uidRange := range r.tunOptions.IncludeUID { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: nftablesRuleMetaUInt32Range(expr.MetaKeySKUID, uidRange, append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServer)...)...), + }) + } + } else { nft.AddRule(&nftables.Rule{ Table: table, Chain: chainPreRouting, - Exprs: nftablesRuleMetaUInt32Range(expr.MetaKeySKUID, uidRange, append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServerAddress)...)...), + Exprs: append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServer)...), }) } - } else { - nft.AddRule(&nftables.Rule{ - Table: table, - Chain: chainPreRouting, - Exprs: append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServerAddress)...), - }) } nft.AddRule(&nftables.Rule{ diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index 4fbb0de..7d85185 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -138,7 +138,6 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock TTL: route.DefaultTTL(), TOS: 0, }, packet) - if err != nil { route.Stats().UDP.PacketSendErrors.Increment() return wrapStackError(err) diff --git a/tun.go b/tun.go index e08a60b..f4cfb29 100644 --- a/tun.go +++ b/tun.go @@ -48,6 +48,9 @@ type Options struct { MTU uint32 GSO bool AutoRoute bool + DNSHijack bool + DNSServers []netip.Addr + IPRoute2RuleIndex int StrictRoute bool Inet4RouteAddress []netip.Prefix Inet6RouteAddress []netip.Prefix diff --git a/tun_linux.go b/tun_linux.go index 9244146..a92e841 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -359,11 +359,6 @@ func (t *NativeTun) routes(tunLink netlink.Link) ([]netlink.Route, error) { }), nil } -const ( - ruleStart = 9000 - ruleEnd = ruleStart + 10 -) - func (t *NativeTun) nextIndex6() int { ruleList, err := netlink.RuleList(netlink.FAMILY_V6) if err != nil { @@ -411,9 +406,14 @@ func (t *NativeTun) rules() []*netlink.Rule { var it *netlink.Rule excludeRanges := t.options.ExcludedRanges() + + ruleStart := t.options.IPRoute2RuleIndex + if ruleStart == 0 { + ruleStart = 9000 + } priority := ruleStart priority6 := priority - nopPriority := ruleEnd + nopPriority := ruleStart + 10 for _, excludeRange := range excludeRanges { if p4 { @@ -788,6 +788,11 @@ func (t *NativeTun) unsetRules() error { return err } for _, rule := range ruleList { + ruleStart := t.options.IPRoute2RuleIndex + if ruleStart == 0 { + ruleStart = 9000 + } + ruleEnd := ruleStart + 10 if rule.Priority >= ruleStart && rule.Priority <= ruleEnd { ruleToDel := netlink.NewRule() ruleToDel.Family = rule.Family @@ -820,20 +825,28 @@ func (t *NativeTun) routeUpdate(event int) { } func (t *NativeTun) setSearchDomainForSystemdResolved() { + if !t.options.DNSHijack { + return + } ctlPath, err := exec.LookPath("resolvectl") if err != nil { return } - var dnsServer []netip.Addr - if len(t.options.Inet4Address) > 0 { - dnsServer = append(dnsServer, t.options.Inet4Address[0].Addr().Next()) - } - if len(t.options.Inet6Address) > 0 { - dnsServer = append(dnsServer, t.options.Inet6Address[0].Addr().Next()) + dnsServer := t.options.DNSServers + if len(dnsServer) == 0 { + if len(t.options.Inet4Address) > 0 { + dnsServer = append(dnsServer, t.options.Inet4Address[0].Addr().Next()) + } + if len(t.options.Inet6Address) > 0 { + dnsServer = append(dnsServer, t.options.Inet6Address[0].Addr().Next()) + } } - go shell.Exec(ctlPath, "domain", t.options.Name, "~.").Run() - if t.options.AutoRoute { - go shell.Exec(ctlPath, "default-route", t.options.Name, "true").Run() - go shell.Exec(ctlPath, append([]string{"dns", t.options.Name}, common.Map(dnsServer, netip.Addr.String)...)...).Run() + if len(dnsServer) == 0 { + return } + go func() { + _ = shell.Exec(ctlPath, "domain", t.options.Name, "~.").Run() + _ = shell.Exec(ctlPath, "default-route", t.options.Name, "true").Run() + _ = shell.Exec(ctlPath, append([]string{"dns", t.options.Name}, common.Map(dnsServer, netip.Addr.String)...)...).Run() + }() } diff --git a/tun_windows.go b/tun_windows.go index db33954..44123af 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -72,9 +72,15 @@ func (t *NativeTun) configure() error { if err != nil { return E.Cause(err, "set ipv4 address") } - err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), []netip.Addr{t.options.Inet4Address[0].Addr().Next()}, nil) - if err != nil { - return E.Cause(err, "set ipv4 dns") + if t.options.DNSHijack { + dnsServers := common.Filter(t.options.DNSServers, netip.Addr.Is4) + if len(dnsServers) == 0 { + dnsServers = []netip.Addr{t.options.Inet4Address[0].Addr().Next()} + } + err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), dnsServers, nil) + if err != nil { + return E.Cause(err, "set ipv4 dns") + } } } if len(t.options.Inet6Address) > 0 { @@ -82,9 +88,15 @@ func (t *NativeTun) configure() error { if err != nil { return E.Cause(err, "set ipv6 address") } - err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), []netip.Addr{t.options.Inet6Address[0].Addr().Next()}, nil) - if err != nil { - return E.Cause(err, "set ipv6 dns") + if t.options.DNSHijack { + dnsServers := common.Filter(t.options.DNSServers, netip.Addr.Is6) + if len(dnsServers) == 0 { + dnsServers = []netip.Addr{t.options.Inet6Address[0].Addr().Next()} + } + err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), dnsServers, nil) + if err != nil { + return E.Cause(err, "set ipv6 dns") + } } } if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 { @@ -284,42 +296,40 @@ func (t *NativeTun) configure() error { } } - blockDNSCondition := make([]winsys.FWPM_FILTER_CONDITION0, 2) - blockDNSCondition[0].FieldKey = winsys.FWPM_CONDITION_IP_PROTOCOL - blockDNSCondition[0].MatchType = winsys.FWP_MATCH_EQUAL - blockDNSCondition[0].ConditionValue.Type = winsys.FWP_UINT8 - blockDNSCondition[0].ConditionValue.Value = uintptr(uint8(winsys.IPPROTO_UDP)) - blockDNSCondition[1].FieldKey = winsys.FWPM_CONDITION_IP_REMOTE_PORT - blockDNSCondition[1].MatchType = winsys.FWP_MATCH_EQUAL - blockDNSCondition[1].ConditionValue.Type = winsys.FWP_UINT16 - blockDNSCondition[1].ConditionValue.Value = uintptr(uint16(53)) - - blockDNSFilter4 := winsys.FWPM_FILTER0{} - blockDNSFilter4.FilterCondition = &blockDNSCondition[0] - blockDNSFilter4.NumFilterConditions = 2 - blockDNSFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv4 dns") - blockDNSFilter4.SubLayerKey = subLayerKey - blockDNSFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 - blockDNSFilter4.Action.Type = winsys.FWP_ACTION_BLOCK - blockDNSFilter4.Weight.Type = winsys.FWP_UINT8 - blockDNSFilter4.Weight.Value = uintptr(10) - err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter4, 0, &filterId) - if err != nil { - return os.NewSyscallError("FwpmFilterAdd0", err) - } + if t.options.DNSHijack { + blockDNSCondition := make([]winsys.FWPM_FILTER_CONDITION0, 1) + blockDNSCondition[0].FieldKey = winsys.FWPM_CONDITION_IP_REMOTE_PORT + blockDNSCondition[0].MatchType = winsys.FWP_MATCH_EQUAL + blockDNSCondition[0].ConditionValue.Type = winsys.FWP_UINT16 + blockDNSCondition[0].ConditionValue.Value = uintptr(uint16(53)) + + blockDNSFilter4 := winsys.FWPM_FILTER0{} + blockDNSFilter4.FilterCondition = &blockDNSCondition[0] + blockDNSFilter4.NumFilterConditions = 2 + blockDNSFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv4 dns") + blockDNSFilter4.SubLayerKey = subLayerKey + blockDNSFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4 + blockDNSFilter4.Action.Type = winsys.FWP_ACTION_BLOCK + blockDNSFilter4.Weight.Type = winsys.FWP_UINT8 + blockDNSFilter4.Weight.Value = uintptr(10) + err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter4, 0, &filterId) + if err != nil { + return os.NewSyscallError("FwpmFilterAdd0", err) + } - blockDNSFilter6 := winsys.FWPM_FILTER0{} - blockDNSFilter6.FilterCondition = &blockDNSCondition[0] - blockDNSFilter6.NumFilterConditions = 2 - blockDNSFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv6 dns") - blockDNSFilter6.SubLayerKey = subLayerKey - blockDNSFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 - blockDNSFilter6.Action.Type = winsys.FWP_ACTION_BLOCK - blockDNSFilter6.Weight.Type = winsys.FWP_UINT8 - blockDNSFilter6.Weight.Value = uintptr(10) - err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter6, 0, &filterId) - if err != nil { - return os.NewSyscallError("FwpmFilterAdd0", err) + blockDNSFilter6 := winsys.FWPM_FILTER0{} + blockDNSFilter6.FilterCondition = &blockDNSCondition[0] + blockDNSFilter6.NumFilterConditions = 2 + blockDNSFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv6 dns") + blockDNSFilter6.SubLayerKey = subLayerKey + blockDNSFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6 + blockDNSFilter6.Action.Type = winsys.FWP_ACTION_BLOCK + blockDNSFilter6.Weight.Type = winsys.FWP_UINT8 + blockDNSFilter6.Weight.Value = uintptr(10) + err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter6, 0, &filterId) + if err != nil { + return os.NewSyscallError("FwpmFilterAdd0", err) + } } }