diff --git a/transport/vless/conn.go b/transport/vless/conn.go index 75eef495d5..0f4d3bf765 100644 --- a/transport/vless/conn.go +++ b/transport/vless/conn.go @@ -6,6 +6,8 @@ import ( "fmt" "io" "net" + "sync" + "time" "github.com/Dreamacro/clash/common/buf" N "github.com/Dreamacro/clash/common/net" @@ -21,6 +23,10 @@ type Conn struct { id *uuid.UUID addons *Addons received bool + + handshake chan struct{} + handshakeMutex sync.Mutex + err error } func (vc *Conn) Read(b []byte) (int, error) { @@ -47,7 +53,41 @@ func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error { return vc.ExtendedConn.ReadBuffer(buffer) } -func (vc *Conn) sendRequest() (err error) { +func (vc *Conn) Write(p []byte) (int, error) { + select { + case <-vc.handshake: + default: + err := vc.sendRequest(p) + if err != nil { + return 0, err + } + } + return vc.ExtendedConn.Write(p) +} + +func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error { + select { + case <-vc.handshake: + default: + err := vc.sendRequest(buffer.Bytes()) + if err != nil { + return err + } + } + return vc.ExtendedConn.WriteBuffer(buffer) +} + +func (vc *Conn) sendRequest(p []byte) (err error) { + vc.handshakeMutex.Lock() + defer vc.handshakeMutex.Unlock() + + select { + case <-vc.handshake: + return vc.err + default: + } + defer close(vc.handshake) + requestLen := 1 // protocol version requestLen += 16 // UUID requestLen += 1 // addons length @@ -65,6 +105,8 @@ func (vc *Conn) sendRequest() (err error) { requestLen += 1 // addr type requestLen += len(vc.dst.Addr) } + requestLen += len(p) + _buffer := buf.StackNewSize(requestLen) defer buf.KeepAlive(_buffer) buffer := buf.Dup(_buffer) @@ -93,25 +135,26 @@ func (vc *Conn) sendRequest() (err error) { ) } + buf.Must(buf.Error(buffer.Write(p))) + _, err = vc.ExtendedConn.Write(buffer.Bytes()) return } func (vc *Conn) recvResponse() error { - var err error var buf [1]byte - _, err = io.ReadFull(vc.ExtendedConn, buf[:]) - if err != nil { - return err + _, vc.err = io.ReadFull(vc.ExtendedConn, buf[:]) + if vc.err != nil { + return vc.err } if buf[0] != Version { return errors.New("unexpected response version") } - _, err = io.ReadFull(vc.ExtendedConn, buf[:]) - if err != nil { - return err + _, vc.err = io.ReadFull(vc.ExtendedConn, buf[:]) + if vc.err != nil { + return vc.err } length := int64(buf[0]) @@ -132,6 +175,7 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) { ExtendedConn: N.NewExtendedConn(conn), id: client.uuid, dst: dst, + handshake: make(chan struct{}), } if !dst.UDP && client.Addons != nil { @@ -155,8 +199,12 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) { } } - if err := c.sendRequest(); err != nil { - return nil, err - } + go func() { + select { + case <-c.handshake: + case <-time.After(200 * time.Millisecond): + _ = c.sendRequest(nil) + } + }() return c, nil }