Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transport: Add HTTP3 to HTTP #3819

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion infra/conf/transport_internet.go
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ func (p TransportProtocol) Build() (string, error) {
return "mkcp", nil
case "ws", "websocket":
return "websocket", nil
case "h2", "http":
case "h2", "h3", "http":
return "http", nil
case "grpc", "gun":
return "grpc", nil
Expand Down
175 changes: 119 additions & 56 deletions transport/internet/http/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"sync"
"time"

"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
c "github.com/xtls/xray-core/common/ctx"
Expand All @@ -24,6 +26,13 @@ import (
"golang.org/x/net/http2"
)

// defines the maximum time an idle TCP session can survive in the tunnel, so
// it should be consistent across HTTP versions and with other transports.
const connIdleTimeout = 300 * time.Second

// consistent with quic-go
const h3KeepalivePeriod = 10 * time.Second

type dialerConf struct {
net.Destination
*internet.MemoryStreamConfig
Expand All @@ -48,72 +57,129 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
if tlsConfigs == nil && realityConfigs == nil {
return nil, errors.New("TLS or REALITY must be enabled for http transport.").AtWarning()
}
isH3 := tlsConfigs != nil && (len(tlsConfigs.NextProtocol) == 1 && tlsConfigs.NextProtocol[0] == "h3")
if isH3 {
dest.Network = net.Network_UDP
}
sockopt := streamSettings.SocketSettings

if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found {
return client, nil
}

transport := &http2.Transport{
DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
rawHost, rawPort, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if len(rawPort) == 0 {
rawPort = "443"
}
port, err := net.PortFromString(rawPort)
if err != nil {
return nil, err
}
address := net.ParseAddress(rawHost)
var transport http.RoundTripper
if isH3 {
quicConfig := &quic.Config{
MaxIdleTimeout: connIdleTimeout,

hctx = c.ContextWithID(hctx, c.IDFromContext(ctx))
hctx = session.ContextWithOutbounds(hctx, session.OutboundsFromContext(ctx))
hctx = session.ContextWithTimeoutOnly(hctx, true)
// these two are defaults of quic-go/http3. the default of quic-go (no
// http3) is different, so it is hardcoded here for clarity.
// https://github.com/quic-go/quic-go/blob/b8ea5c798155950fb5bbfdd06cad1939c9355878/http3/client.go#L36-L39
MaxIncomingStreams: -1,
KeepAlivePeriod: h3KeepalivePeriod,
}
roundTripper := &http3.RoundTripper{
QUICConfig: quicConfig,
TLSClientConfig: tlsConfigs.GetTLSConfig(tls.WithDestination(dest)),
Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
if err != nil {
return nil, err
}

pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt)
if err != nil {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
return nil, err
}
var udpConn net.PacketConn
var udpAddr *net.UDPAddr

if realityConfigs != nil {
return reality.UClient(pconn, realityConfigs, hctx, dest)
}
switch c := conn.(type) {
case *internet.PacketConnWrapper:
var ok bool
udpConn, ok = c.Conn.(*net.UDPConn)
if !ok {
return nil, errors.New("PacketConnWrapper does not contain a UDP connection")
}
udpAddr, err = net.ResolveUDPAddr("udp", c.Dest.String())
if err != nil {
return nil, err
}
case *net.UDPConn:
udpConn = c
udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
if err != nil {
return nil, err
}
default:
udpConn = &internet.FakePacketConn{c}
udpAddr, err = net.ResolveUDPAddr("udp", c.RemoteAddr().String())
if err != nil {
return nil, err
}
}

var cn tls.Interface
if fingerprint := tls.GetFingerprint(tlsConfigs.Fingerprint); fingerprint != nil {
cn = tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
} else {
cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
}
if err := cn.HandshakeContext(ctx); err != nil {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
return nil, err
}
if !tlsConfig.InsecureSkipVerify {
if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg)
},
}
transport = roundTripper
} else {
transportH2 := &http2.Transport{
DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
rawHost, rawPort, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
if len(rawPort) == 0 {
rawPort = "443"
}
port, err := net.PortFromString(rawPort)
if err != nil {
return nil, err
}
address := net.ParseAddress(rawHost)

hctx = c.ContextWithID(hctx, c.IDFromContext(ctx))
hctx = session.ContextWithOutbounds(hctx, session.OutboundsFromContext(ctx))
hctx = session.ContextWithTimeoutOnly(hctx, true)

pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt)
if err != nil {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
return nil, err
}
}
negotiatedProtocol := cn.NegotiatedProtocol()
if negotiatedProtocol != http2.NextProtoTLS {
return nil, errors.New("http2: unexpected ALPN protocol " + negotiatedProtocol + "; want q" + http2.NextProtoTLS).AtError()
}
return cn, nil
},
}

