From 361300d362e75a1cee0af545ee8fb49b21d470c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 15 Dec 2023 12:18:45 +0800 Subject: [PATCH] Update BatchTUN API for WireGuard --- stack_mixed.go | 10 ++++------ stack_system.go | 10 ++++------ tun.go | 4 ++-- tun_linux.go | 31 ++++++++++++------------------- 4 files changed, 22 insertions(+), 33 deletions(-) diff --git a/stack_mixed.go b/stack_mixed.go index b52c996..7c1a8f9 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -145,15 +145,13 @@ func (m *Mixed) wintunLoop(winTun WinTun) { func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) { frontHeadroom := m.tun.FrontHeadroom() packetBuffers := make([][]byte, batchSize) - readBuffers := make([][]byte, batchSize) writeBuffers := make([][]byte, batchSize) packetSizes := make([]int, batchSize) for i := range packetBuffers { - packetBuffers[i] = make([]byte, m.mtu+frontHeadroom+PacketOffset) - readBuffers[i] = packetBuffers[i][frontHeadroom:] + packetBuffers[i] = make([]byte, m.mtu+frontHeadroom) } for { - n, err := linuxTUN.BatchRead(readBuffers, packetSizes) + n, err := linuxTUN.BatchRead(packetBuffers, frontHeadroom, packetSizes) if err != nil { if E.IsClosed(err) { return @@ -169,13 +167,13 @@ func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) { continue } packetBuffer := packetBuffers[i] - packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize] + packet := packetBuffer[frontHeadroom : frontHeadroom+packetSize] if m.processPacket(packet) { writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize]) } } if len(writeBuffers) > 0 { - err = linuxTUN.BatchWrite(writeBuffers) + err = linuxTUN.BatchWrite(writeBuffers, frontHeadroom) if err != nil { m.logger.Trace(E.Cause(err, "batch write packet")) } diff --git a/stack_system.go b/stack_system.go index 73a83ae..8b687fa 100644 --- a/stack_system.go +++ b/stack_system.go @@ -198,15 +198,13 @@ func (s *System) wintunLoop(winTun WinTun) { func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) { frontHeadroom := s.tun.FrontHeadroom() packetBuffers := make([][]byte, batchSize) - readBuffers := make([][]byte, batchSize) writeBuffers := make([][]byte, batchSize) packetSizes := make([]int, batchSize) for i := range packetBuffers { - packetBuffers[i] = make([]byte, s.mtu+frontHeadroom+PacketOffset) - readBuffers[i] = packetBuffers[i][frontHeadroom:] + packetBuffers[i] = make([]byte, s.mtu+frontHeadroom) } for { - n, err := linuxTUN.BatchRead(readBuffers, packetSizes) + n, err := linuxTUN.BatchRead(packetBuffers, frontHeadroom, packetSizes) if err != nil { if E.IsClosed(err) { return @@ -222,13 +220,13 @@ func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) { continue } packetBuffer := packetBuffers[i] - packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize] + packet := packetBuffer[frontHeadroom : frontHeadroom+packetSize] if s.processPacket(packet) { writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize]) } } if len(writeBuffers) > 0 { - err = linuxTUN.BatchWrite(writeBuffers) + err = linuxTUN.BatchWrite(writeBuffers, frontHeadroom) if err != nil { s.logger.Trace(E.Cause(err, "batch write packet")) } diff --git a/tun.go b/tun.go index 5d853ee..9610782 100644 --- a/tun.go +++ b/tun.go @@ -36,8 +36,8 @@ type WinTun interface { type BatchTUN interface { Tun BatchSize() int - BatchRead(buffers [][]byte, readN []int) (n int, err error) - BatchWrite(buffers [][]byte) error + BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) + BatchWrite(buffers [][]byte, offset int) error } type Options struct { diff --git a/tun_linux.go b/tun_linux.go index 31903ef..9e9542d 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -35,6 +35,7 @@ type NativeTun struct { ruleIndex6 []int gsoEnabled bool gsoBuffer []byte + gsoToWrite []int tcpGROAccess sync.Mutex tcp4GROTable *tcpGROTable tcp6GROTable *tcpGROTable @@ -105,7 +106,7 @@ func (t *NativeTun) Read(p []byte) (n int, err error) { func (t *NativeTun) Write(p []byte) (n int, err error) { if t.gsoEnabled { - err = t.BatchWrite([][]byte{p}) + err = t.BatchWrite([][]byte{p}, 0) if err != nil { return } @@ -140,37 +141,29 @@ func (t *NativeTun) BatchSize() int { return batchSize } -func (t *NativeTun) BatchRead(buffers [][]byte, readN []int) (n int, err error) { - if t.gsoEnabled { - n, err = t.tunFile.Read(t.gsoBuffer) - if err != nil { - return - } - n, err = handleVirtioRead(t.gsoBuffer[:n], buffers, readN, 0) - if err != nil { - return - } - +func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) { + n, err = t.tunFile.Read(t.gsoBuffer) + if err != nil { return - } else { - return 0, os.ErrInvalid } + return handleVirtioRead(t.gsoBuffer[:n], buffers, readN, offset) } -func (t *NativeTun) BatchWrite(buffers [][]byte) error { +func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) error { t.tcpGROAccess.Lock() defer func() { t.tcp4GROTable.reset() t.tcp6GROTable.reset() t.tcpGROAccess.Unlock() }() - var toWrite []int - err := handleGRO(buffers, virtioNetHdrLen, t.tcp4GROTable, t.tcp6GROTable, &toWrite) + t.gsoToWrite = t.gsoToWrite[:0] + err := handleGRO(buffers, offset, t.tcp4GROTable, t.tcp6GROTable, &t.gsoToWrite) if err != nil { return err } - for _, bufferIndex := range toWrite { - _, err = t.tunFile.Write(buffers[bufferIndex]) + offset -= virtioNetHdrLen + for _, bufferIndex := range t.gsoToWrite { + _, err = t.tunFile.Write(buffers[bufferIndex][offset:]) if err != nil { return err }