Skip to content

Commit

Permalink
feat: introduce a new robust approach to handle tproxy udp. (MetaCube…
Browse files Browse the repository at this point in the history
  • Loading branch information
Ovear authored Feb 17, 2023
1 parent b2d1cea commit 8e4dfbd
Show file tree
Hide file tree
Showing 10 changed files with 246 additions and 12 deletions.
73 changes: 69 additions & 4 deletions component/nat/table.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package nat

import (
"net"
"sync"

C "github.com/Dreamacro/clash/constant"
Expand All @@ -10,16 +11,24 @@ type Table struct {
mapping sync.Map
}

func (t *Table) Set(key string, pc C.PacketConn) {
t.mapping.Store(key, pc)
type Entry struct {
PacketConn C.PacketConn
LocalUDPConnMap sync.Map
}

func (t *Table) Set(key string, e C.PacketConn) {
t.mapping.Store(key, &Entry{
PacketConn: e,
LocalUDPConnMap: sync.Map{},
})
}

func (t *Table) Get(key string) C.PacketConn {
item, exist := t.mapping.Load(key)
entry, exist := t.getEntry(key)
if !exist {
return nil
}
return item.(C.PacketConn)
return entry.PacketConn
}

func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) {
Expand All @@ -31,6 +40,62 @@ func (t *Table) Delete(key string) {
t.mapping.Delete(key)
}

func (t *Table) GetLocalConn(lAddr, rAddr string) *net.UDPConn {
entry, exist := t.getEntry(lAddr)
if !exist {
return nil
}
item, exist := entry.LocalUDPConnMap.Load(rAddr)
if !exist {
return nil
}
return item.(*net.UDPConn)
}

func (t *Table) AddLocalConn(lAddr, rAddr string, conn *net.UDPConn) bool {
entry, exist := t.getEntry(lAddr)
if !exist {
return false
}
entry.LocalUDPConnMap.Store(rAddr, conn)
return true
}

func (t *Table) RangeLocalConn(lAddr string, f func(key, value any) bool) {
entry, exist := t.getEntry(lAddr)
if !exist {
return
}
entry.LocalUDPConnMap.Range(f)
}

func (t *Table) GetOrCreateLockForLocalConn(lAddr, key string) (*sync.Cond, bool) {
entry, loaded := t.getEntry(lAddr)
if !loaded {
return nil, false
}
item, loaded := entry.LocalUDPConnMap.LoadOrStore(key, sync.NewCond(&sync.Mutex{}))
return item.(*sync.Cond), loaded
}

func (t *Table) DeleteLocalConnMap(lAddr, key string) {
entry, loaded := t.getEntry(lAddr)
if !loaded {
return
}
entry.LocalUDPConnMap.Delete(key)
}

func (t *Table) getEntry(key string) (*Entry, bool) {
item, ok := t.mapping.Load(key)
// This should not happen usually since this function called after PacketConn created
if !ok {
return nil, false
}
entry, ok := item.(*Entry)
return entry, ok
}

// New return *Cache
func New() *Table {
return &Table{}
Expand Down
25 changes: 25 additions & 0 deletions constant/adapters.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/netip"
"sync"
"time"

"github.com/Dreamacro/clash/component/dialer"
Expand Down Expand Up @@ -216,6 +217,10 @@ type UDPPacket interface {

// LocalAddr returns the source IP/Port of packet
LocalAddr() net.Addr

SetNatTable(natTable NatTable)

SetUdpInChan(in chan<- PacketAdapter)
}

type UDPPacketInAddr interface {
Expand All @@ -227,3 +232,23 @@ type PacketAdapter interface {
UDPPacket
Metadata() *Metadata
}

type NatTable interface {
Set(key string, e PacketConn)

Get(key string) PacketConn

GetOrCreateLock(key string) (*sync.Cond, bool)

Delete(key string)

GetLocalConn(lAddr, rAddr string) *net.UDPConn

AddLocalConn(lAddr, rAddr string, conn *net.UDPConn) bool

RangeLocalConn(lAddr string, f func(key, value any) bool)

GetOrCreateLockForLocalConn(lAddr, key string) (*sync.Cond, bool)

DeleteLocalConnMap(lAddr, key string)
}
8 changes: 8 additions & 0 deletions listener/shadowsocks/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/url"

"github.com/Dreamacro/clash/common/pool"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/socks5"
)

Expand Down Expand Up @@ -44,6 +45,13 @@ func (c *packet) InAddr() net.Addr {
return c.pc.LocalAddr()
}

func (c *packet) SetNatTable(natTable C.NatTable) {
// no need
}

func (c *packet) SetUdpInChan(in chan<- C.PacketAdapter) {
// no need
}
func ParseSSURL(s string) (addr, cipher, password string, err error) {
u, err := url.Parse(s)
if err != nil {
Expand Down
8 changes: 8 additions & 0 deletions listener/sing/sing.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,11 @@ func (c *packet) Drop() {
func (c *packet) InAddr() net.Addr {
return c.lAddr
}

func (c *packet) SetNatTable(natTable C.NatTable) {
// no need
}

func (c *packet) SetUdpInChan(in chan<- C.PacketAdapter) {
// no need
}
9 changes: 9 additions & 0 deletions listener/socks/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"net"

"github.com/Dreamacro/clash/common/pool"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/transport/socks5"
)

Expand Down Expand Up @@ -39,3 +40,11 @@ func (c *packet) Drop() {
func (c *packet) InAddr() net.Addr {
return c.pc.LocalAddr()
}

func (c *packet) SetNatTable(natTable C.NatTable) {
// no need
}

func (c *packet) SetUdpInChan(in chan<- C.PacketAdapter) {
// no need
}
98 changes: 91 additions & 7 deletions listener/tproxy/packet.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
package tproxy

import (
"errors"
"fmt"
"github.com/Dreamacro/clash/adapter/inbound"
"github.com/Dreamacro/clash/common/pool"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
"net"
"net/netip"

"github.com/Dreamacro/clash/common/pool"
)

type packet struct {
pc net.PacketConn
lAddr netip.AddrPort
buf []byte
pc net.PacketConn
lAddr netip.AddrPort
buf []byte
natTable C.NatTable
in chan<- C.PacketAdapter
}

func (c *packet) Data() []byte {
Expand All @@ -19,13 +25,12 @@ func (c *packet) Data() []byte {

// WriteBack opens a new socket binding `addr` to write UDP packet back
func (c *packet) WriteBack(b []byte, addr net.Addr) (n int, err error) {
tc, err := dialUDP("udp", addr.(*net.UDPAddr).AddrPort(), c.lAddr)
tc, err := createOrGetLocalConn(addr, c.LocalAddr(), c.natTable, c.in)
if err != nil {
n = 0
return
}
n, err = tc.Write(b)
tc.Close()
return
}

Expand All @@ -41,3 +46,82 @@ func (c *packet) Drop() {
func (c *packet) InAddr() net.Addr {
return c.pc.LocalAddr()
}

func (c *packet) SetNatTable(natTable C.NatTable) {
c.natTable = natTable
}

func (c *packet) SetUdpInChan(in chan<- C.PacketAdapter) {
c.in = in
}

// this function listen at rAddr and write to lAddr
// for here, rAddr is the ip/port client want to access
// lAddr is the ip/port client opened
func createOrGetLocalConn(rAddr, lAddr net.Addr, natTable C.NatTable, in chan<- C.PacketAdapter) (*net.UDPConn, error) {
remote := rAddr.String()
local := lAddr.String()
localConn := natTable.GetLocalConn(local, remote)
// localConn not exist
if localConn == nil {
lockKey := remote + "-lock"
cond, loaded := natTable.GetOrCreateLockForLocalConn(local, lockKey)
if loaded {
cond.L.Lock()
cond.Wait()
// we should get localConn here
localConn = natTable.GetLocalConn(local, remote)
if localConn == nil {
return nil, fmt.Errorf("localConn is nil, nat entry not exist")
}
cond.L.Unlock()
} else {
if cond == nil {
return nil, fmt.Errorf("cond is nil, nat entry not exist")
}
defer func() {
natTable.DeleteLocalConnMap(local, lockKey)
cond.Broadcast()
}()
conn, err := listenLocalConn(rAddr, lAddr, in)
if err != nil {
log.Errorln("listenLocalConn failed with error: %s, packet loss", err.Error())
return nil, err
}
natTable.AddLocalConn(local, remote, conn)
localConn = conn
}
}
return localConn, nil
}

// this function listen at rAddr
// and send what received to program itself, then send to real remote
func listenLocalConn(rAddr, lAddr net.Addr, in chan<- C.PacketAdapter) (*net.UDPConn, error) {
additions := []inbound.Addition{
inbound.WithInName("DEFAULT-TPROXY"),
inbound.WithSpecialRules(""),
}
lc, err := dialUDP("udp", rAddr.(*net.UDPAddr).AddrPort(), lAddr.(*net.UDPAddr).AddrPort())
if err != nil {
return nil, err
}
go func() {
log.Debugln("TProxy listenLocalConn rAddr=%s lAddr=%s", rAddr.String(), lAddr.String())
for {
buf := pool.Get(pool.UDPBufferSize)
br, err := lc.Read(buf)
if err != nil {
pool.Put(buf)
if errors.Is(err, net.ErrClosed) {
log.Debugln("TProxy local conn listener exit.. rAddr=%s lAddr=%s", rAddr.String(), lAddr.String())
return
}
}
// since following localPackets are pass through this socket which listen rAddr
// I choose current listener as packet's packet conn
handlePacketConn(lc, in, buf[:br], lAddr.(*net.UDPAddr).AddrPort(), rAddr.(*net.UDPAddr).AddrPort(), additions...)
}
}()
return lc, nil
}
9 changes: 9 additions & 0 deletions listener/tunnel/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"net"

"github.com/Dreamacro/clash/common/pool"
C "github.com/Dreamacro/clash/constant"
)

type packet struct {
Expand Down Expand Up @@ -33,3 +34,11 @@ func (c *packet) Drop() {
func (c *packet) InAddr() net.Addr {
return c.pc.LocalAddr()
}

func (c *packet) SetNatTable(natTable C.NatTable) {
// no need
}

func (c *packet) SetUdpInChan(in chan<- C.PacketAdapter) {
// no need
}
8 changes: 8 additions & 0 deletions transport/tuic/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,5 +316,13 @@ func (s *serverUDPPacket) Drop() {
s.packet.DATA = nil
}

func (s *serverUDPPacket) SetNatTable(natTable C.NatTable) {
// no need
}

func (s *serverUDPPacket) SetUdpInChan(in chan<- C.PacketAdapter) {
// no need
}

var _ C.UDPPacket = &serverUDPPacket{}
var _ C.UDPPacketInAddr = &serverUDPPacket{}
15 changes: 15 additions & 0 deletions tunnel/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tunnel

import (
"errors"
"github.com/Dreamacro/clash/log"
"net"
"net/netip"
"time"
Expand Down Expand Up @@ -32,6 +33,7 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, oAddr,
buf := pool.Get(pool.UDPBufferSize)
defer func() {
_ = pc.Close()
closeAllLocalCoon(key)
natTable.Delete(key)
_ = pool.Put(buf)
}()
Expand Down Expand Up @@ -60,6 +62,19 @@ func handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, oAddr,
}
}

func closeAllLocalCoon(lAddr string) {
natTable.RangeLocalConn(lAddr, func(key, value any) bool {
conn, ok := value.(*net.UDPConn)
if !ok || conn == nil {
log.Debugln("Value %#v unknown value when closing TProxy local conn...", conn)
return true
}
conn.Close()
log.Debugln("Closing TProxy local conn... lAddr=%s rAddr=%s", lAddr, key)
return true
})
}

func handleSocket(ctx C.ConnContext, outbound net.Conn) {
N.Relay(ctx.Conn(), outbound)
}
Loading

0 comments on commit 8e4dfbd

Please sign in to comment.