if tlsConfigs != nil {
transport.TLSClientConfig = tlsConfigs.GetTLSConfig(tls.WithDestination(dest))
}

if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 {
transport.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout)
transport.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout)

if realityConfigs != nil {
return reality.UClient(pconn, realityConfigs, hctx, dest)
}

var cn tls.Interface
if fingerprint := tls.GetFingerprint(tlsConfigs.Fingerprint); fingerprint != nil {
cn = tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
} else {
cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
}
if err := cn.HandshakeContext(ctx); err != nil {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
return nil, err
}
if !tlsConfig.InsecureSkipVerify {
if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr)
return nil, err
}
}
negotiatedProtocol := cn.NegotiatedProtocol()
if negotiatedProtocol != http2.NextProtoTLS {
return nil, errors.New("http2: unexpected ALPN protocol " + negotiatedProtocol + "; want q" + http2.NextProtoTLS).AtError()
}
return cn, nil
},
}
if tlsConfigs != nil {
transportH2.TLSClientConfig = tlsConfigs.GetTLSConfig(tls.WithDestination(dest))
}
if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 {
transportH2.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout)
transportH2.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout)
}
transport = transportH2
}

client := &http.Client{
Expand Down Expand Up @@ -158,9 +224,6 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
Host: dest.NetAddr(),
Path: httpSettings.getNormalizedPath(),
},
Proto: "HTTP/2",
ProtoMajor: 2,
ProtoMinor: 0,
Header: httpHeaders,
}
// Disable any compression method from server.
Expand Down
78 changes: 78 additions & 0 deletions transport/internet/http/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol/tls/cert"
"github.com/xtls/xray-core/testing/servers/tcp"
"github.com/xtls/xray-core/testing/servers/udp"
"github.com/xtls/xray-core/transport/internet"
. "github.com/xtls/xray-core/transport/internet/http"
"github.com/xtls/xray-core/transport/internet/stat"
Expand Down Expand Up @@ -92,3 +93,80 @@ func TestHTTPConnection(t *testing.T) {
t.Error(r)
}
}

func TestH3Connection(t *testing.T) {
port := udp.PickPort()

listener, err := Listen(context.Background(), net.LocalHostIP, port, &internet.MemoryStreamConfig{
ProtocolName: "http",
ProtocolSettings: &Config{},
SecurityType: "tls",
SecuritySettings: &tls.Config{
NextProtocol: []string{"h3"},
Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("www.example.com")))},
},
}, func(conn stat.Connection) {
go func() {
defer conn.Close()

b := buf.New()
defer b.Release()

for {
if _, err := b.ReadFrom(conn); err != nil {
return
}
_, err := conn.Write(b.Bytes())
common.Must(err)
}
}()
})
common.Must(err)

defer listener.Close()

time.Sleep(time.Second)

dctx := context.Background()
conn, err := Dial(dctx, net.TCPDestination(net.LocalHostIP, port), &internet.MemoryStreamConfig{
ProtocolName: "http",
ProtocolSettings: &Config{},
SecurityType: "tls",
SecuritySettings: &tls.Config{
NextProtocol: []string{"h3"},
ServerName: "www.example.com",
AllowInsecure: true,
},
})
common.Must(err)
defer conn.Close()

const N = 1024
b1 := make([]byte, N)
common.Must2(rand.Read(b1))
b2 := buf.New()

nBytes, err := conn.Write(b1)
common.Must(err)
if nBytes != N {
t.Error("write: ", nBytes)
}

b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N))
if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}

nBytes, err = conn.Write(b1)
common.Must(err)
if nBytes != N {
t.Error("write: ", nBytes)
}

b2.Clear()
common.Must2(b2.ReadFullFrom(conn, N))
if r := cmp.Diff(b2.Bytes(), b1); r != "" {
t.Error(r)
}
}
Loading