Skip to content

Commit

Permalink
Update BatchTUN API for WireGuard
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Dec 15, 2023
1 parent 0e13875 commit 3c677bc
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 33 deletions.
10 changes: 4 additions & 6 deletions stack_mixed.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,13 @@ func (m *Mixed) wintunLoop(winTun WinTun) {
func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) {
frontHeadroom := m.tun.FrontHeadroom()
packetBuffers := make([][]byte, batchSize)
readBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, m.mtu+frontHeadroom+PacketOffset)
readBuffers[i] = packetBuffers[i][frontHeadroom:]
packetBuffers[i] = make([]byte, m.mtu+frontHeadroom)
}
for {
n, err := linuxTUN.BatchRead(readBuffers, packetSizes)
n, err := linuxTUN.BatchRead(packetBuffers, frontHeadroom, packetSizes)
if err != nil {
if E.IsClosed(err) {
return
Expand All @@ -169,13 +167,13 @@ func (m *Mixed) batchLoop(linuxTUN BatchTUN, batchSize int) {
continue
}
packetBuffer := packetBuffers[i]
packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize]
packet := packetBuffer[frontHeadroom : frontHeadroom+packetSize]
if m.processPacket(packet) {
writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize])
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers)
err = linuxTUN.BatchWrite(writeBuffers, frontHeadroom)
if err != nil {
m.logger.Trace(E.Cause(err, "batch write packet"))
}
Expand Down
10 changes: 4 additions & 6 deletions stack_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,13 @@ func (s *System) wintunLoop(winTun WinTun) {
func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) {
frontHeadroom := s.tun.FrontHeadroom()
packetBuffers := make([][]byte, batchSize)
readBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, s.mtu+frontHeadroom+PacketOffset)
readBuffers[i] = packetBuffers[i][frontHeadroom:]
packetBuffers[i] = make([]byte, s.mtu+frontHeadroom)
}
for {
n, err := linuxTUN.BatchRead(readBuffers, packetSizes)
n, err := linuxTUN.BatchRead(packetBuffers, frontHeadroom, packetSizes)
if err != nil {
if E.IsClosed(err) {
return
Expand All @@ -222,13 +220,13 @@ func (s *System) batchLoop(linuxTUN BatchTUN, batchSize int) {
continue
}
packetBuffer := packetBuffers[i]
packet := packetBuffer[frontHeadroom+PacketOffset : frontHeadroom+packetSize]
packet := packetBuffer[frontHeadroom : frontHeadroom+packetSize]
if s.processPacket(packet) {
writeBuffers = append(writeBuffers, packetBuffer[:frontHeadroom+packetSize])
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers)
err = linuxTUN.BatchWrite(writeBuffers, frontHeadroom)
if err != nil {
s.logger.Trace(E.Cause(err, "batch write packet"))
}
Expand Down
4 changes: 2 additions & 2 deletions tun.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ type WinTun interface {
type BatchTUN interface {
Tun
BatchSize() int
BatchRead(buffers [][]byte, readN []int) (n int, err error)
BatchWrite(buffers [][]byte) error
BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error)
BatchWrite(buffers [][]byte, offset int) error
}

type Options struct {
Expand Down
35 changes: 16 additions & 19 deletions tun_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type NativeTun struct {
ruleIndex6 []int
gsoEnabled bool
gsoBuffer []byte
gsoToWrite []int
tcpGROAccess sync.Mutex
tcp4GROTable *tcpGROTable
tcp6GROTable *tcpGROTable
Expand Down Expand Up @@ -105,7 +106,7 @@ func (t *NativeTun) Read(p []byte) (n int, err error) {

func (t *NativeTun) Write(p []byte) (n int, err error) {
if t.gsoEnabled {
err = t.BatchWrite([][]byte{p})
err = t.BatchWrite([][]byte{p}, 0)
if err != nil {
return
}
Expand Down Expand Up @@ -140,37 +141,33 @@ func (t *NativeTun) BatchSize() int {
return batchSize
}

func (t *NativeTun) BatchRead(buffers [][]byte, readN []int) (n int, err error) {
if t.gsoEnabled {
n, err = t.tunFile.Read(t.gsoBuffer)
if err != nil {
return
}
n, err = handleVirtioRead(t.gsoBuffer[:n], buffers, readN, 0)
if err != nil {
return
}

return
} else {
func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) {
if !t.gsoEnabled {
return 0, os.ErrInvalid
}
n, err = t.tunFile.Read(t.gsoBuffer)
if err != nil {
return
}
return handleVirtioRead(t.gsoBuffer[:n], buffers, readN, offset)
}

func (t *NativeTun) BatchWrite(buffers [][]byte) error {
func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) error {
println("batch write ", len(buffers))
t.tcpGROAccess.Lock()
defer func() {
t.tcp4GROTable.reset()
t.tcp6GROTable.reset()
t.tcpGROAccess.Unlock()
}()
var toWrite []int
err := handleGRO(buffers, virtioNetHdrLen, t.tcp4GROTable, t.tcp6GROTable, &toWrite)
t.gsoToWrite = t.gsoToWrite[:0]
err := handleGRO(buffers, offset, t.tcp4GROTable, t.tcp6GROTable, &t.gsoToWrite)
if err != nil {
return err
}
for _, bufferIndex := range toWrite {
_, err = t.tunFile.Write(buffers[bufferIndex])
offset -= virtioNetHdrLen
for _, bufferIndex := range t.gsoToWrite {
_, err = t.tunFile.Write(buffers[bufferIndex][offset:])
if err != nil {
return err
}
Expand Down

0 comments on commit 3c677bc

Please sign in to comment.