diff --git a/docs/config.yaml b/docs/config.yaml index 4e1b3a1883..0cec9a0484 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -905,9 +905,9 @@ listeners: listen: 0.0.0.0 # rule: sub-rule-name1 # 默认使用 rules,如果未找到 sub-rule 则直接使用 rules # proxy: proxy # 如果不为空则直接将该入站流量交由指定proxy处理(当proxy不为空时,这里的proxy名称必须合法,否则会出错) - # token: # tuicV4填写(不可同时填写users) + # token: # tuicV4填写(可以同时填写users) # - TOKEN - # users: # tuicV5填写(不可同时填写token) + # users: # tuicV5填写(可以同时填写token) # 00000000-0000-0000-0000-000000000000: PASSWORD_0 # 00000000-0000-0000-0000-000000000001: PASSWORD_1 # certificate: ./server.crt @@ -978,9 +978,9 @@ listeners: # tuic-server: # enable: true # listen: 127.0.0.1:10443 -# token: # tuicV4填写(不可同时填写users) +# token: # tuicV4填写(可以同时填写users) # - TOKEN -# users: # tuicV5填写(不可同时填写token) +# users: # tuicV5填写(可以同时填写token) # 00000000-0000-0000-0000-000000000000: PASSWORD_0 # 00000000-0000-0000-0000-000000000001: PASSWORD_1 # certificate: ./server.crt diff --git a/listener/tuic/server.go b/listener/tuic/server.go index 742d8ac9b7..76996b272c 100644 --- a/listener/tuic/server.go +++ b/listener/tuic/server.go @@ -26,7 +26,7 @@ type Listener struct { closed bool config LC.TuicServer udpListeners []net.PacketConn - servers []tuic.Server + servers []*tuic.Server } func New(config LC.TuicServer, tcpIn chan<- C.ConnContext, udpIn chan<- C.PacketAdapter, additions ...inbound.Addition) (*Listener, error) { @@ -102,42 +102,29 @@ func New(config LC.TuicServer, tcpIn chan<- C.ConnContext, udpIn chan<- C.Packet return nil } - var optionV4 *tuic.ServerOptionV4 - var optionV5 *tuic.ServerOptionV5 + option := &tuic.ServerOption{ + HandleTcpFn: handleTcpFn, + HandleUdpFn: handleUdpFn, + TlsConfig: tlsConfig, + QuicConfig: quicConfig, + CongestionController: config.CongestionController, + AuthenticationTimeout: time.Duration(config.AuthenticationTimeout) * time.Millisecond, + MaxUdpRelayPacketSize: config.MaxUdpRelayPacketSize, + CWND: config.CWND, + } if len(config.Token) > 0 { tokens := make([][32]byte, len(config.Token)) for i, token := range config.Token { tokens[i] = tuic.GenTKN(token) } - - optionV4 = &tuic.ServerOptionV4{ - HandleTcpFn: handleTcpFn, - HandleUdpFn: handleUdpFn, - TlsConfig: tlsConfig, - QuicConfig: quicConfig, - Tokens: tokens, - CongestionController: config.CongestionController, - AuthenticationTimeout: time.Duration(config.AuthenticationTimeout) * time.Millisecond, - MaxUdpRelayPacketSize: config.MaxUdpRelayPacketSize, - CWND: config.CWND, - } - } else { + option.Tokens = tokens + } + if len(config.Users) > 0 { users := make(map[[16]byte]string) for _uuid, password := range config.Users { users[uuid.FromStringOrNil(_uuid)] = password } - - optionV5 = &tuic.ServerOptionV5{ - HandleTcpFn: handleTcpFn, - HandleUdpFn: handleUdpFn, - TlsConfig: tlsConfig, - QuicConfig: quicConfig, - Users: users, - CongestionController: config.CongestionController, - AuthenticationTimeout: time.Duration(config.AuthenticationTimeout) * time.Millisecond, - MaxUdpRelayPacketSize: config.MaxUdpRelayPacketSize, - CWND: config.CWND, - } + option.Users = users } sl := &Listener{false, config, nil, nil} @@ -157,12 +144,8 @@ func New(config LC.TuicServer, tcpIn chan<- C.ConnContext, udpIn chan<- C.Packet sl.udpListeners = append(sl.udpListeners, ul) - var server tuic.Server - if optionV4 != nil { - server, err = tuic.NewServerV4(optionV4, ul) - } else { - server, err = tuic.NewServerV5(optionV5, ul) - } + var server *tuic.Server + server, err = tuic.NewServer(option, ul) if err != nil { return nil, err } diff --git a/transport/tuic/common/type.go b/transport/tuic/common/type.go index a5a60986e4..9a568dd731 100644 --- a/transport/tuic/common/type.go +++ b/transport/tuic/common/type.go @@ -1,11 +1,13 @@ package common import ( + "bufio" "context" "errors" "net" "time" + N "github.com/Dreamacro/clash/common/net" C "github.com/Dreamacro/clash/constant" "github.com/metacubex/quic-go" @@ -28,9 +30,12 @@ type Client interface { Close() } -type Server interface { - Serve() error - Close() error +type ServerHandler interface { + AuthOk() bool + HandleTimeout() + HandleStream(conn *N.BufferedConn) (err error) + HandleMessage(message []byte) (err error) + HandleUniStream(reader *bufio.Reader) (err error) } type UdpRelayMode uint8 diff --git a/transport/tuic/server.go b/transport/tuic/server.go new file mode 100644 index 0000000000..47850107ff --- /dev/null +++ b/transport/tuic/server.go @@ -0,0 +1,234 @@ +package tuic + +import ( + "bufio" + "context" + "crypto/tls" + "net" + "time" + + "github.com/Dreamacro/clash/adapter/inbound" + N "github.com/Dreamacro/clash/common/net" + "github.com/Dreamacro/clash/common/utils" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/transport/socks5" + "github.com/Dreamacro/clash/transport/tuic/common" + v4 "github.com/Dreamacro/clash/transport/tuic/v4" + v5 "github.com/Dreamacro/clash/transport/tuic/v5" + + "github.com/gofrs/uuid/v5" + "github.com/metacubex/quic-go" +) + +type ServerOption struct { + HandleTcpFn func(conn net.Conn, addr socks5.Addr, additions ...inbound.Addition) error + HandleUdpFn func(addr socks5.Addr, packet C.UDPPacket, additions ...inbound.Addition) error + + TlsConfig *tls.Config + QuicConfig *quic.Config + Tokens [][32]byte // V4 special + Users map[[16]byte]string // V5 special + CongestionController string + AuthenticationTimeout time.Duration + MaxUdpRelayPacketSize int + CWND int +} + +type Server struct { + *ServerOption + optionV4 *v4.ServerOption + optionV5 *v5.ServerOption + listener *quic.EarlyListener +} + +func (s *Server) Serve() error { + for { + conn, err := s.listener.Accept(context.Background()) + if err != nil { + return err + } + common.SetCongestionController(conn, s.CongestionController, s.CWND) + h := &serverHandler{ + Server: s, + quicConn: conn, + uuid: utils.NewUUIDV4(), + } + if h.optionV4 != nil { + h.v4Handler = v4.NewServerHandler(h.optionV4, conn, h.uuid) + } + if h.optionV5 != nil { + h.v5Handler = v5.NewServerHandler(h.optionV5, conn, h.uuid) + } + go h.handle() + } +} + +func (s *Server) Close() error { + return s.listener.Close() +} + +type serverHandler struct { + *Server + quicConn quic.EarlyConnection + uuid uuid.UUID + + v4Handler common.ServerHandler + v5Handler common.ServerHandler +} + +func (s *serverHandler) handle() { + go func() { + _ = s.handleUniStream() + }() + go func() { + _ = s.handleStream() + }() + go func() { + _ = s.handleMessage() + }() + + <-s.quicConn.HandshakeComplete() + time.AfterFunc(s.AuthenticationTimeout, func() { + if s.v4Handler != nil { + if s.v4Handler.AuthOk() { + return + } + } + + if s.v5Handler != nil { + if s.v5Handler.AuthOk() { + return + } + } + + if s.v4Handler != nil { + s.v4Handler.HandleTimeout() + } + + if s.v5Handler != nil { + s.v5Handler.HandleTimeout() + } + }) +} + +func (s *serverHandler) handleMessage() (err error) { + for { + var message []byte + message, err = s.quicConn.ReceiveMessage() + if err != nil { + return err + } + go func() (err error) { + if len(message) > 0 { + switch message[0] { + case v4.VER: + if s.v4Handler != nil { + return s.v4Handler.HandleMessage(message) + } + case v5.VER: + if s.v5Handler != nil { + return s.v5Handler.HandleMessage(message) + } + } + } + return + }() + } +} + +func (s *serverHandler) handleStream() (err error) { + for { + var quicStream quic.Stream + quicStream, err = s.quicConn.AcceptStream(context.Background()) + if err != nil { + return err + } + go func() (err error) { + stream := common.NewQuicStreamConn( + quicStream, + s.quicConn.LocalAddr(), + s.quicConn.RemoteAddr(), + nil, + ) + conn := N.NewBufferedConn(stream) + + verBytes, err := conn.Peek(1) + if err != nil { + _ = conn.Close() + return err + } + + switch verBytes[0] { + case v4.VER: + if s.v4Handler != nil { + return s.v4Handler.HandleStream(conn) + } + case v5.VER: + if s.v5Handler != nil { + return s.v5Handler.HandleStream(conn) + } + } + return + }() + } +} + +func (s *serverHandler) handleUniStream() (err error) { + for { + var stream quic.ReceiveStream + stream, err = s.quicConn.AcceptUniStream(context.Background()) + if err != nil { + return err + } + go func() (err error) { + defer func() { + stream.CancelRead(0) + }() + reader := bufio.NewReader(stream) + verBytes, err := reader.Peek(1) + if err != nil { + return err + } + + switch verBytes[0] { + case v4.VER: + if s.v4Handler != nil { + return s.v4Handler.HandleUniStream(reader) + } + case v5.VER: + if s.v5Handler != nil { + return s.v5Handler.HandleUniStream(reader) + } + } + return + }() + } +} + +func NewServer(option *ServerOption, pc net.PacketConn) (*Server, error) { + listener, err := quic.ListenEarly(pc, option.TlsConfig, option.QuicConfig) + if err != nil { + return nil, err + } + server := &Server{ + ServerOption: option, + listener: listener, + } + if len(option.Tokens) > 0 { + server.optionV4 = &v4.ServerOption{ + HandleTcpFn: option.HandleTcpFn, + HandleUdpFn: option.HandleUdpFn, + Tokens: option.Tokens, + MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize, + } + } + if len(option.Users) > 0 { + server.optionV5 = &v5.ServerOption{ + HandleTcpFn: option.HandleTcpFn, + HandleUdpFn: option.HandleUdpFn, + Users: option.Users, + MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize, + } + } + return server, nil +} diff --git a/transport/tuic/tuic.go b/transport/tuic/tuic.go index 7be6f45056..8832ef91d1 100644 --- a/transport/tuic/tuic.go +++ b/transport/tuic/tuic.go @@ -1,8 +1,6 @@ package tuic import ( - "net" - C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/transport/tuic/common" v4 "github.com/Dreamacro/clash/transport/tuic/v4" @@ -26,19 +24,6 @@ type DialFunc = common.DialFunc var TooManyOpenStreams = common.TooManyOpenStreams -type ServerOptionV4 = v4.ServerOption -type ServerOptionV5 = v5.ServerOption - -type Server = common.Server - -func NewServerV4(option *ServerOptionV4, pc net.PacketConn) (Server, error) { - return v4.NewServer(option, pc) -} - -func NewServerV5(option *ServerOptionV5, pc net.PacketConn) (Server, error) { - return v5.NewServer(option, pc) -} - const DefaultStreamReceiveWindow = common.DefaultStreamReceiveWindow const DefaultConnectionReceiveWindow = common.DefaultConnectionReceiveWindow diff --git a/transport/tuic/v4/packet.go b/transport/tuic/v4/packet.go index 2f808befa2..2066ceb732 100644 --- a/transport/tuic/v4/packet.go +++ b/transport/tuic/v4/packet.go @@ -3,9 +3,9 @@ package v4 import ( "net" "sync" - "sync/atomic" "time" + "github.com/Dreamacro/clash/common/atomic" N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/transport/tuic/common" diff --git a/transport/tuic/v4/protocol.go b/transport/tuic/v4/protocol.go index 65f0f9d5a6..11ac3b4eb9 100644 --- a/transport/tuic/v4/protocol.go +++ b/transport/tuic/v4/protocol.go @@ -36,6 +36,8 @@ const ( ResponseType = CommandType(0xff) ) +const VER byte = 0x04 + func (c CommandType) String() string { switch c { case AuthenticateType: @@ -66,7 +68,7 @@ type CommandHead struct { func NewCommandHead(TYPE CommandType) CommandHead { return CommandHead{ - VER: 0x04, + VER: VER, TYPE: TYPE, } } diff --git a/transport/tuic/v4/server.go b/transport/tuic/v4/server.go index 37b311b027..9513ccfd51 100644 --- a/transport/tuic/v4/server.go +++ b/transport/tuic/v4/server.go @@ -3,18 +3,14 @@ package v4 import ( "bufio" "bytes" - "context" - "crypto/tls" "fmt" "net" "sync" - "sync/atomic" - "time" "github.com/Dreamacro/clash/adapter/inbound" + "github.com/Dreamacro/clash/common/atomic" N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" - "github.com/Dreamacro/clash/common/utils" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/transport/socks5" "github.com/Dreamacro/clash/transport/tuic/common" @@ -27,106 +23,55 @@ type ServerOption struct { HandleTcpFn func(conn net.Conn, addr socks5.Addr, additions ...inbound.Addition) error HandleUdpFn func(addr socks5.Addr, packet C.UDPPacket, additions ...inbound.Addition) error - TlsConfig *tls.Config - QuicConfig *quic.Config Tokens [][32]byte - CongestionController string - AuthenticationTimeout time.Duration MaxUdpRelayPacketSize int - CWND int } -type Server struct { - *ServerOption - listener *quic.EarlyListener -} - -func NewServer(option *ServerOption, pc net.PacketConn) (*Server, error) { - listener, err := quic.ListenEarly(pc, option.TlsConfig, option.QuicConfig) - if err != nil { - return nil, err - } - return &Server{ +func NewServerHandler(option *ServerOption, quicConn quic.EarlyConnection, uuid uuid.UUID) common.ServerHandler { + return &serverHandler{ ServerOption: option, - listener: listener, - }, err -} - -func (s *Server) Serve() error { - for { - conn, err := s.listener.Accept(context.Background()) - if err != nil { - return err - } - common.SetCongestionController(conn, s.CongestionController, s.CWND) - h := &serverHandler{ - Server: s, - quicConn: conn, - uuid: utils.NewUUIDV4(), - authCh: make(chan struct{}), - } - go h.handle() + quicConn: quicConn, + uuid: uuid, + authCh: make(chan struct{}), } } -func (s *Server) Close() error { - return s.listener.Close() -} - type serverHandler struct { - *Server + *ServerOption quicConn quic.EarlyConnection uuid uuid.UUID authCh chan struct{} - authOk bool + authOk atomic.Bool authOnce sync.Once udpInputMap sync.Map } -func (s *serverHandler) handle() { - go func() { - _ = s.handleUniStream() - }() - go func() { - _ = s.handleStream() - }() - go func() { - _ = s.handleMessage() - }() +func (s *serverHandler) AuthOk() bool { + return s.authOk.Load() +} - <-s.quicConn.HandshakeComplete() - time.AfterFunc(s.AuthenticationTimeout, func() { - s.authOnce.Do(func() { - _ = s.quicConn.CloseWithError(AuthenticationTimeout, "AuthenticationTimeout") - s.authOk = false - close(s.authCh) - }) +func (s *serverHandler) HandleTimeout() { + s.authOnce.Do(func() { + _ = s.quicConn.CloseWithError(AuthenticationTimeout, "AuthenticationTimeout") + s.authOk.Store(false) + close(s.authCh) }) } -func (s *serverHandler) handleMessage() (err error) { - for { - var message []byte - message, err = s.quicConn.ReceiveMessage() - if err != nil { - return err - } - go func() (err error) { - buffer := bytes.NewBuffer(message) - packet, err := ReadPacket(buffer) - if err != nil { - return - } - return s.parsePacket(packet, common.NATIVE) - }() +func (s *serverHandler) HandleMessage(message []byte) (err error) { + buffer := bytes.NewBuffer(message) + packet, err := ReadPacket(buffer) + if err != nil { + return } + return s.parsePacket(packet, common.NATIVE) } func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayMode) (err error) { <-s.authCh - if !s.authOk { + if !s.authOk.Load() { return } var assocId uint32 @@ -157,119 +102,90 @@ func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayM }) } -func (s *serverHandler) handleStream() (err error) { - for { - var quicStream quic.Stream - quicStream, err = s.quicConn.AcceptStream(context.Background()) - if err != nil { - return err - } - go func() (err error) { - stream := common.NewQuicStreamConn( - quicStream, - s.quicConn.LocalAddr(), - s.quicConn.RemoteAddr(), - nil, - ) - conn := N.NewBufferedConn(stream) - connect, err := ReadConnect(conn) - if err != nil { - return err - } - <-s.authCh - if !s.authOk { - return conn.Close() - } - - buf := pool.GetBuffer() - defer pool.PutBuffer(buf) - err = s.HandleTcpFn(conn, connect.ADDR.SocksAddr()) - if err != nil { - err = NewResponseFailed().WriteTo(buf) - defer conn.Close() - } else { - err = NewResponseSucceed().WriteTo(buf) - } - if err != nil { - _ = conn.Close() - return err - } - _, err = buf.WriteTo(stream) - if err != nil { - _ = conn.Close() - return err - } +func (s *serverHandler) HandleStream(conn *N.BufferedConn) (err error) { + connect, err := ReadConnect(conn) + if err != nil { + return err + } + <-s.authCh + if !s.authOk.Load() { + return conn.Close() + } - return - }() + buf := pool.GetBuffer() + defer pool.PutBuffer(buf) + err = s.HandleTcpFn(conn, connect.ADDR.SocksAddr()) + if err != nil { + err = NewResponseFailed().WriteTo(buf) + defer conn.Close() + } else { + err = NewResponseSucceed().WriteTo(buf) } + if err != nil { + _ = conn.Close() + return err + } + _, err = buf.WriteTo(conn) + if err != nil { + _ = conn.Close() + return err + } + + return } -func (s *serverHandler) handleUniStream() (err error) { - for { - var stream quic.ReceiveStream - stream, err = s.quicConn.AcceptUniStream(context.Background()) +func (s *serverHandler) HandleUniStream(reader *bufio.Reader) (err error) { + commandHead, err := ReadCommandHead(reader) + if err != nil { + return + } + switch commandHead.TYPE { + case AuthenticateType: + var authenticate Authenticate + authenticate, err = ReadAuthenticateWithHead(commandHead, reader) if err != nil { - return err + return } - go func() (err error) { - defer func() { - stream.CancelRead(0) - }() - reader := bufio.NewReader(stream) - commandHead, err := ReadCommandHead(reader) - if err != nil { - return + authOk := false + for _, tkn := range s.Tokens { + if authenticate.TKN == tkn { + authOk = true + break } - switch commandHead.TYPE { - case AuthenticateType: - var authenticate Authenticate - authenticate, err = ReadAuthenticateWithHead(commandHead, reader) - if err != nil { - return - } - authOk := false - for _, tkn := range s.Tokens { - if authenticate.TKN == tkn { - authOk = true - break - } - } - s.authOnce.Do(func() { - if !authOk { - _ = s.quicConn.CloseWithError(AuthenticationFailed, "AuthenticationFailed") - } - s.authOk = authOk - close(s.authCh) - }) - case PacketType: - var packet Packet - packet, err = ReadPacketWithHead(commandHead, reader) - if err != nil { - return - } - return s.parsePacket(packet, common.QUIC) - case DissociateType: - var disassociate Dissociate - disassociate, err = ReadDissociateWithHead(commandHead, reader) - if err != nil { - return - } - if v, loaded := s.udpInputMap.LoadAndDelete(disassociate.ASSOC_ID); loaded { - writeClosed := v.(*atomic.Bool) - writeClosed.Store(true) - } - case HeartbeatType: - var heartbeat Heartbeat - heartbeat, err = ReadHeartbeatWithHead(commandHead, reader) - if err != nil { - return - } - heartbeat.BytesLen() + } + s.authOnce.Do(func() { + if !authOk { + _ = s.quicConn.CloseWithError(AuthenticationFailed, "AuthenticationFailed") } + s.authOk.Store(authOk) + close(s.authCh) + }) + case PacketType: + var packet Packet + packet, err = ReadPacketWithHead(commandHead, reader) + if err != nil { + return + } + return s.parsePacket(packet, common.QUIC) + case DissociateType: + var disassociate Dissociate + disassociate, err = ReadDissociateWithHead(commandHead, reader) + if err != nil { return - }() + } + if v, loaded := s.udpInputMap.LoadAndDelete(disassociate.ASSOC_ID); loaded { + writeClosed := v.(*atomic.Bool) + writeClosed.Store(true) + } + case HeartbeatType: + var heartbeat Heartbeat + heartbeat, err = ReadHeartbeatWithHead(commandHead, reader) + if err != nil { + return + } + heartbeat.BytesLen() } + return } type serverUDPPacket struct { diff --git a/transport/tuic/v5/packet.go b/transport/tuic/v5/packet.go index 9f546400eb..4a11d67182 100644 --- a/transport/tuic/v5/packet.go +++ b/transport/tuic/v5/packet.go @@ -4,9 +4,9 @@ import ( "errors" "net" "sync" - "sync/atomic" "time" + "github.com/Dreamacro/clash/common/atomic" N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/transport/tuic/common" diff --git a/transport/tuic/v5/protocol.go b/transport/tuic/v5/protocol.go index dc7062ea15..83b4414649 100644 --- a/transport/tuic/v5/protocol.go +++ b/transport/tuic/v5/protocol.go @@ -35,6 +35,8 @@ const ( HeartbeatType = CommandType(0x04) ) +const VER byte = 0x05 + func (c CommandType) String() string { switch c { case AuthenticateType: @@ -63,7 +65,7 @@ type CommandHead struct { func NewCommandHead(TYPE CommandType) CommandHead { return CommandHead{ - VER: 0x05, + VER: VER, TYPE: TYPE, } } diff --git a/transport/tuic/v5/server.go b/transport/tuic/v5/server.go index 7b21ee6c59..96b3d24fdc 100644 --- a/transport/tuic/v5/server.go +++ b/transport/tuic/v5/server.go @@ -3,18 +3,13 @@ package v5 import ( "bufio" "bytes" - "context" - "crypto/tls" "fmt" - "net" "sync" - "sync/atomic" - "time" "github.com/Dreamacro/clash/adapter/inbound" + "github.com/Dreamacro/clash/common/atomic" N "github.com/Dreamacro/clash/common/net" - "github.com/Dreamacro/clash/common/utils" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/transport/socks5" "github.com/Dreamacro/clash/transport/tuic/common" @@ -27,123 +22,72 @@ type ServerOption struct { HandleTcpFn func(conn net.Conn, addr socks5.Addr, additions ...inbound.Addition) error HandleUdpFn func(addr socks5.Addr, packet C.UDPPacket, additions ...inbound.Addition) error - TlsConfig *tls.Config - QuicConfig *quic.Config Users map[[16]byte]string - CongestionController string - AuthenticationTimeout time.Duration MaxUdpRelayPacketSize int - CWND int -} - -type Server struct { - *ServerOption - listener *quic.EarlyListener } -func NewServer(option *ServerOption, pc net.PacketConn) (*Server, error) { - listener, err := quic.ListenEarly(pc, option.TlsConfig, option.QuicConfig) - if err != nil { - return nil, err - } - return &Server{ +func NewServerHandler(option *ServerOption, quicConn quic.EarlyConnection, uuid uuid.UUID) common.ServerHandler { + return &serverHandler{ ServerOption: option, - listener: listener, - }, err -} - -func (s *Server) Serve() error { - for { - conn, err := s.listener.Accept(context.Background()) - if err != nil { - return err - } - common.SetCongestionController(conn, s.CongestionController, s.CWND) - h := &serverHandler{ - Server: s, - quicConn: conn, - uuid: utils.NewUUIDV4(), - authCh: make(chan struct{}), - } - go h.handle() + quicConn: quicConn, + uuid: uuid, + authCh: make(chan struct{}), } } -func (s *Server) Close() error { - return s.listener.Close() -} - type serverHandler struct { - *Server + *ServerOption quicConn quic.EarlyConnection uuid uuid.UUID authCh chan struct{} - authOk bool - authUUID string + authOk atomic.Bool + authUUID atomic.TypedValue[string] authOnce sync.Once udpInputMap sync.Map } -func (s *serverHandler) handle() { - go func() { - _ = s.handleUniStream() - }() - go func() { - _ = s.handleStream() - }() - go func() { - _ = s.handleMessage() - }() +func (s *serverHandler) AuthOk() bool { + return s.authOk.Load() +} - <-s.quicConn.HandshakeComplete() - time.AfterFunc(s.AuthenticationTimeout, func() { - s.authOnce.Do(func() { - _ = s.quicConn.CloseWithError(AuthenticationTimeout, "AuthenticationTimeout") - s.authOk = false - close(s.authCh) - }) +func (s *serverHandler) HandleTimeout() { + s.authOnce.Do(func() { + _ = s.quicConn.CloseWithError(AuthenticationTimeout, "AuthenticationTimeout") + s.authOk.Store(false) + close(s.authCh) }) } -func (s *serverHandler) handleMessage() (err error) { - for { - var message []byte - message, err = s.quicConn.ReceiveMessage() +func (s *serverHandler) HandleMessage(message []byte) (err error) { + reader := bytes.NewBuffer(message) + commandHead, err := ReadCommandHead(reader) + if err != nil { + return + } + switch commandHead.TYPE { + case PacketType: + var packet Packet + packet, err = ReadPacketWithHead(commandHead, reader) if err != nil { - return err + return } - go func() (err error) { - reader := bytes.NewBuffer(message) - commandHead, err := ReadCommandHead(reader) - if err != nil { - return - } - switch commandHead.TYPE { - case PacketType: - var packet Packet - packet, err = ReadPacketWithHead(commandHead, reader) - if err != nil { - return - } - return s.parsePacket(packet, common.NATIVE) - case HeartbeatType: - var heartbeat Heartbeat - heartbeat, err = ReadHeartbeatWithHead(commandHead, reader) - if err != nil { - return - } - heartbeat.BytesLen() - } + return s.parsePacket(packet, common.NATIVE) + case HeartbeatType: + var heartbeat Heartbeat + heartbeat, err = ReadHeartbeatWithHead(commandHead, reader) + if err != nil { return - }() + } + heartbeat.BytesLen() } + return } func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayMode) (err error) { <-s.authCh - if !s.authOk { + if !s.authOk.Load() { return } var assocId uint16 @@ -175,108 +119,79 @@ func (s *serverHandler) parsePacket(packet Packet, udpRelayMode common.UdpRelayM pc: pc, packet: packetPtr, rAddr: N.NewCustomAddr("tuic", fmt.Sprintf("tuic-%s-%d", s.uuid, assocId), s.quicConn.RemoteAddr()), // for tunnel's handleUDPConn - }, inbound.WithInUser(s.authUUID)) + }, inbound.WithInUser(s.authUUID.Load())) } -func (s *serverHandler) handleStream() (err error) { - for { - var quicStream quic.Stream - quicStream, err = s.quicConn.AcceptStream(context.Background()) - if err != nil { - return err - } - go func() (err error) { - stream := common.NewQuicStreamConn( - quicStream, - s.quicConn.LocalAddr(), - s.quicConn.RemoteAddr(), - nil, - ) - conn := N.NewBufferedConn(stream) - connect, err := ReadConnect(conn) - if err != nil { - return err - } - <-s.authCh - if !s.authOk { - return conn.Close() - } +func (s *serverHandler) HandleStream(conn *N.BufferedConn) (err error) { + connect, err := ReadConnect(conn) + if err != nil { + return err + } + <-s.authCh + if !s.authOk.Load() { + return conn.Close() + } - err = s.HandleTcpFn(conn, connect.ADDR.SocksAddr(), inbound.WithInUser(s.authUUID)) - if err != nil { - _ = conn.Close() - return err - } - return - }() + err = s.HandleTcpFn(conn, connect.ADDR.SocksAddr(), inbound.WithInUser(s.authUUID.Load())) + if err != nil { + _ = conn.Close() + return err } + return } -func (s *serverHandler) handleUniStream() (err error) { - for { - var stream quic.ReceiveStream - stream, err = s.quicConn.AcceptUniStream(context.Background()) +func (s *serverHandler) HandleUniStream(reader *bufio.Reader) (err error) { + commandHead, err := ReadCommandHead(reader) + if err != nil { + return + } + switch commandHead.TYPE { + case AuthenticateType: + var authenticate Authenticate + authenticate, err = ReadAuthenticateWithHead(commandHead, reader) if err != nil { - return err + return } - go func() (err error) { - defer func() { - stream.CancelRead(0) - }() - reader := bufio.NewReader(stream) - commandHead, err := ReadCommandHead(reader) + authOk := false + var authUUID uuid.UUID + var token [32]byte + if password, ok := s.Users[authenticate.UUID]; ok { + token, err = GenToken(s.quicConn.ConnectionState(), authenticate.UUID, password) if err != nil { return } - switch commandHead.TYPE { - case AuthenticateType: - var authenticate Authenticate - authenticate, err = ReadAuthenticateWithHead(commandHead, reader) - if err != nil { - return - } - authOk := false - var authUUID uuid.UUID - var token [32]byte - if password, ok := s.Users[authenticate.UUID]; ok { - token, err = GenToken(s.quicConn.ConnectionState(), authenticate.UUID, password) - if err != nil { - return - } - if token == authenticate.TOKEN { - authOk = true - authUUID = authenticate.UUID - } - } - s.authOnce.Do(func() { - if !authOk { - _ = s.quicConn.CloseWithError(AuthenticationFailed, "AuthenticationFailed") - } - s.authOk = authOk - s.authUUID = authUUID.String() - close(s.authCh) - }) - case PacketType: - var packet Packet - packet, err = ReadPacketWithHead(commandHead, reader) - if err != nil { - return - } - return s.parsePacket(packet, common.QUIC) - case DissociateType: - var disassociate Dissociate - disassociate, err = ReadDissociateWithHead(commandHead, reader) - if err != nil { - return - } - if v, loaded := s.udpInputMap.LoadAndDelete(disassociate.ASSOC_ID); loaded { - input := v.(*serverUDPInput) - input.writeClosed.Store(true) - } + if token == authenticate.TOKEN { + authOk = true + authUUID = authenticate.UUID } + } + s.authOnce.Do(func() { + if !authOk { + _ = s.quicConn.CloseWithError(AuthenticationFailed, "AuthenticationFailed") + } + s.authOk.Store(authOk) + s.authUUID.Store(authUUID.String()) + close(s.authCh) + }) + case PacketType: + var packet Packet + packet, err = ReadPacketWithHead(commandHead, reader) + if err != nil { return - }() + } + return s.parsePacket(packet, common.QUIC) + case DissociateType: + var disassociate Dissociate + disassociate, err = ReadDissociateWithHead(commandHead, reader) + if err != nil { + return + } + if v, loaded := s.udpInputMap.LoadAndDelete(disassociate.ASSOC_ID); loaded { + input := v.(*serverUDPInput) + input.writeClosed.Store(true) + } } + return } type serverUDPInput struct {