Skip to content

Commit

Permalink
Make DNS hijack optional
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Jun 1, 2024
1 parent 9da0dd9 commit f4c7198
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 112 deletions.
63 changes: 35 additions & 28 deletions redirect_iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os/exec"
"strings"

"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"

Expand Down Expand Up @@ -109,42 +110,48 @@ func (r *autoRedirect) setupIPTables(family int) error {
return err
}
}
var dnsServerAddress netip.Addr
if family == unix.AF_INET {
dnsServerAddress = r.tunOptions.Inet4Address[0].Addr().Next()
} else {
dnsServerAddress = r.tunOptions.Inet6Address[0].Addr().Next()
}
if len(routeAddress) > 0 {
for _, address := range routeAddress {
err = r.runShell(iptablesPath, "-r nat -A", tableNamePreRouteing,
"-d", address.String(), "-p udp --dport 53 -j DNAT --to", dnsServerAddress)
if err != nil {
return err
if r.tunOptions.DNSHijack {
dnsServer := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool {
return it.Is4() == (family == unix.AF_INET)
})
if !dnsServer.IsValid() {
if family == unix.AF_INET {
dnsServer = r.tunOptions.Inet4Address[0].Addr().Next()
} else {
dnsServer = r.tunOptions.Inet6Address[0].Addr().Next()
}
}
} else if len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0 {
for _, name := range r.tunOptions.IncludeInterface {
err = r.runShell(iptablesPath, "-r nat -A", tableNamePreRouteing,
"-i", name, "-p udp --dport 53 -j DNAT --to", dnsServerAddress)
if err != nil {
return err
if len(routeAddress) > 0 {
for _, address := range routeAddress {
err = r.runShell(iptablesPath, "-r nat -A", tableNamePreRouteing,
"-d", address.String(), "-p udp --dport 53 -j DNAT --to", dnsServer)
if err != nil {
return err
}
}
}
for _, uidRange := range r.tunOptions.IncludeUID {
for uid := uidRange.Start; uid <= uidRange.End; uid++ {
} else if len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0 {
for _, name := range r.tunOptions.IncludeInterface {
err = r.runShell(iptablesPath, "-r nat -A", tableNamePreRouteing,
"-m owner --uid-owner", uid, "-p udp --dport 53 -j DNAT --to", dnsServerAddress)
"-i", name, "-p udp --dport 53 -j DNAT --to", dnsServer)
if err != nil {
return err
}
}
}
} else {
err = r.runShell(iptablesPath, "-r nat -A", tableNamePreRouteing,
"-p udp --dport 53 -j DNAT --to", dnsServerAddress)
if err != nil {
return err
for _, uidRange := range r.tunOptions.IncludeUID {
for uid := uidRange.Start; uid <= uidRange.End; uid++ {
err = r.runShell(iptablesPath, "-r nat -A", tableNamePreRouteing,
"-m owner --uid-owner", uid, "-p udp --dport 53 -j DNAT --to", dnsServer)
if err != nil {
return err
}
}
}
} else {
err = r.runShell(iptablesPath, "-r nat -A", tableNamePreRouteing,
"-p udp --dport 53 -j DNAT --to", dnsServer)
if err != nil {
return err
}
}
}

Expand Down
9 changes: 5 additions & 4 deletions redirect_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package tun

import (
"context"
"net/netip"
"os"
"os/exec"
"runtime"

"github.com/sagernet/nftables"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
"net/netip"
"os"
"os/exec"
"runtime"

"golang.org/x/sys/unix"
)
Expand Down
50 changes: 28 additions & 22 deletions redirect_nftables.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/sagernet/nftables"
"github.com/sagernet/nftables/binaryutil"
"github.com/sagernet/nftables/expr"
"github.com/sagernet/sing/common"
F "github.com/sagernet/sing/common/format"

"golang.org/x/sys/unix"
Expand Down Expand Up @@ -138,34 +139,39 @@ func (r *autoRedirect) setupNFTables(family int) error {
}
}

var dnsServerAddress netip.Addr
if table.Family == nftables.TableFamilyIPv4 {
dnsServerAddress = r.tunOptions.Inet4Address[0].Addr().Next()
} else {
dnsServerAddress = r.tunOptions.Inet6Address[0].Addr().Next()
}

if len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0 {
for _, name := range r.tunOptions.IncludeInterface {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRouting,
Exprs: nftablesRuleIfName(expr.MetaKeyIIFNAME, name, append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServerAddress)...)...),
})
if r.tunOptions.DNSHijack {
dnsServer := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool {
return it.Is4() == (family == unix.AF_INET)
})
if !dnsServer.IsValid() {
if family == unix.AF_INET {
dnsServer = r.tunOptions.Inet4Address[0].Addr().Next()
} else {
dnsServer = r.tunOptions.Inet6Address[0].Addr().Next()
}
}
for _, uidRange := range r.tunOptions.IncludeUID {
if len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0 {
for _, name := range r.tunOptions.IncludeInterface {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRouting,
Exprs: nftablesRuleIfName(expr.MetaKeyIIFNAME, name, append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServer)...)...),
})
}
for _, uidRange := range r.tunOptions.IncludeUID {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRouting,
Exprs: nftablesRuleMetaUInt32Range(expr.MetaKeySKUID, uidRange, append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServer)...)...),
})
}
} else {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRouting,
Exprs: nftablesRuleMetaUInt32Range(expr.MetaKeySKUID, uidRange, append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServerAddress)...)...),
Exprs: append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServer)...),
})
}
} else {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRouting,
Exprs: append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServerAddress)...),
})
}

