Skip to content

Commit

Permalink
Fix GSO write
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 14, 2023
1 parent 6a1419a commit b84e980
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 95 deletions.
2 changes: 1 addition & 1 deletion stack_gvisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (t *GVisor) Start() error {
if err != nil {
return err
}
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, bufio.NewVectorisedWriter(t.tun)}
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
ipStack, err := newGVisorStack(linkEndpoint)
if err != nil {
return err
Expand Down
76 changes: 48 additions & 28 deletions stack_mixed.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,14 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/gvisor/pkg/waiter"
"github.com/sagernet/sing-tun/internal/clashtcpip"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/canceler"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

type Mixed struct {
*System
writer N.VectorisedWriter
endpointIndependentNat bool
stack *stack.Stack
endpoint *channel.Endpoint
Expand All @@ -38,7 +35,6 @@ func NewMixed(
}
return &Mixed{
System: system.(*System),
writer: bufio.NewVectorisedWriter(options.Tun),
endpointIndependentNat: options.EndpointIndependentNat,
}, nil
}
Expand Down Expand Up @@ -95,7 +91,6 @@ func (m *Mixed) tunLoop() {
m.wintunLoop(winTun)
return
}

if batchTUN, isBatchTUN := m.tun.(BatchTUN); isBatchTUN {
batchSize := batchTUN.BatchSize()
if batchSize > 1 {
Expand All @@ -118,7 +113,12 @@ func (m *Mixed) tunLoop() {
}
rawPacket := packetBuffer[:frontHeadroom+n]
packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+n]
m.processPacket(rawPacket, packet)
if m.processPacket(packet) {
_, err = m.tun.Write(rawPacket)
if err != nil {
m.logger.Trace(E.Cause(err, "write packet"))
}
}
}
}

Expand All @@ -132,7 +132,12 @@ func (m *Mixed) wintunLoop(winTun WinTun) {
release()
continue
}
m.processPacket(packet, packet)
if m.processPacket(packet) {
_, err = winTun.Write(packet)
if err != nil {
m.logger.Trace(E.Cause(err, "write packet"))
}
}
release()
}
}
Expand All @@ -141,6 +146,7 @@ func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) {
frontHeadroom := m.tun.FrontHeadroom()
packetBuffers := make([][]byte, batchSize)
readBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, m.mtu+frontHeadroom+PacketOffset)
Expand All @@ -163,69 +169,83 @@ func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) {
continue
}
packetBuffer := packetBuffers[i]
rawPacket := packetBuffer[:frontHeadroom+packetSize]
packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize]
m.processPacket(rawPacket, packet)
if m.processPacket(packet) {
writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize])
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers)
if err != nil {
m.logger.Trace(E.Cause(err, "batch write packet"))
}
writeBuffers = writeBuffers[:0]
}
}
}

func (m *Mixed) processPacket(rawPacket []byte, packet []byte) {
var err error
func (m *Mixed) processPacket(packet []byte) bool {
var (
writeBack bool
err error
)
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = m.processIPv4(rawPacket, packet)
writeBack, err = m.processIPv4(packet)
case 6:
err = m.processIPv6(rawPacket, packet)
writeBack, err = m.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
m.logger.Trace(err)
return false
}
return writeBack
}

func (m *Mixed) processIPv4(rawPacket []byte, packet clashtcpip.IPv4Packet) error {
func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) {
writeBack = true
destination := packet.DestinationIP()
if destination == m.broadcastAddr || !destination.IsGlobalUnicast() {
return common.Error(m.tun.Write(rawPacket))
return
}
switch packet.Protocol() {
case clashtcpip.TCP:
return m.processIPv4TCP(rawPacket, packet, packet.Payload())
err = m.processIPv4TCP(packet, packet.Payload())
case clashtcpip.UDP:
writeBack = false
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt)
pkt.DecRef()
return nil
return
case clashtcpip.ICMP:
return m.processIPv4ICMP(rawPacket, packet, packet.Payload())
default:
return common.Error(m.tun.Write(rawPacket))
err = m.processIPv4ICMP(packet, packet.Payload())
}
return
}

func (m *Mixed) processIPv6(rawPacket []byte, packet clashtcpip.IPv6Packet) error {
func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) {
writeBack = true
if !packet.DestinationIP().IsGlobalUnicast() {
return common.Error(m.tun.Write(rawPacket))
return
}
switch packet.Protocol() {
case clashtcpip.TCP:
return m.processIPv6TCP(rawPacket, packet, packet.Payload())
err = m.processIPv6TCP(packet, packet.Payload())
case clashtcpip.UDP:
writeBack = false
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt)
pkt.DecRef()
return nil
case clashtcpip.ICMPv6:
return m.processIPv6ICMP(rawPacket, packet, packet.Payload())
default:
return common.Error(m.tun.Write(rawPacket))
err = m.processIPv6ICMP(packet, packet.Payload())
}
return
}

func (m *Mixed) packetLoop() {
Expand All @@ -234,7 +254,7 @@ func (m *Mixed) packetLoop() {
if packet == nil {
break
}
bufio.WriteVectorised(m.writer, packet.AsSlices())
bufio.WriteVectorised(m.tun, packet.AsSlices())
packet.DecRef()
}
}
Expand Down
Loading

0 comments on commit b84e980

Please sign in to comment.