diff --git a/adapter/outbound/shadowsocks.go b/adapter/outbound/shadowsocks.go index 49ff45eda1..af034472b8 100644 --- a/adapter/outbound/shadowsocks.go +++ b/adapter/outbound/shadowsocks.go @@ -86,22 +86,14 @@ type restlsOption struct { // StreamConn implements C.ProxyAdapter func (ss *ShadowSocks) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { - switch ss.obfsMode { - case shadowtls.Mode: - // fix tls handshake not timeout - ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) - defer cancel() - var err error - c, err = shadowtls.NewShadowTLS(ctx, c, ss.shadowTLSOption) - if err != nil { - return nil, err - } - - } - return ss.streamConn(c, metadata) + // fix tls handshake not timeout + ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout) + defer cancel() + return ss.StreamConnContext(ctx, c, metadata) } -func (ss *ShadowSocks) streamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) { +func (ss *ShadowSocks) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) { + useEarly := false switch ss.obfsMode { case "tls": c = obfs.NewTLSObfs(c, ss.obfsOption.Host) @@ -114,21 +106,30 @@ func (ss *ShadowSocks) streamConn(c net.Conn, metadata *C.Metadata) (net.Conn, e if err != nil { return nil, fmt.Errorf("%s connect error: %w", ss.addr, err) } + case shadowtls.Mode: + var err error + c, err = shadowtls.NewShadowTLS(ctx, c, ss.shadowTLSOption) + if err != nil { + return nil, err + } + useEarly = true case restls.Mode: var err error - c, err = restls.NewRestls(c, ss.restlsConfig) + c, err = restls.NewRestls(ctx, c, ss.restlsConfig) if err != nil { return nil, fmt.Errorf("%s (restls) connect error: %w", ss.addr, err) } + useEarly = true } + useEarly = useEarly || N.NeedHandshake(c) if metadata.NetWork == C.UDP && ss.option.UDPOverTCP { - if N.NeedHandshake(c) { + if useEarly { return ss.method.DialEarlyConn(c, M.ParseSocksaddr(uot.UOTMagicAddress+":443")), nil } else { return ss.method.DialConn(c, M.ParseSocksaddr(uot.UOTMagicAddress+":443")) } } - if N.NeedHandshake(c) { + if useEarly { return ss.method.DialEarlyConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil } else { return ss.method.DialConn(c, M.ParseSocksaddr(metadata.RemoteAddress())) @@ -152,15 +153,7 @@ func (ss *ShadowSocks) DialContextWithDialer(ctx context.Context, dialer C.Diale safeConnClose(c, err) }(c) - switch ss.obfsMode { - case shadowtls.Mode: - c, err = shadowtls.NewShadowTLS(ctx, c, ss.shadowTLSOption) - if err != nil { - return nil, err - } - } - - c, err = ss.streamConn(c, metadata) + c, err = ss.StreamConnContext(ctx, c, metadata) return NewConn(c, ss), err } diff --git a/transport/restls/restls.go b/transport/restls/restls.go index 4b03b8254f..0f3ba8ac77 100644 --- a/transport/restls/restls.go +++ b/transport/restls/restls.go @@ -1,6 +1,7 @@ package restls import ( + "context" "net" tls "github.com/3andne/restls-client-go" @@ -10,48 +11,29 @@ const ( Mode string = "restls" ) -// Restls type Restls struct { - net.Conn - firstPacketCache []byte - firstPacket bool + *tls.UConn } -func (r *Restls) Read(b []byte) (int, error) { - if err := r.Conn.(*tls.UConn).Handshake(); err != nil { - return 0, err - } - n, err := r.Conn.(*tls.UConn).Read(b) - return n, err -} - -func (r *Restls) Write(b []byte) (int, error) { - if r.firstPacket { - r.firstPacketCache = append([]byte(nil), b...) - r.firstPacket = false - return len(b), nil - } - if len(r.firstPacketCache) != 0 { - b = append(r.firstPacketCache, b...) - r.firstPacketCache = nil - } - n, err := r.Conn.(*tls.UConn).Write(b) - return n, err +func (r *Restls) Upstream() any { + return r.UConn.NetConn() } // NewRestls return a Restls Connection -func NewRestls(conn net.Conn, config *tls.Config) (net.Conn, error) { +func NewRestls(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, error) { + clientHellowID := tls.HelloChrome_Auto if config != nil { clientIDPtr := config.ClientID.Load() if clientIDPtr != nil { - return &Restls{ - Conn: tls.UClient(conn, config, *clientIDPtr), - firstPacket: true, - }, nil + clientHellowID = *clientIDPtr } } - return &Restls{ - Conn: tls.UClient(conn, config, tls.HelloChrome_Auto), - firstPacket: true, - }, nil + restls := &Restls{ + UConn: tls.UClient(conn, config, clientHellowID), + } + if err := restls.HandshakeContext(ctx); err != nil { + return nil, err + } + + return restls, nil }