nft.AddRule(&nftables.Rule{
Expand Down
1 change: 0 additions & 1 deletion stack_gvisor_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Sock
TTL: route.DefaultTTL(),
TOS: 0,
}, packet)

if err != nil {
route.Stats().UDP.PacketSendErrors.Increment()
return wrapStackError(err)
Expand Down
3 changes: 3 additions & 0 deletions tun.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ type Options struct {
MTU uint32
GSO bool
AutoRoute bool
DNSHijack bool
DNSServers []netip.Addr
IPRoute2RuleIndex int
StrictRoute bool
Inet4RouteAddress []netip.Prefix
Inet6RouteAddress []netip.Prefix
Expand Down
45 changes: 29 additions & 16 deletions tun_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,6 @@ func (t *NativeTun) routes(tunLink netlink.Link) ([]netlink.Route, error) {
}), nil
}

const (
ruleStart = 9000
ruleEnd = ruleStart + 10
)

func (t *NativeTun) nextIndex6() int {
ruleList, err := netlink.RuleList(netlink.FAMILY_V6)
if err != nil {
Expand Down Expand Up @@ -411,9 +406,14 @@ func (t *NativeTun) rules() []*netlink.Rule {
var it *netlink.Rule

excludeRanges := t.options.ExcludedRanges()

ruleStart := t.options.IPRoute2RuleIndex
if ruleStart == 0 {
ruleStart = 9000
}
priority := ruleStart
priority6 := priority
nopPriority := ruleEnd
nopPriority := ruleStart + 10

for _, excludeRange := range excludeRanges {
if p4 {
Expand Down Expand Up @@ -788,6 +788,11 @@ func (t *NativeTun) unsetRules() error {
return err
}
for _, rule := range ruleList {
ruleStart := t.options.IPRoute2RuleIndex
if ruleStart == 0 {
ruleStart = 9000
}
ruleEnd := ruleStart + 10
if rule.Priority >= ruleStart && rule.Priority <= ruleEnd {
ruleToDel := netlink.NewRule()
ruleToDel.Family = rule.Family
Expand Down Expand Up @@ -820,20 +825,28 @@ func (t *NativeTun) routeUpdate(event int) {
}

func (t *NativeTun) setSearchDomainForSystemdResolved() {
if !t.options.DNSHijack {
return
}
ctlPath, err := exec.LookPath("resolvectl")
if err != nil {
return
}
var dnsServer []netip.Addr
if len(t.options.Inet4Address) > 0 {
dnsServer = append(dnsServer, t.options.Inet4Address[0].Addr().Next())
}
if len(t.options.Inet6Address) > 0 {
dnsServer = append(dnsServer, t.options.Inet6Address[0].Addr().Next())
dnsServer := t.options.DNSServers
if len(dnsServer) == 0 {
if len(t.options.Inet4Address) > 0 {
dnsServer = append(dnsServer, t.options.Inet4Address[0].Addr().Next())
}
if len(t.options.Inet6Address) > 0 {
dnsServer = append(dnsServer, t.options.Inet6Address[0].Addr().Next())
}
}
go shell.Exec(ctlPath, "domain", t.options.Name, "~.").Run()
if t.options.AutoRoute {
go shell.Exec(ctlPath, "default-route", t.options.Name, "true").Run()
go shell.Exec(ctlPath, append([]string{"dns", t.options.Name}, common.Map(dnsServer, netip.Addr.String)...)...).Run()
if len(dnsServer) == 0 {
return
}
go func() {
_ = shell.Exec(ctlPath, "domain", t.options.Name, "~.").Run()
_ = shell.Exec(ctlPath, "default-route", t.options.Name, "true").Run()
_ = shell.Exec(ctlPath, append([]string{"dns", t.options.Name}, common.Map(dnsServer, netip.Addr.String)...)...).Run()
}()
}
92 changes: 51 additions & 41 deletions tun_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,31 @@ func (t *NativeTun) configure() error {
if err != nil {
return E.Cause(err, "set ipv4 address")
}
err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), []netip.Addr{t.options.Inet4Address[0].Addr().Next()}, nil)
if err != nil {
return E.Cause(err, "set ipv4 dns")
if t.options.DNSHijack {
dnsServers := common.Filter(t.options.DNSServers, netip.Addr.Is4)
if len(dnsServers) == 0 {
dnsServers = []netip.Addr{t.options.Inet4Address[0].Addr().Next()}
}
err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), dnsServers, nil)
if err != nil {
return E.Cause(err, "set ipv4 dns")
}
}
}
if len(t.options.Inet6Address) > 0 {
err := luid.SetIPAddressesForFamily(winipcfg.AddressFamily(windows.AF_INET6), t.options.Inet6Address)
if err != nil {
return E.Cause(err, "set ipv6 address")
}
err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), []netip.Addr{t.options.Inet6Address[0].Addr().Next()}, nil)
if err != nil {
return E.Cause(err, "set ipv6 dns")
if t.options.DNSHijack {
dnsServers := common.Filter(t.options.DNSServers, netip.Addr.Is6)
if len(dnsServers) == 0 {
dnsServers = []netip.Addr{t.options.Inet6Address[0].Addr().Next()}
}
err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), dnsServers, nil)
if err != nil {
return E.Cause(err, "set ipv6 dns")
}
}
}
if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 {
Expand Down Expand Up @@ -284,42 +296,40 @@ func (t *NativeTun) configure() error {
}
}

blockDNSCondition := make([]winsys.FWPM_FILTER_CONDITION0, 2)
blockDNSCondition[0].FieldKey = winsys.FWPM_CONDITION_IP_PROTOCOL
blockDNSCondition[0].MatchType = winsys.FWP_MATCH_EQUAL
blockDNSCondition[0].ConditionValue.Type = winsys.FWP_UINT8
blockDNSCondition[0].ConditionValue.Value = uintptr(uint8(winsys.IPPROTO_UDP))
blockDNSCondition[1].FieldKey = winsys.FWPM_CONDITION_IP_REMOTE_PORT
blockDNSCondition[1].MatchType = winsys.FWP_MATCH_EQUAL
blockDNSCondition[1].ConditionValue.Type = winsys.FWP_UINT16
blockDNSCondition[1].ConditionValue.Value = uintptr(uint16(53))

blockDNSFilter4 := winsys.FWPM_FILTER0{}
blockDNSFilter4.FilterCondition = &blockDNSCondition[0]
blockDNSFilter4.NumFilterConditions = 2
blockDNSFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv4 dns")
blockDNSFilter4.SubLayerKey = subLayerKey
blockDNSFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4
blockDNSFilter4.Action.Type = winsys.FWP_ACTION_BLOCK
blockDNSFilter4.Weight.Type = winsys.FWP_UINT8
blockDNSFilter4.Weight.Value = uintptr(10)
err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter4, 0, &filterId)
if err != nil {
return os.NewSyscallError("FwpmFilterAdd0", err)
}
if t.options.DNSHijack {
blockDNSCondition := make([]winsys.FWPM_FILTER_CONDITION0, 1)
blockDNSCondition[0].FieldKey = winsys.FWPM_CONDITION_IP_REMOTE_PORT
blockDNSCondition[0].MatchType = winsys.FWP_MATCH_EQUAL
blockDNSCondition[0].ConditionValue.Type = winsys.FWP_UINT16
blockDNSCondition[0].ConditionValue.Value = uintptr(uint16(53))

blockDNSFilter4 := winsys.FWPM_FILTER0{}
blockDNSFilter4.FilterCondition = &blockDNSCondition[0]
blockDNSFilter4.NumFilterConditions = 2
blockDNSFilter4.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv4 dns")
blockDNSFilter4.SubLayerKey = subLayerKey
blockDNSFilter4.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V4
blockDNSFilter4.Action.Type = winsys.FWP_ACTION_BLOCK
blockDNSFilter4.Weight.Type = winsys.FWP_UINT8
blockDNSFilter4.Weight.Value = uintptr(10)
err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter4, 0, &filterId)
if err != nil {
return os.NewSyscallError("FwpmFilterAdd0", err)
}

blockDNSFilter6 := winsys.FWPM_FILTER0{}
blockDNSFilter6.FilterCondition = &blockDNSCondition[0]
blockDNSFilter6.NumFilterConditions = 2
blockDNSFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv6 dns")
blockDNSFilter6.SubLayerKey = subLayerKey
blockDNSFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6
blockDNSFilter6.Action.Type = winsys.FWP_ACTION_BLOCK
blockDNSFilter6.Weight.Type = winsys.FWP_UINT8
blockDNSFilter6.Weight.Value = uintptr(10)
err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter6, 0, &filterId)
if err != nil {
return os.NewSyscallError("FwpmFilterAdd0", err)
blockDNSFilter6 := winsys.FWPM_FILTER0{}
blockDNSFilter6.FilterCondition = &blockDNSCondition[0]
blockDNSFilter6.NumFilterConditions = 2
blockDNSFilter6.DisplayData = winsys.CreateDisplayData(TunnelType, "block ipv6 dns")
blockDNSFilter6.SubLayerKey = subLayerKey
blockDNSFilter6.LayerKey = winsys.FWPM_LAYER_ALE_AUTH_CONNECT_V6
blockDNSFilter6.Action.Type = winsys.FWP_ACTION_BLOCK
blockDNSFilter6.Weight.Type = winsys.FWP_UINT8
blockDNSFilter6.Weight.Value = uintptr(10)
err = winsys.FwpmFilterAdd0(engine, &blockDNSFilter6, 0, &filterId)
if err != nil {
return os.NewSyscallError("FwpmFilterAdd0", err)
}
}
}

Expand Down

0 comments on commit f4c7198

Please sign in to comment.