diff --git a/component/arp/arp.go b/component/arp/arp.go new file mode 100644 index 0000000000..446082bf58 --- /dev/null +++ b/component/arp/arp.go @@ -0,0 +1,47 @@ +// Package arp provides a general interface to retrieve the ARP table +// on both Linux and Windows + +// kanged from https://github.com/situation-sh/situation/blob/main/modules/arp/arp.go + +package arp + +import ( + "net" + "net/netip" + "sync" + + "github.com/metacubex/mihomo/log" +) + +var ( + table []ARPEntry + once sync.Once +) + +type ARPEntry struct { + MAC net.HardwareAddr + IP net.IP +} + +func IsReserved(ip net.IP) bool { + if ip4 := ip.To4(); ip4 != nil { + return ip4[3] == 0 || ip4[3] == 255 + } + return false +} + +func IPToMac(ip netip.Addr) net.HardwareAddr { + once.Do(func() { + var err error + table, err = GetARPTable() + if err != nil { + log.Warnln("failed to get ARP table") + } + }) + for _, entry := range table { + if entry.IP.Equal(ip.AsSlice()) { + return entry.MAC + } + } + return nil +} diff --git a/component/arp/arp_linux.go b/component/arp/arp_linux.go new file mode 100644 index 0000000000..7e07bea09f --- /dev/null +++ b/component/arp/arp_linux.go @@ -0,0 +1,59 @@ +package arp + +import ( + "fmt" + "net" + + "github.com/sagernet/netlink" +) + +func neighMAC(n netlink.Neigh) net.HardwareAddr { + length := len(n.HardwareAddr) + mac := make(net.HardwareAddr, length) + copy(mac, n.HardwareAddr) + return mac +} + +func neighIP(n netlink.Neigh) net.IP { + length := len(n.IP) + ip := make(net.IP, length) + copy(ip, n.IP) + return ip +} + +func neighToARPEntry(n netlink.Neigh) ARPEntry { + return ARPEntry{ + MAC: neighMAC(n), + IP: neighIP(n), + } +} + +func GetARPTable() ([]ARPEntry, error) { + entries := make([]ARPEntry, 0) + + links, err := netlink.LinkList() + if err != nil { + return nil, err + } + + for _, link := range links { + attr := link.Attrs() + neighs, err := netlink.NeighList(attr.Index, 0) + if err != nil { + fmt.Println(err) + continue + } + for _, neigh := range neighs { + entry := neighToARPEntry(neigh) + + if IsReserved(entry.IP) { + continue + } + + if entry.IP.IsGlobalUnicast() { + entries = append(entries, entry) + } + } + } + return entries, nil +} diff --git a/component/arp/arp_other.go b/component/arp/arp_other.go new file mode 100644 index 0000000000..609847517a --- /dev/null +++ b/component/arp/arp_other.go @@ -0,0 +1,7 @@ +//go:build !linux && !windows + +package arp + +func GetARPTable() ([]ARPEntry, error) { + return nil, nil +} diff --git a/component/arp/arp_windows.go b/component/arp/arp_windows.go new file mode 100644 index 0000000000..fe47c568ff --- /dev/null +++ b/component/arp/arp_windows.go @@ -0,0 +1,22 @@ +package arp + +func GetARPTable() ([]ARPEntry, error) { + table, err := GetIpNetTable2() + if err != nil { + return nil, err + } + entries := make([]ARPEntry, 0) + for _, row := range table { + entry := row.ToARPEntry() + + // ignore 0 and 255 in case of IPv4 + if IsReserved(entry.IP) { + continue + } + + if entry.IP.IsGlobalUnicast() { + entries = append(entries, entry) + } + } + return entries, nil +} diff --git a/component/arp/get_ip_net_table2.go b/component/arp/get_ip_net_table2.go new file mode 100644 index 0000000000..c9f9ec7f74 --- /dev/null +++ b/component/arp/get_ip_net_table2.go @@ -0,0 +1,52 @@ +//go:build windows +// +build windows + +package arp + +import ( + "fmt" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var iphlpapi *windows.DLL + +func init() { + iphlpapi = windows.MustLoadDLL("Iphlpapi.dll") +} + +func GetIpNetTable2() (MIBIpNetTable2, error) { + proc, err := iphlpapi.FindProc("GetIpNetTable2") + if err != nil { + return nil, err + } + + free, err := iphlpapi.FindProc("FreeMibTable") + if err != nil { + return nil, err + } + + var data *rawMIBIpNetTable2 + errno, _, _ := proc.Call(0, uintptr(unsafe.Pointer(&data))) + defer free.Call(uintptr(unsafe.Pointer(data))) + + switch syscall.Errno(errno) { + case windows.ERROR_SUCCESS: + err = nil + case windows.ERROR_NOT_ENOUGH_MEMORY: + err = fmt.Errorf("insufficient memory resources are available to complete the operation") + case windows.ERROR_INVALID_PARAMETER: + err = fmt.Errorf("an invalid parameter was passed to the function") + case windows.ERROR_NOT_FOUND: + err = fmt.Errorf("no neighbor IP address entries as specified in the Family parameter were found") + case windows.ERROR_NOT_SUPPORTED: + err = fmt.Errorf("the IPv4 or IPv6 transports are not configured on the local computer") + default: + err = windows.GetLastError() + } + + table := data.parse() + return table, err +} diff --git a/component/arp/mib_ipnet_row2.go b/component/arp/mib_ipnet_row2.go new file mode 100644 index 0000000000..657395a1db --- /dev/null +++ b/component/arp/mib_ipnet_row2.go @@ -0,0 +1,143 @@ +//go:build windows +// +build windows + +package arp + +import ( + "encoding/binary" + "net" + "time" +) + +const MIBIpNetRow2Size = 88 +const SockAddrSize = 28 + +type SockAddrIn struct { + sinFamily uint16 + sinPort uint16 + sinAddr net.IP + sinZero []byte +} + +func NewSockAddrIn(buffer []byte) SockAddrIn { + addr := SockAddrIn{ + sinFamily: binary.LittleEndian.Uint16(buffer[:2]), + sinPort: binary.LittleEndian.Uint16(buffer[2:4]), + sinAddr: net.IP(make([]byte, 4)).To4(), + sinZero: make([]byte, 8), + } + copy(addr.sinAddr, buffer[4:8]) + copy(addr.sinZero, buffer[8:16]) + return addr +} + +func (s SockAddrIn) Family() uint16 { + return s.sinFamily +} + +func (s SockAddrIn) Addr() net.IP { + return s.sinAddr.To4() +} + +type SockAddrIn6 struct { + sin6Family uint16 + sin6Port uint16 + sin6FlowInfo uint32 + sin6Addr net.IP + sin6ScopeId uint32 +} + +func NewSockAddrIn6(buffer []byte) SockAddrIn6 { + addr := SockAddrIn6{ + sin6Family: binary.LittleEndian.Uint16(buffer[:2]), + sin6Port: binary.LittleEndian.Uint16(buffer[2:4]), + sin6FlowInfo: binary.LittleEndian.Uint32(buffer[4:8]), + sin6Addr: net.IP(make([]byte, 16)).To16(), + sin6ScopeId: binary.LittleEndian.Uint32(buffer[24:28]), + } + copy(addr.sin6Addr, buffer[8:24]) + return addr +} + +func (s SockAddrIn6) Family() uint16 { + return s.sin6Family +} + +func (s SockAddrIn6) Addr() net.IP { + return s.sin6Addr.To16() +} + +type SockAddr interface { + Family() uint16 + Addr() net.IP +} + +func parseSockAddr(buffer []byte) SockAddr { + sockType := binary.LittleEndian.Uint16(buffer[:2]) + switch sockType { + case 2: // IPv4 + return NewSockAddrIn(buffer[:SockAddrSize]) + case 23: // IPv6 + return NewSockAddrIn6(buffer[:SockAddrSize]) + default: + return nil + } +} + +func parsePhysicalAddress(buffer []byte, physicalAddressLength uint32) net.HardwareAddr { + pa := make(net.HardwareAddr, physicalAddressLength) + copy(pa, buffer[:physicalAddressLength]) + return pa +} + +type MIBIpNetRow2 struct { + address SockAddr + interfaceIndex uint32 + interfaceLuid uint64 + physicalAddress net.HardwareAddr + physicalAddressLength uint32 + flags uint32 + reachabilityTime time.Duration +} + +func (r MIBIpNetRow2) MAC() net.HardwareAddr { + mac := make(net.HardwareAddr, r.physicalAddressLength) + copy(mac, r.physicalAddress) + return mac +} + +func (r MIBIpNetRow2) IP() net.IP { + length := len(r.address.Addr()) + ip := make(net.IP, length) + copy(ip, r.address.Addr()) + return ip +} + +func (r MIBIpNetRow2) ToARPEntry() ARPEntry { + return ARPEntry{ + MAC: r.MAC(), + IP: r.IP(), + } +} + +type rawMIBIpNetRow2 struct { + address [28]byte + interfaceIndex uint32 + interfaceLuid uint64 + physicalAddress [32]byte + physicalAddressLength uint32 + flags uint32 + reachabilityTime uint32 +} + +func (r rawMIBIpNetRow2) Parse() MIBIpNetRow2 { + return MIBIpNetRow2{ + address: parseSockAddr(r.address[:]), + interfaceIndex: r.interfaceIndex, + interfaceLuid: r.interfaceLuid, + physicalAddress: parsePhysicalAddress(r.physicalAddress[:], r.physicalAddressLength), + physicalAddressLength: r.physicalAddressLength, + flags: r.flags, + reachabilityTime: time.Duration(r.reachabilityTime * uint32(time.Millisecond)), + } +} diff --git a/component/arp/mib_ipnet_table2.go b/component/arp/mib_ipnet_table2.go new file mode 100644 index 0000000000..6a3823a5cb --- /dev/null +++ b/component/arp/mib_ipnet_table2.go @@ -0,0 +1,22 @@ +//go:build windows +// +build windows + +package arp + +const anySize = 1 << 16 + +type MIBIpNetTable2 []MIBIpNetRow2 + +type rawMIBIpNetTable2 struct { + numEntries uint32 + padding uint32 + table [anySize]rawMIBIpNetRow2 +} + +func (r *rawMIBIpNetTable2) parse() MIBIpNetTable2 { + t := make([]MIBIpNetRow2, r.numEntries) + for i := 0; i < int(r.numEntries); i++ { + t[i] = r.table[i].Parse() + } + return t +} diff --git a/constant/rule.go b/constant/rule.go index a91ee6cb07..e221bd0060 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -19,6 +19,7 @@ const ( DstPort InPort DSCP + Mac InUser InName InType @@ -98,6 +99,8 @@ func (rt RuleType) String() string { return "Uid" case SubRules: return "SubRules" + case Mac: + return "Mac" case AND: return "AND" case OR: diff --git a/rules/common/mac.go b/rules/common/mac.go new file mode 100644 index 0000000000..9d9333d0d0 --- /dev/null +++ b/rules/common/mac.go @@ -0,0 +1,41 @@ +package common + +import ( + "strings" + + "github.com/metacubex/mihomo/component/arp" + C "github.com/metacubex/mihomo/constant" +) + +type Mac struct { + *Base + mac string + adapter string +} + +func (m *Mac) RuleType() C.RuleType { + return C.Mac +} + +func (m *Mac) Match(metadata *C.Metadata) (bool, string) { + if arp.IPToMac(metadata.SrcIP).String() == m.mac { + return true, m.adapter + } + return false, m.adapter +} + +func (m *Mac) Adapter() string { + return m.adapter +} + +func (m *Mac) Payload() string { + return m.mac +} + +func NewMAC(mac string, adapter string) (*Mac, error) { + return &Mac{ + Base: &Base{}, + mac: strings.ReplaceAll(strings.ToLower(mac), "-", ":"), + adapter: adapter, + }, nil +} diff --git a/rules/parser.go b/rules/parser.go index 9b1f552007..178d1d10ab 100644 --- a/rules/parser.go +++ b/rules/parser.go @@ -49,6 +49,8 @@ func ParseRule(tp, payload, target string, params []string, subRules map[string] parsed, parseErr = RC.NewPort(payload, target, C.InPort) case "DSCP": parsed, parseErr = RC.NewDSCP(payload, target) + case "MAC", "SRC-MAC": + parsed, parseErr = RC.NewMAC(payload, target) case "PROCESS-NAME": parsed, parseErr = RC.NewProcess(payload, target, true, false) case "PROCESS-PATH":