diff --git a/stack.go b/stack.go index 7eb7854..43c5383 100644 --- a/stack.go +++ b/stack.go @@ -19,10 +19,7 @@ type Stack interface { type StackOptions struct { Context context.Context Tun Tun - Name string - MTU uint32 - Inet4Address []netip.Prefix - Inet6Address []netip.Prefix + TunOptions Options EndpointIndependentNat bool UDPTimeout int64 Handler Handler @@ -31,13 +28,21 @@ type StackOptions struct { InterfaceFinder control.InterfaceFinder } +func (o *StackOptions) BufferSize() uint32 { + if o.TunOptions.GSO { + return o.TunOptions.GSOMaxSize + } else { + return o.TunOptions.MTU + } +} + func NewStack( stack string, options StackOptions, ) (Stack, error) { switch stack { case "": - if WithGVisor { + if WithGVisor && !options.TunOptions.GSO { return NewMixed(options) } else { return NewSystem(options) @@ -48,8 +53,6 @@ func NewStack( return NewMixed(options) case "system": return NewSystem(options) - case "lwip": - return NewLWIP(options) default: return nil, E.New("unknown stack: ", stack) } diff --git a/stack_gvisor.go b/stack_gvisor.go index 108af21..6a1d0f3 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -31,7 +31,6 @@ const defaultNIC tcpip.NICID = 1 type GVisor struct { ctx context.Context tun GVisorTun - tunMtu uint32 endpointIndependentNat bool udpTimeout int64 broadcastAddr netip.Addr @@ -57,10 +56,9 @@ func NewGVisor( gStack := &GVisor{ ctx: options.Context, tun: gTun, - tunMtu: options.MTU, endpointIndependentNat: options.EndpointIndependentNat, udpTimeout: options.UDPTimeout, - broadcastAddr: BroadcastAddr(options.Inet4Address), + broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), handler: options.Handler, logger: options.Logger, } @@ -72,7 +70,7 @@ func (t *GVisor) Start() error { if err != nil { return err } - linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun.CreateVectorisedWriter()} + linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, bufio.NewVectorisedWriter(t.tun)} ipStack, err := newGVisorStack(linkEndpoint) if err != nil { return err diff --git a/stack_lwip.go b/stack_lwip.go deleted file mode 100644 index 42cb651..0000000 --- a/stack_lwip.go +++ /dev/null @@ -1,144 +0,0 @@ -//go:build with_lwip - -package tun - -import ( - "context" - "net" - "net/netip" - "os" - - lwip "github.com/sagernet/go-tun2socks/core" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/udpnat" -) - -type LWIP struct { - ctx context.Context - tun Tun - tunMtu uint32 - udpTimeout int64 - handler Handler - stack lwip.LWIPStack - udpNat *udpnat.Service[netip.AddrPort] -} - -func NewLWIP( - options StackOptions, -) (Stack, error) { - return &LWIP{ - ctx: options.Context, - tun: options.Tun, - tunMtu: options.MTU, - handler: options.Handler, - stack: lwip.NewLWIPStack(), - udpNat: udpnat.New[netip.AddrPort](options.UDPTimeout, options.Handler), - }, nil -} - -func (l *LWIP) Start() error { - lwip.RegisterTCPConnHandler(l) - lwip.RegisterUDPConnHandler(l) - lwip.RegisterOutputFn(l.tun.Write) - go l.loopIn() - return nil -} - -func (l *LWIP) loopIn() { - if winTun, isWintun := l.tun.(WinTun); isWintun { - l.loopInWintun(winTun) - return - } - buffer := make([]byte, int(l.tunMtu)+PacketOffset) - for { - n, err := l.tun.Read(buffer) - if err != nil { - return - } - _, err = l.stack.Write(buffer[PacketOffset:n]) - if err != nil { - if err.Error() == "stack closed" { - return - } - l.handler.NewError(context.Background(), err) - } - } -} - -func (l *LWIP) loopInWintun(tun WinTun) { - for { - packet, release, err := tun.ReadPacket() - if err != nil { - return - } - _, err = l.stack.Write(packet) - release() - if err != nil { - if err.Error() == "stack closed" { - return - } - l.handler.NewError(context.Background(), err) - } - } -} - -func (l *LWIP) Close() error { - lwip.RegisterTCPConnHandler(nil) - lwip.RegisterUDPConnHandler(nil) - lwip.RegisterOutputFn(func(bytes []byte) (int, error) { - return 0, os.ErrClosed - }) - return l.stack.Close() -} - -func (l *LWIP) Handle(conn net.Conn) error { - lAddr := conn.LocalAddr() - rAddr := conn.RemoteAddr() - if lAddr == nil || rAddr == nil { - conn.Close() - return nil - } - go func() { - var metadata M.Metadata - metadata.Source = M.SocksaddrFromNet(lAddr) - metadata.Destination = M.SocksaddrFromNet(rAddr) - hErr := l.handler.NewConnection(l.ctx, conn, metadata) - if hErr != nil { - conn.(lwip.TCPConn).Abort() - } - }() - return nil -} - -func (l *LWIP) ReceiveTo(conn lwip.UDPConn, data []byte, addr M.Socksaddr) error { - var upstreamMetadata M.Metadata - upstreamMetadata.Source = conn.LocalAddr() - upstreamMetadata.Destination = addr - - l.udpNat.NewPacket( - l.ctx, - upstreamMetadata.Source.AddrPort(), - buf.As(data).ToOwned(), - upstreamMetadata, - func(natConn N.PacketConn) N.PacketWriter { - return &LWIPUDPBackWriter{conn} - }, - ) - return nil -} - -type LWIPUDPBackWriter struct { - conn lwip.UDPConn -} - -func (w *LWIPUDPBackWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - defer buffer.Release() - return common.Error(w.conn.WriteFrom(buffer.Bytes(), destination)) -} - -func (w *LWIPUDPBackWriter) Close() error { - return w.conn.Close() -} diff --git a/stack_lwip_stub.go b/stack_lwip_stub.go deleted file mode 100644 index 403a45e..0000000 --- a/stack_lwip_stub.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build !with_lwip - -package tun - -import E "github.com/sagernet/sing/common/exceptions" - -func NewLWIP( - options StackOptions, -) (Stack, error) { - return nil, E.New(`LWIP is not included in this build, rebuild with -tags with_lwip`) -} diff --git a/stack_mixed.go b/stack_mixed.go index 41d0ce2..710929f 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -38,7 +38,7 @@ func NewMixed( } return &Mixed{ System: system.(*System), - writer: options.Tun.CreateVectorisedWriter(), + writer: bufio.NewVectorisedWriter(options.Tun), endpointIndependentNat: options.EndpointIndependentNat, }, nil } @@ -95,7 +95,7 @@ func (m *Mixed) tunLoop() { m.wintunLoop(winTun) return } - packetBuffer := make([]byte, m.mtu+PacketOffset) + packetBuffer := make([]byte, m.bufferSize+m.tun.FrontHeadroom()) for { n, err := m.tun.Read(packetBuffer) if err != nil { @@ -104,12 +104,13 @@ func (m *Mixed) tunLoop() { if n < clashtcpip.IPv4PacketMinLength { continue } - packet := packetBuffer[PacketOffset:n] + rawPacket := packetBuffer[:n] + packet := packetBuffer[m.tun.FrontHeadroom():n] switch ipVersion := packet[0] >> 4; ipVersion { case 4: - err = m.processIPv4(packet) + err = m.processIPv4(rawPacket, packet) case 6: - err = m.processIPv6(packet) + err = m.processIPv6(rawPacket, packet) default: err = E.New("ip: unknown version: ", ipVersion) } @@ -131,9 +132,9 @@ func (m *Mixed) wintunLoop(winTun WinTun) { } switch ipVersion := packet[0] >> 4; ipVersion { case 4: - err = m.processIPv4(packet) + err = m.processIPv4(packet, packet) case 6: - err = m.processIPv6(packet) + err = m.processIPv6(packet, packet) default: err = E.New("ip: unknown version: ", ipVersion) } @@ -144,14 +145,14 @@ func (m *Mixed) wintunLoop(winTun WinTun) { } } -func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) error { +func (m *Mixed) processIPv4(rawPacket []byte, packet clashtcpip.IPv4Packet) error { destination := packet.DestinationIP() if destination == m.broadcastAddr || !destination.IsGlobalUnicast() { return common.Error(m.tun.Write(packet)) } switch packet.Protocol() { case clashtcpip.TCP: - return m.processIPv4TCP(packet, packet.Payload()) + return m.processIPv4TCP(rawPacket, packet, packet.Payload()) case clashtcpip.UDP: pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Payload: buffer.MakeWithData(packet), @@ -160,19 +161,19 @@ func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) error { pkt.DecRef() return nil case clashtcpip.ICMP: - return m.processIPv4ICMP(packet, packet.Payload()) + return m.processIPv4ICMP(rawPacket, packet, packet.Payload()) default: return common.Error(m.tun.Write(packet)) } } -func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) error { +func (m *Mixed) processIPv6(rawPacket []byte, packet clashtcpip.IPv6Packet) error { if !packet.DestinationIP().IsGlobalUnicast() { return common.Error(m.tun.Write(packet)) } switch packet.Protocol() { case clashtcpip.TCP: - return m.processIPv6TCP(packet, packet.Payload()) + return m.processIPv6TCP(rawPacket, packet, packet.Payload()) case clashtcpip.UDP: pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Payload: buffer.MakeWithData(packet), @@ -181,7 +182,7 @@ func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) error { pkt.DecRef() return nil case clashtcpip.ICMPv6: - return m.processIPv6ICMP(packet, packet.Payload()) + return m.processIPv6ICMP(rawPacket, packet, packet.Payload()) default: return common.Error(m.tun.Write(packet)) } diff --git a/stack_system.go b/stack_system.go index dd305fe..9773cab 100644 --- a/stack_system.go +++ b/stack_system.go @@ -23,6 +23,7 @@ type System struct { tun Tun tunName string mtu uint32 + bufferSize int handler Handler logger logger.Logger inet4Prefixes []netip.Prefix @@ -41,6 +42,7 @@ type System struct { udpNat *udpnat.Service[netip.AddrPort] bindInterface bool interfaceFinder control.InterfaceFinder + offload bool } type Session struct { @@ -54,29 +56,30 @@ func NewSystem(options StackOptions) (Stack, error) { stack := &System{ ctx: options.Context, tun: options.Tun, - tunName: options.Name, - mtu: options.MTU, + tunName: options.TunOptions.Name, + mtu: options.TunOptions.MTU, + bufferSize: int(options.BufferSize()), udpTimeout: options.UDPTimeout, handler: options.Handler, logger: options.Logger, - inet4Prefixes: options.Inet4Address, - inet6Prefixes: options.Inet6Address, - broadcastAddr: BroadcastAddr(options.Inet4Address), + inet4Prefixes: options.TunOptions.Inet4Address, + inet6Prefixes: options.TunOptions.Inet6Address, + broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), bindInterface: options.ForwarderBindInterface, interfaceFinder: options.InterfaceFinder, } - if len(options.Inet4Address) > 0 { - if options.Inet4Address[0].Bits() == 32 { + if len(options.TunOptions.Inet4Address) > 0 { + if options.TunOptions.Inet4Address[0].Bits() == 32 { return nil, E.New("need one more IPv4 address in first prefix for system stack") } - stack.inet4ServerAddress = options.Inet4Address[0].Addr() + stack.inet4ServerAddress = options.TunOptions.Inet4Address[0].Addr() stack.inet4Address = stack.inet4ServerAddress.Next() } - if len(options.Inet6Address) > 0 { - if options.Inet6Address[0].Bits() == 128 { + if len(options.TunOptions.Inet6Address) > 0 { + if options.TunOptions.Inet6Address[0].Bits() == 128 { return nil, E.New("need one more IPv6 address in first prefix for system stack") } - stack.inet6ServerAddress = options.Inet6Address[0].Addr() + stack.inet6ServerAddress = options.TunOptions.Inet6Address[0].Addr() stack.inet6Address = stack.inet6ServerAddress.Next() } if !stack.inet4Address.IsValid() && !stack.inet6Address.IsValid() { @@ -144,21 +147,26 @@ func (s *System) tunLoop() { s.wintunLoop(winTun) return } - packetBuffer := make([]byte, s.mtu+PacketOffset) + frontHeadroom := s.tun.FrontHeadroom() + packetBuffer := make([]byte, s.bufferSize+frontHeadroom) for { - n, err := s.tun.Read(packetBuffer) + n, err := s.tun.Read(packetBuffer[frontHeadroom:]) if err != nil { - return + if E.IsClosed(err) { + return + } + s.logger.Error(E.Cause(err, "read packet")) } if n < clashtcpip.IPv4PacketMinLength { continue } - packet := packetBuffer[PacketOffset:n] + rawPacket := packetBuffer[:frontHeadroom+n] + packet := packetBuffer[frontHeadroom : frontHeadroom+n] switch ipVersion := packet[0] >> 4; ipVersion { case 4: - err = s.processIPv4(packet) + err = s.processIPv4(rawPacket, packet) case 6: - err = s.processIPv6(packet) + err = s.processIPv6(rawPacket, packet) default: err = E.New("ip: unknown version: ", ipVersion) } @@ -180,9 +188,9 @@ func (s *System) wintunLoop(winTun WinTun) { } switch ipVersion := packet[0] >> 4; ipVersion { case 4: - err = s.processIPv4(packet) + err = s.processIPv4(packet, packet) case 6: - err = s.processIPv6(packet) + err = s.processIPv6(packet, packet) default: err = E.New("ip: unknown version: ", ipVersion) } @@ -234,44 +242,44 @@ func (s *System) acceptLoop(listener net.Listener) { } } -func (s *System) processIPv4(packet clashtcpip.IPv4Packet) error { +func (s *System) processIPv4(rawPacket []byte, packet clashtcpip.IPv4Packet) error { destination := packet.DestinationIP() if destination == s.broadcastAddr || !destination.IsGlobalUnicast() { return common.Error(s.tun.Write(packet)) } switch packet.Protocol() { case clashtcpip.TCP: - return s.processIPv4TCP(packet, packet.Payload()) + return s.processIPv4TCP(rawPacket, packet, packet.Payload()) case clashtcpip.UDP: - return s.processIPv4UDP(packet, packet.Payload()) + return s.processIPv4UDP(rawPacket, packet, packet.Payload()) case clashtcpip.ICMP: - return s.processIPv4ICMP(packet, packet.Payload()) + return s.processIPv4ICMP(rawPacket, packet, packet.Payload()) default: - return common.Error(s.tun.Write(packet)) + return common.Error(s.tun.Write(rawPacket)) } } -func (s *System) processIPv6(packet clashtcpip.IPv6Packet) error { +func (s *System) processIPv6(rawPacket []byte, packet clashtcpip.IPv6Packet) error { if !packet.DestinationIP().IsGlobalUnicast() { - return common.Error(s.tun.Write(packet)) + return common.Error(s.tun.Write(rawPacket)) } switch packet.Protocol() { case clashtcpip.TCP: - return s.processIPv6TCP(packet, packet.Payload()) + return s.processIPv6TCP(rawPacket, packet, packet.Payload()) case clashtcpip.UDP: - return s.processIPv6UDP(packet, packet.Payload()) + return s.processIPv6UDP(rawPacket, packet, packet.Payload()) case clashtcpip.ICMPv6: - return s.processIPv6ICMP(packet, packet.Payload()) + return s.processIPv6ICMP(rawPacket, packet, packet.Payload()) default: - return common.Error(s.tun.Write(packet)) + return common.Error(s.tun.Write(rawPacket)) } } -func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error { +func (s *System) processIPv4TCP(rawPacket []byte, packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error { source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return common.Error(s.tun.Write(packet)) + return common.Error(s.tun.Write(rawPacket)) } else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort { session := s.tcpNat.LookupBack(destination.Port()) if session == nil { @@ -290,14 +298,14 @@ func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip. } header.ResetChecksum(packet.PseudoSum()) packet.ResetChecksum() - return common.Error(s.tun.Write(packet)) + return common.Error(s.tun.Write(rawPacket)) } -func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error { +func (s *System) processIPv6TCP(rawPacket []byte, packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error { source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return common.Error(s.tun.Write(packet)) + return common.Error(s.tun.Write(rawPacket)) } else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 { session := s.tcpNat.LookupBack(destination.Port()) if session == nil { @@ -316,10 +324,10 @@ func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip. } header.ResetChecksum(packet.PseudoSum()) packet.ResetChecksum() - return common.Error(s.tun.Write(packet)) + return common.Error(s.tun.Write(rawPacket)) } -func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error { +func (s *System) processIPv4UDP(rawPacket []byte, packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error { if packet.Flags()&clashtcpip.FlagMoreFragment != 0 { return E.New("ipv4: fragment dropped") } @@ -332,7 +340,7 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip. source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return common.Error(s.tun.Write(packet)) + return common.Error(s.tun.Write(rawPacket)) } data := buf.As(header.Payload()) if data.Len() == 0 { @@ -346,19 +354,19 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip. headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize headerCopy := make([]byte, headerLen) copy(headerCopy, packet[:headerLen]) - return &systemUDPPacketWriter4{s.tun, headerCopy, source} + return &systemUDPPacketWriter4{s.tun, s.tun.FrontHeadroom(), headerCopy, source} }) return nil } -func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error { +func (s *System) processIPv6UDP(rawPacket []byte, packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error { if !header.Valid() { return E.New("ipv6: udp: invalid packet") } source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort()) destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort()) if !destination.Addr().IsGlobalUnicast() { - return common.Error(s.tun.Write(packet)) + return common.Error(s.tun.Write(rawPacket)) } data := buf.As(header.Payload()) if data.Len() == 0 { @@ -372,12 +380,12 @@ func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip. headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize headerCopy := make([]byte, headerLen) copy(headerCopy, packet[:headerLen]) - return &systemUDPPacketWriter6{s.tun, headerCopy, source} + return &systemUDPPacketWriter6{s.tun, s.tun.FrontHeadroom(), headerCopy, source} }) return nil } -func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error { +func (s *System) processIPv4ICMP(rawPacket []byte, packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error { if header.Type() != clashtcpip.ICMPTypePingRequest || header.Code() != 0 { return nil } @@ -387,10 +395,10 @@ func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip packet.SetDestinationIP(sourceAddress) header.ResetChecksum() packet.ResetChecksum() - return common.Error(s.tun.Write(packet)) + return common.Error(s.tun.Write(rawPacket)) } -func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error { +func (s *System) processIPv6ICMP(rawPacket []byte, packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error { if header.Type() != clashtcpip.ICMPv6EchoRequest || header.Code() != 0 { return nil } @@ -400,102 +408,20 @@ func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip packet.SetDestinationIP(sourceAddress) header.ResetChecksum(packet.PseudoSum()) packet.ResetChecksum() - return common.Error(s.tun.Write(packet)) -} - -type systemTCPDirectPacketWriter4 struct { - tun Tun - source netip.AddrPort -} - -func (w *systemTCPDirectPacketWriter4) WritePacket(p []byte) error { - packet := clashtcpip.IPv4Packet(p) - header := clashtcpip.TCPPacket(packet.Payload()) - packet.SetDestinationIP(w.source.Addr()) - header.SetDestinationPort(w.source.Port()) - header.ResetChecksum(packet.PseudoSum()) - packet.ResetChecksum() - return common.Error(w.tun.Write(packet)) -} - -type systemTCPDirectPacketWriter6 struct { - tun Tun - source netip.AddrPort -} - -func (w *systemTCPDirectPacketWriter6) WritePacket(p []byte) error { - packet := clashtcpip.IPv6Packet(p) - header := clashtcpip.TCPPacket(packet.Payload()) - packet.SetDestinationIP(w.source.Addr()) - header.SetDestinationPort(w.source.Port()) - header.ResetChecksum(packet.PseudoSum()) - packet.ResetChecksum() - return common.Error(w.tun.Write(packet)) -} - -type systemUDPDirectPacketWriter4 struct { - tun Tun - source netip.AddrPort -} - -func (w *systemUDPDirectPacketWriter4) WritePacket(p []byte) error { - packet := clashtcpip.IPv4Packet(p) - header := clashtcpip.UDPPacket(packet.Payload()) - packet.SetDestinationIP(w.source.Addr()) - header.SetDestinationPort(w.source.Port()) - header.ResetChecksum(packet.PseudoSum()) - packet.ResetChecksum() - return common.Error(w.tun.Write(packet)) -} - -type systemUDPDirectPacketWriter6 struct { - tun Tun - source netip.AddrPort -} - -func (w *systemUDPDirectPacketWriter6) WritePacket(p []byte) error { - packet := clashtcpip.IPv6Packet(p) - header := clashtcpip.UDPPacket(packet.Payload()) - packet.SetDestinationIP(w.source.Addr()) - header.SetDestinationPort(w.source.Port()) - header.ResetChecksum(packet.PseudoSum()) - packet.ResetChecksum() - return common.Error(w.tun.Write(packet)) -} - -type systemICMPDirectPacketWriter4 struct { - tun Tun - source netip.Addr -} - -func (w *systemICMPDirectPacketWriter4) WritePacket(p []byte) error { - packet := clashtcpip.IPv4Packet(p) - packet.SetDestinationIP(w.source) - packet.ResetChecksum() - return common.Error(w.tun.Write(packet)) -} - -type systemICMPDirectPacketWriter6 struct { - tun Tun - source netip.Addr -} - -func (w *systemICMPDirectPacketWriter6) WritePacket(p []byte) error { - packet := clashtcpip.IPv6Packet(p) - packet.SetDestinationIP(w.source) - packet.ResetChecksum() - return common.Error(w.tun.Write(packet)) + return common.Error(s.tun.Write(rawPacket)) } type systemUDPPacketWriter4 struct { - tun Tun - header []byte - source netip.AddrPort + tun Tun + frontHeadroom int + header []byte + source netip.AddrPort } func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - newPacket := buf.NewSize(len(w.header) + buffer.Len()) + newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len()) defer newPacket.Release() + newPacket.Extend(w.frontHeadroom) newPacket.Write(w.header) newPacket.Write(buffer.Bytes()) ipHdr := clashtcpip.IPv4Packet(newPacket.Bytes()) @@ -512,14 +438,16 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S } type systemUDPPacketWriter6 struct { - tun Tun - header []byte - source netip.AddrPort + tun Tun + frontHeadroom int + header []byte + source netip.AddrPort } func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - newPacket := buf.NewSize(len(w.header) + buffer.Len()) + newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len()) defer newPacket.Release() + newPacket.Extend(w.frontHeadroom) newPacket.Write(w.header) newPacket.Write(buffer.Bytes()) ipHdr := clashtcpip.IPv6Packet(newPacket.Bytes()) diff --git a/tun.go b/tun.go index 977d4eb..f897c28 100644 --- a/tun.go +++ b/tun.go @@ -23,7 +23,7 @@ type Handler interface { type Tun interface { io.ReadWriter - CreateVectorisedWriter() N.VectorisedWriter + N.FrontHeadroom Close() error } @@ -37,6 +37,8 @@ type Options struct { Inet4Address []netip.Prefix Inet6Address []netip.Prefix MTU uint32 + GSO bool + GSOMaxSize uint32 AutoRoute bool StrictRoute bool Inet4RouteAddress []netip.Prefix diff --git a/tun_darwin.go b/tun_darwin.go index 553eb02..c8d569f 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -20,8 +20,6 @@ import ( "golang.org/x/sys/unix" ) -const PacketOffset = 4 - type NativeTun struct { tunFile *os.File tunWriter N.VectorisedWriter @@ -72,6 +70,10 @@ func New(options Options) (Tun, error) { return nativeTun, nil } +func (t *NativeTun) FrontHeadroom() int { + return 4 +} + func (t *NativeTun) Read(p []byte) (n int, err error) { /*n, err = t.tunFile.Read(p) if n < 4 { @@ -83,37 +85,24 @@ func (t *NativeTun) Read(p []byte) (n int, err error) { return t.tunFile.Read(p) } -var ( - packetHeader4 = [4]byte{0x00, 0x00, 0x00, unix.AF_INET} - packetHeader6 = [4]byte{0x00, 0x00, 0x00, unix.AF_INET6} -) - func (t *NativeTun) Write(p []byte) (n int, err error) { - var packetHeader []byte - if p[0]>>4 == 4 { - packetHeader = packetHeader4[:] + if p[4]>>4 == 4 { + p[3] = unix.AF_INET } else { - packetHeader = packetHeader6[:] - } - _, err = bufio.WriteVectorised(t.tunWriter, [][]byte{packetHeader, p}) - if err == nil { - n = len(p) + p[3] = unix.AF_INET6 } - return -} - -func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter { - return t + return t.tunFile.Write(p) } func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { - var packetHeader []byte + packetHeader := buf.NewSize(4) + packetHeader.WriteZeroN(3) if buffers[0].Byte(0)>>4 == 4 { - packetHeader = packetHeader4[:] + packetHeader.WriteByte(unix.AF_INET) } else { - packetHeader = packetHeader6[:] + packetHeader.WriteByte(unix.AF_INET6) } - return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...)) + return t.tunWriter.WriteVectorised(append([]*buf.Buffer{packetHeader}, buffers...)) } func (t *NativeTun) Close() error { diff --git a/tun_linux.go b/tun_linux.go index a261455..62c3992 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -1,6 +1,7 @@ package tun import ( + "io" "math/rand" "net" "net/netip" @@ -12,9 +13,7 @@ import ( "github.com/sagernet/netlink" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" - N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/shell" "github.com/sagernet/sing/common/x/list" @@ -28,11 +27,15 @@ type NativeTun struct { interfaceCallback *list.Element[DefaultInterfaceUpdateCallback] options Options ruleIndex6 []int + gsoEnabled bool + gsoBuffer []byte + tcp4GROTable *tcpGROTable + tcp6GROTable *tcpGROTable } func New(options Options) (Tun, error) { if options.FileDescriptor == 0 { - tunFd, err := open(options.Name) + tunFd, err := open(options.Name, options.GSO) if err != nil { return nil, err } @@ -62,18 +65,60 @@ func New(options Options) (Tun, error) { } } +func (t *NativeTun) FrontHeadroom() int { + if t.gsoEnabled { + return virtioNetHdrLen + } + return 0 +} + +func (t *NativeTun) UpstreamWriter() io.Writer { + return t.tunFile +} + +func (t *NativeTun) WriterReplaceable() bool { + return !t.gsoEnabled +} + func (t *NativeTun) Read(p []byte) (n int, err error) { - return t.tunFile.Read(p) + if t.gsoEnabled { + n, err = t.tunFile.Read(t.gsoBuffer) + if err != nil { + return + } + var sizes [1]int + n, err = handleVirtioRead(t.gsoBuffer[:n], [][]byte{p}, sizes[:], 0) + if err != nil { + return + } + if n == 0 { + return + } + n = sizes[0] + return + } else { + return t.tunFile.Read(p) + } } func (t *NativeTun) Write(p []byte) (n int, err error) { + if t.gsoEnabled { + defer func() { + t.tcp4GROTable.reset() + t.tcp6GROTable.reset() + }() + var toWrite []int + err = handleGRO([][]byte{p}, virtioNetHdrLen, t.tcp4GROTable, t.tcp6GROTable, &toWrite) + if err != nil { + return + } + if len(toWrite) == 0 { + return + } + } return t.tunFile.Write(p) } -func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter { - return bufio.NewVectorisedWriter(t.tunFile) -} - var controlPath string func init() { @@ -86,7 +131,7 @@ func init() { } } -func open(name string) (int, error) { +func open(name string, vnetHdr bool) (int, error) { fd, err := unix.Open(controlPath, unix.O_RDWR, 0) if err != nil { return -1, err @@ -100,6 +145,9 @@ func open(name string) (int, error) { copy(ifr.name[:], name) ifr.flags = unix.IFF_TUN | unix.IFF_NO_PI + if vnetHdr { + ifr.flags |= unix.IFF_VNET_HDR + } _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.TUNSETIFF, uintptr(unsafe.Pointer(&ifr))) if errno != 0 { unix.Close(fd) @@ -142,6 +190,28 @@ func (t *NativeTun) configure(tunLink netlink.Link) error { } } + if t.options.GSO { + vnethdrEnabled, err := checkVNETHDREnabled(uint16(t.tunFd), t.options.Name) + if err != nil { + return E.Cause(err, "enable offload: check IFF_VNET_HDR enabled") + } + if !vnethdrEnabled { + return E.Cause(err, "enable offload: IFF_VNET_HDR not enabled") + } + const ( + // TODO: support TSO with ECN bits + tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + ) + err = unix.IoctlSetInt(t.tunFd, unix.TUNSETOFFLOAD, tunOffloads) + if err != nil { + return E.Cause(os.NewSyscallError("TUNSETOFFLOAD", err), "enable offload") + } + t.gsoEnabled = true + t.gsoBuffer = make([]byte, virtioNetHdrLen+int(t.options.GSOMaxSize)) + t.tcp4GROTable = newTCPGROTable() + t.tcp6GROTable = newTCPGROTable() + } + err = netlink.LinkSetUp(tunLink) if err != nil { return err @@ -181,6 +251,18 @@ func (t *NativeTun) configure(tunLink netlink.Link) error { return nil } +func checkVNETHDREnabled(fd uint16, name string) (bool, error) { + ifr, err := unix.NewIfreq(name) + if err != nil { + return false, err + } + err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr) + if err != nil { + return false, os.NewSyscallError("TUNGETIFF", err) + } + return ifr.Uint16()&unix.IFF_VNET_HDR != 0, nil +} + func (t *NativeTun) Close() error { if t.interfaceCallback != nil { t.options.InterfaceMonitor.UnregisterCallback(t.interfaceCallback) diff --git a/tun_linux_gvisor.go b/tun_linux_gvisor.go index b5c400c..65aa2ec 100644 --- a/tun_linux_gvisor.go +++ b/tun_linux_gvisor.go @@ -10,6 +10,13 @@ import ( var _ GVisorTun = (*NativeTun)(nil) func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) { + if t.gsoEnabled { + return fdbased.New(&fdbased.Options{ + FDs: []int{t.tunFd}, + MTU: t.options.MTU, + GSOMaxSize: t.options.GSOMaxSize, + }) + } return fdbased.New(&fdbased.Options{ FDs: []int{t.tunFd}, MTU: t.options.MTU, diff --git a/tun_linux_offload.go b/tun_linux_offload.go new file mode 100644 index 0000000..272fc79 --- /dev/null +++ b/tun_linux_offload.go @@ -0,0 +1,764 @@ +//go:build linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "unsafe" + + "github.com/sagernet/sing-tun/internal/clashtcpip" + + "golang.org/x/sys/unix" +) + +const ( + tcpFlagsOffset = 13 + idealBatchSize = 1 +) + +const ( + tcpFlagFIN uint8 = 0x01 + tcpFlagPSH uint8 = 0x08 + tcpFlagACK uint8 = 0x10 +) + +var errTooManySegments = errors.New("too many segments") + +// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The +// kernel symbol is virtio_net_hdr. +type virtioNetHdr struct { + flags uint8 + gsoType uint8 + hdrLen uint16 + gsoSize uint16 + csumStart uint16 + csumOffset uint16 +} + +func (v *virtioNetHdr) decode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen]) + return nil +} + +func (v *virtioNetHdr) encode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen)) + return nil +} + +const ( + // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the + // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). + virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) +) + +// flowKey represents the key for a flow. +type flowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. +} + +// tcpGROTable holds flow and coalescing information for the purposes of GRO. +type tcpGROTable struct { + itemsByFlow map[flowKey][]tcpGROItem + itemsPool [][]tcpGROItem +} + +func newTCPGROTable() *tcpGROTable { + t := &tcpGROTable{ + itemsByFlow: make(map[flowKey][]tcpGROItem, idealBatchSize), + itemsPool: make([][]tcpGROItem, idealBatchSize), + } + for i := range t.itemsPool { + t.itemsPool[i] = make([]tcpGROItem, 0, idealBatchSize) + } + return t +} + +func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey { + key := flowKey{} + addrSize := dstAddr - srcAddr + copy(key.srcAddr[:], pkt[srcAddr:dstAddr]) + copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) + key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { + key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + items, ok := t.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { + key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + item := tcpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), + iphLen: uint8(tcphOffset), + tcphLen: uint8(tcphLen), + sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), + pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, + } + items, ok := t.itemsByFlow[key] + if !ok { + items = t.newItems() + } + items = append(items, item) + t.itemsByFlow[key] = items +} + +func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { + items, _ := t.itemsByFlow[item.key] + items[i] = item +} + +func (t *tcpGROTable) deleteAt(key flowKey, i int) { + items, _ := t.itemsByFlow[key] + items = append(items[:i], items[i+1:]...) + t.itemsByFlow[key] = items +} + +// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type tcpGROItem struct { + key flowKey + sentSeq uint32 // the sequence number + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + tcphLen uint8 // tcp header len + pshSet bool // psh flag is set +} + +func (t *tcpGROTable) newItems() []tcpGROItem { + var items []tcpGROItem + items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1] + return items +} + +func (t *tcpGROTable) reset() { + for k, items := range t.itemsByFlow { + items = items[:0] + t.itemsPool = append(t.itemsPool, items) + delete(t.itemsByFlow, k) + } +} + +// canCoalesce represents the outcome of checking if two TCP packets are +// candidates for coalescing. +type canCoalesce int + +const ( + coalescePrepend canCoalesce = -1 + coalesceUnavailable canCoalesce = 0 + coalesceAppend canCoalesce = 1 +) + +// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. This function makes considerations that match the kernel's +// GRO self tests, which can be found in tools/testing/selftests/net/gro.c. +func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if tcphLen != item.tcphLen { + // cannot coalesce with unequal tcp options len + return coalesceUnavailable + } + if tcphLen > 20 { + if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) { + // cannot coalesce with unequal tcp options + return coalesceUnavailable + } + } + if pkt[0]>>4 == 6 { + if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 { + // cannot coalesce with unequal Traffic class values + return coalesceUnavailable + } + if pkt[7] != pktTarget[7] { + // cannot coalesce with unequal Hop limit values + return coalesceUnavailable + } + } else { + if pkt[1] != pktTarget[1] { + // cannot coalesce with unequal ToS values + return coalesceUnavailable + } + if pkt[6]>>5 != pktTarget[6]>>5 { + // cannot coalesce with unequal DF or reserved bits. MF is checked + // further up the stack. + return coalesceUnavailable + } + if pkt[8] != pktTarget[8] { + // cannot coalesce with unequal TTL values + return coalesceUnavailable + } + } + // seq adjacency + lhsLen := item.gsoSize + lhsLen += item.numMerged * item.gsoSize + if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective + if item.pshSet { + // We cannot append to a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend + } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective + if pshSet { + // We cannot prepend with a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if gsoSize < item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + if gsoSize > item.gsoSize && item.numMerged > 0 { + // There's at least one previous merge, and we're larger than all + // previous. This would put multiple smaller packets on the end. + return coalesceUnavailable + } + return coalescePrepend + } + return coalesceUnavailable +} + +func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool { + srcAddrAt := ipv4SrcAddrOffset + addrSize := 4 + if isV6 { + srcAddrAt = ipv6SrcAddrOffset + addrSize = 16 + } + tcpTotalLen := uint16(len(pkt) - int(iphLen)) + tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen) + return ^checksumFold(pkt[iphLen:], tcpCSumNoFold) == 0 +} + +// coalesceResult represents the result of attempting to coalesce two TCP +// packets. +type coalesceResult int + +const ( + coalesceInsufficientCap coalesceResult = iota + coalescePSHEnding + coalesceItemInvalidCSum + coalescePktInvalidCSum + coalesceSuccess +) + +// coalesceTCPPackets attempts to coalesce pkt with the packet described by +// item, returning the outcome. This function may swap bufs elements in the +// event of a prepend as item's bufs index is already being tracked for writing +// to a Device. +func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + var pktHead []byte // the packet that will end up at the front + headersLen := item.iphLen + item.tcphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + // Copy data + if mode == coalescePrepend { + pktHead = pkt + if cap(pkt)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if pshSet { + return coalescePSHEnding + } + if item.numMerged == 0 { + if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) { + return coalesceItemInvalidCSum + } + } + if !tcpChecksumValid(pkt, item.iphLen, isV6) { + return coalescePktInvalidCSum + } + item.sentSeq = seq + extendBy := coalescedLen - len(pktHead) + bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...) + copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):]) + // Flip the slice headers in bufs as part of prepend. The index of item + // is already being tracked for writing. + bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex] + } else { + pktHead = bufs[item.bufsIndex][bufsOffset:] + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) { + return coalesceItemInvalidCSum + } + } + if !tcpChecksumValid(pkt, item.iphLen, isV6) { + return coalescePktInvalidCSum + } + if pshSet { + // We are appending a segment with PSH set. + item.pshSet = pshSet + pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + } + + if gsoSize > item.gsoSize { + item.gsoSize = gsoSize + } + + item.numMerged++ + return coalesceSuccess +} + +const ( + ipv4FlagMoreFragments uint8 = 0x20 +) + +const ( + ipv4SrcAddrOffset = 12 + ipv6SrcAddrOffset = 8 + maxUint16 = 1<<16 - 1 +) + +type tcpGROResult int + +const ( + tcpGROResultNoop tcpGROResult = iota + tcpGROResultTableInsert + tcpGROResultCoalesced +) + +// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a tcpGROResultNoop when no +// action was taken, tcpGROResultTableInsert when the evaluated packet was +// inserted into table, and tcpGROResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return tcpGROResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return tcpGROResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return tcpGROResultNoop + } + } + if len(pkt) < iphLen { + return tcpGROResultNoop + } + tcphLen := int((pkt[iphLen+12] >> 4) * 4) + if tcphLen < 20 || tcphLen > 60 { + return tcpGROResultNoop + } + if len(pkt) < iphLen+tcphLen { + return tcpGROResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return tcpGROResultNoop + } + } + tcpFlags := pkt[iphLen+tcpFlagsOffset] + var pshSet bool + // not a candidate if any non-ACK flags (except PSH+ACK) are set + if tcpFlags != tcpFlagACK { + if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { + return tcpGROResultNoop + } + pshSet = true + } + gsoSize := uint16(len(pkt) - tcphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return tcpGROResultNoop + } + seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + if !existing { + return tcpGROResultNoop + } + for i := len(items) - 1; i >= 0; i-- { + // In the best case of packets arriving in order iterating in reverse is + // more efficient if there are multiple items for a given flow. This + // also enables a natural table.deleteAt() in the + // coalesceItemInvalidCSum case without the need for index tracking. + // This algorithm makes a best effort to coalesce in the event of + // unordered packets, where pkt may land anywhere in items from a + // sequence number perspective, however once an item is inserted into + // the table it is never compared across other items later. + item := items[i] + can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) + if can != coalesceUnavailable { + result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, i) + return tcpGROResultCoalesced + case coalesceItemInvalidCSum: + // delete the item with an invalid csum + table.deleteAt(item.key, i) + case coalescePktInvalidCSum: + // no point in inserting an item that we can't coalesce + return tcpGROResultNoop + default: + } + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + return tcpGROResultTableInsert +} + +func isTCP4NoIPOptions(b []byte) bool { + if len(b) < 40 { + return false + } + if b[0]>>4 != 4 { + return false + } + if b[0]&0x0F != 5 { + return false + } + if b[9] != unix.IPPROTO_TCP { + return false + } + return true +} + +func isTCP6NoEH(b []byte) bool { + if len(b) < 60 { + return false + } + if b[0]>>4 != 6 { + return false + } + if b[6] != unix.IPPROTO_TCP { + return false + } + return true +} + +// applyCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + item.tcphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 16, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + if isV6 { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksumFold(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Calculate the pseudo header checksum and place it at the TCP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the tcp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksumFold([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +// handleGRO evaluates bufs for GRO, and writes the indices of the resulting +// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be +// empty (but non-nil), and are passed in to save allocs as the caller may reset +// and recycle them across vectors of packets. +func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error { + for i := range bufs { + if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { + return errors.New("invalid offset") + } + var result tcpGROResult + switch { + case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce + result = tcpGRO(bufs, offset, i, tcp4Table, false) + case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce + result = tcpGRO(bufs, offset, i, tcp6Table, true) + } + switch result { + case tcpGROResultNoop: + hdr := virtioNetHdr{} + err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + fallthrough + case tcpGROResultTableInsert: + *toWrite = append(*toWrite, i) + } + } + err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false) + err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true) + return errors.Join(err4, err6) +} + +// tcpTSO splits packets from in into outBuffs, writing the size of each +// element into sizes. It returns the number of buffers populated, and/or an +// error. +func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) { + iphLen := int(hdr.csumStart) + srcAddrOffset := ipv6SrcAddrOffset + addrLen := 16 + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { + in[10], in[11] = 0, 0 // clear ipv4 header checksum + srcAddrOffset = ipv4SrcAddrOffset + addrLen = 4 + } + tcpCSumAt := int(hdr.csumStart + hdr.csumOffset) + in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum + firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:]) + nextSegmentDataAt := int(hdr.hdrLen) + i := 0 + for ; nextSegmentDataAt < len(in); i++ { + if i == len(outBuffs) { + return i - 1, errTooManySegments + } + nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize) + if nextSegmentEnd > len(in) { + nextSegmentEnd = len(in) + } + segmentDataLen := nextSegmentEnd - nextSegmentDataAt + totalLen := int(hdr.hdrLen) + segmentDataLen + sizes[i] = totalLen + out := outBuffs[i][outOffset:] + + copy(out, in[:iphLen]) + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 { + // For IPv4 we are responsible for incrementing the ID field, + // updating the total len field, and recalculating the header + // checksum. + if i > 0 { + id := binary.BigEndian.Uint16(out[4:]) + id += uint16(i) + binary.BigEndian.PutUint16(out[4:], id) + } + binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) + ipv4CSum := ^checksumFold(out[:iphLen], 0) + binary.BigEndian.PutUint16(out[10:], ipv4CSum) + } else { + // For IPv6 we are responsible for updating the payload length field. + binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) + } + + // TCP header + copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen]) + tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) + binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) + if nextSegmentEnd != len(in) { + // FIN and PSH should only be set on last segment + clearFlags := tcpFlagFIN | tcpFlagPSH + out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags + } + + // payload + copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) + + // TCP checksum + tcpHLen := int(hdr.hdrLen - hdr.csumStart) + tcpLenForPseudo := uint16(tcpHLen + segmentDataLen) + tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo) + tcpCSum := ^checksumFold(out[hdr.csumStart:totalLen], tcpCSumNoFold) + binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum) + + nextSegmentDataAt += int(hdr.gsoSize) + } + return i, nil +} + +func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { + cSumAt := cSumStart + cSumOffset + // The initial value at the checksum offset should be summed with the + // checksum we compute. This is typically the pseudo-header checksum. + initial := binary.BigEndian.Uint16(in[cSumAt:]) + in[cSumAt], in[cSumAt+1] = 0, 0 + binary.BigEndian.PutUint16(in[cSumAt:], ^checksumFold(in[cSumStart:], uint64(initial))) + return nil +} + +// handleVirtioRead splits in into bufs, leaving offset bytes at the front of +// each buffer. It mutates sizes to reflect the size of each element of bufs, +// and returns the number of packets read. +func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { + var hdr virtioNetHdr + err := hdr.decode(in) + if err != nil { + return 0, err + } + in = in[virtioNetHdrLen:] + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { + if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { + // This means CHECKSUM_PARTIAL in skb context. We are responsible + // for computing the checksum starting at hdr.csumStart and placing + // at hdr.csumOffset. + err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) + if err != nil { + return 0, err + } + } + if len(in) > len(bufs[0][offset:]) { + return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:])) + } + n := copy(bufs[0][offset:], in) + sizes[0] = n + return 1, nil + } + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) + } + + ipVersion := in[0] >> 4 + switch ipVersion { + case 4: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + case 6: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + default: + return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) + } + + if len(in) <= int(hdr.csumStart+12) { + return 0, errors.New("packet is too short") + } + // Don't trust hdr.hdrLen from the kernel as it can be equal to the length + // of the entire first packet when the kernel is handling it as part of a + // FORWARD path. Instead, parse the TCP header length and add it onto + // csumStart, which is synonymous for IP header length. + tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) + } + hdr.hdrLen = hdr.csumStart + tcpHLen + + if len(in) < int(hdr.hdrLen) { + return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) + } + + if hdr.hdrLen < hdr.csumStart { + return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) + } + cSumAt := int(hdr.csumStart + hdr.csumOffset) + if cSumAt+1 >= len(in) { + return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) + } + + return tcpTSO(in, hdr, bufs, sizes, offset) +} + +func checksumNoFold(b []byte, initial uint64) uint64 { + return initial + uint64(clashtcpip.Sum(b)) +} + +func checksumFold(b []byte, initial uint64) uint16 { + r := clashtcpip.Checksum(uint32(initial), b) + return binary.BigEndian.Uint16(r[:]) +} + +func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { + sum := checksumNoFold(srcAddr, 0) + sum = checksumNoFold(dstAddr, sum) + sum = checksumNoFold([]byte{0, protocol}, sum) + tmp := make([]byte, 2) + binary.BigEndian.PutUint16(tmp, totalLen) + return checksumNoFold(tmp, sum) +} diff --git a/tun_nondarwin.go b/tun_nondarwin.go deleted file mode 100644 index 0faa2c9..0000000 --- a/tun_nondarwin.go +++ /dev/null @@ -1,5 +0,0 @@ -//go:build !darwin - -package tun - -const PacketOffset = 0 diff --git a/tun_nonlinux.go b/tun_nonlinux.go new file mode 100644 index 0000000..28ce640 --- /dev/null +++ b/tun_nonlinux.go @@ -0,0 +1,5 @@ +//go:build !linux + +package tun + +const OffloadOffset = 0 diff --git a/tun_windows.go b/tun_windows.go index 2028746..90a9867 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -19,7 +19,6 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" - N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/windnsapi" "golang.org/x/sys/windows" @@ -66,6 +65,10 @@ func New(options Options) (WinTun, error) { return nativeTun, nil } +func (t *NativeTun) FrontHeadroom() int { + return 0 +} + func (t *NativeTun) configure() error { luid := winipcfg.LUID(t.adapter.LUID()) if len(t.options.Inet4Address) > 0 { @@ -454,10 +457,6 @@ func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) { return 0, fmt.Errorf("write failed: %w", err) } -func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter { - return t -} - func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { defer buf.ReleaseMulti(buffers) return common.Error(t.write(buf.ToSliceMulti(buffers)))