Skip to content

Commit

Permalink
chore: packet deadline support CreateReadWaiter interface
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed May 20, 2023
1 parent 2b1e691 commit b047ca0
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 70 deletions.
51 changes: 29 additions & 22 deletions common/net/deadline/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package deadline
import (
"net"
"os"
"runtime"
"time"

"github.com/Dreamacro/clash/common/atomic"
Expand All @@ -13,8 +14,6 @@ type readResult struct {
data []byte
addr net.Addr
err error
enhanceReadResult
singReadResult
}

type NetPacketConn struct {
Expand All @@ -23,14 +22,14 @@ type NetPacketConn struct {
pipeDeadline pipeDeadline
disablePipe atomic.Bool
inRead atomic.Bool
resultCh chan *readResult
resultCh chan any
}

func NewNetPacketConn(pc net.PacketConn) net.PacketConn {
npc := &NetPacketConn{
PacketConn: pc,
pipeDeadline: makePipeDeadline(),
resultCh: make(chan *readResult, 1),
resultCh: make(chan any, 1),
}
npc.resultCh <- nil
if enhancePC, isEnhance := pc.(packet.EnhancePacketConn); isEnhance {
Expand Down Expand Up @@ -65,20 +64,28 @@ func NewNetPacketConn(pc net.PacketConn) net.PacketConn {
}

func (c *NetPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
select {
case result := <-c.resultCh:
if result != nil {
n = copy(p, result.data)
addr = result.addr
err = result.err
c.resultCh <- nil // finish cache read
return
} else {
c.resultCh <- nil
break
FOR:
for {
select {
case result := <-c.resultCh:
if result != nil {
if result, ok := result.(*readResult); ok {
n = copy(p, result.data)
addr = result.addr
err = result.err
c.resultCh <- nil // finish cache read
return
}
c.resultCh <- result // another type of read
runtime.Gosched() // allowing other goroutines to run
continue FOR
} else {
c.resultCh <- nil
break FOR
}
case <-c.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
}
case <-c.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded
}

if c.disablePipe.Load() {
Expand All @@ -100,11 +107,11 @@ func (c *NetPacketConn) pipeReadFrom(size int) {
buffer := make([]byte, size)
n, addr, err := c.PacketConn.ReadFrom(buffer)
buffer = buffer[:n]
c.resultCh <- &readResult{
data: buffer,
addr: addr,
err: err,
}
result := &readResult{}
result.data = buffer
result.addr = addr
result.err = err
c.resultCh <- result
}

func (c *NetPacketConn) SetReadDeadline(t time.Time) error {
Expand Down
56 changes: 33 additions & 23 deletions common/net/deadline/packet_enhance.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package deadline
import (
"net"
"os"
"runtime"

"github.com/Dreamacro/clash/common/net/packet"
)
Expand All @@ -19,7 +20,10 @@ func NewEnhancePacketConn(pc packet.EnhancePacketConn) packet.EnhancePacketConn
}

type enhanceReadResult struct {
put func()
data []byte
put func()
addr net.Addr
err error
}

type enhancePacketConn struct {
Expand All @@ -28,21 +32,29 @@ type enhancePacketConn struct {
}

func (c *enhancePacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) {
select {
case result := <-c.netPacketConn.resultCh:
if result != nil {
data = result.data
put = result.put
addr = result.addr
err = result.err
c.netPacketConn.resultCh <- nil // finish cache read
return
} else {
c.netPacketConn.resultCh <- nil
break
FOR:
for {
select {
case result := <-c.netPacketConn.resultCh:
if result != nil {
if result, ok := result.(*enhanceReadResult); ok {
data = result.data
put = result.put
addr = result.addr
err = result.err
c.netPacketConn.resultCh <- nil // finish cache read
return
}
c.netPacketConn.resultCh <- result // another type of read
runtime.Gosched() // allowing other goroutines to run
continue FOR
} else {
c.netPacketConn.resultCh <- nil
break FOR
}
case <-c.netPacketConn.pipeDeadline.wait():
return nil, nil, nil, os.ErrDeadlineExceeded
}
case <-c.netPacketConn.pipeDeadline.wait():
return nil, nil, nil, os.ErrDeadlineExceeded
}

if c.netPacketConn.disablePipe.Load() {
Expand All @@ -62,12 +74,10 @@ func (c *enhancePacketConn) WaitReadFrom() (data []byte, put func(), addr net.Ad

func (c *enhancePacketConn) pipeWaitReadFrom() {
data, put, addr, err := c.enhancePacketConn.WaitReadFrom()
c.netPacketConn.resultCh <- &readResult{
data: data,
enhanceReadResult: enhanceReadResult{
put: put,
},
addr: addr,
err: err,
}
result := &enhanceReadResult{}
result.data = data
result.put = put
result.addr = addr
result.err = err
c.netPacketConn.resultCh <- result
}
134 changes: 109 additions & 25 deletions common/net/deadline/packet_sing.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package deadline

import (
"os"
"runtime"

"github.com/Dreamacro/clash/common/net/packet"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

type SingPacketConn struct {
Expand Down Expand Up @@ -33,6 +36,7 @@ var _ packet.EnhanceSingPacketConn = (*EnhanceSingPacketConn)(nil)
type singReadResult struct {
buffer *buf.Buffer
destination M.Socksaddr
err error
}

type singPacketConn struct {
Expand All @@ -41,26 +45,34 @@ type singPacketConn struct {
}

func (c *singPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
select {
case result := <-c.netPacketConn.resultCh:
if result != nil {
destination = result.destination
err = result.err
buffer.Resize(result.buffer.Start(), 0)
n := copy(buffer.FreeBytes(), result.buffer.Bytes())
buffer.Truncate(n)
result.buffer.Advance(n)
if result.buffer.IsEmpty() {
result.buffer.Release()
FOR:
for {
select {
case result := <-c.netPacketConn.resultCh:
if result != nil {
if result, ok := result.(*singReadResult); ok {
destination = result.destination
err = result.err
buffer.Resize(result.buffer.Start(), 0)
n := copy(buffer.FreeBytes(), result.buffer.Bytes())
buffer.Truncate(n)
result.buffer.Advance(n)
if result.buffer.IsEmpty() {
result.buffer.Release()
}
c.netPacketConn.resultCh <- nil // finish cache read
return
}
c.netPacketConn.resultCh <- result // another type of read
runtime.Gosched() // allowing other goroutines to run
continue FOR
} else {
c.netPacketConn.resultCh <- nil
break FOR
}
c.netPacketConn.resultCh <- nil // finish cache read
return
} else {
c.netPacketConn.resultCh <- nil
break
case <-c.netPacketConn.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
case <-c.netPacketConn.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}

if c.netPacketConn.disablePipe.Load() {
Expand All @@ -82,15 +94,87 @@ func (c *singPacketConn) pipeReadPacket(bufLen int, bufStart int) {
buffer := buf.NewSize(bufLen)
buffer.Advance(bufStart)
destination, err := c.singPacketConn.ReadPacket(buffer)
c.netPacketConn.resultCh <- &readResult{
singReadResult: singReadResult{
buffer: buffer,
destination: destination,
},
err: err,
}
result := &singReadResult{}
result.destination = destination
result.err = err
c.netPacketConn.resultCh <- result
}

func (c *singPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return c.singPacketConn.WritePacket(buffer, destination)
}

func (c *singPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
prw, isReadWaiter := bufio.CreatePacketReadWaiter(c.singPacketConn)
if isReadWaiter {
return &singPacketReadWaiter{
netPacketConn: c.netPacketConn,
packetReadWaiter: prw,
}, true
}
return nil, false
}

var _ N.PacketReadWaiter = (*singPacketReadWaiter)(nil)

type singPacketReadWaiter struct {
netPacketConn *NetPacketConn
packetReadWaiter N.PacketReadWaiter
}

type singWaitReadResult singReadResult

func (c *singPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
c.packetReadWaiter.InitializeReadWaiter(newBuffer)
}

func (c *singPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) {
FOR:
for {
select {
case result := <-c.netPacketConn.resultCh:
if result != nil {
if result, ok := result.(*singWaitReadResult); ok {
destination = result.destination
err = result.err
c.netPacketConn.resultCh <- nil // finish cache read
return
}
c.netPacketConn.resultCh <- result // another type of read
runtime.Gosched() // allowing other goroutines to run
continue FOR
} else {
c.netPacketConn.resultCh <- nil
break FOR
}
case <-c.netPacketConn.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded
}
}

if c.netPacketConn.disablePipe.Load() {
return c.packetReadWaiter.WaitReadPacket()
} else if c.netPacketConn.deadline.Load().IsZero() {
c.netPacketConn.inRead.Store(true)
defer c.netPacketConn.inRead.Store(false)
destination, err = c.packetReadWaiter.WaitReadPacket()
return
}

<-c.netPacketConn.resultCh
go c.pipeWaitReadPacket()

return c.WaitReadPacket()
}

func (c *singPacketReadWaiter) pipeWaitReadPacket() {
destination, err := c.packetReadWaiter.WaitReadPacket()
result := &singWaitReadResult{}
result.destination = destination
result.err = err
c.netPacketConn.resultCh <- result
}

func (c *singPacketReadWaiter) Upstream() any {
return c.packetReadWaiter
}

0 comments on commit b047ca0

Please sign in to comment.