From a2f09b46a5e4e1f28b1b05c0f9f736d0969cad0d Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Sun, 2 Jun 2024 00:20:52 +0800 Subject: [PATCH 01/13] chore: tidy workspace --- go.mod | 1 + go.sum | 2 ++ 2 files changed, 3 insertions(+) diff --git a/go.mod b/go.mod index 125c96c..bcb852d 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/onsi/ginkgo/v2 v2.11.0 // indirect + github.com/quic-go/qpack v0.4.0 // indirect go.uber.org/mock v0.4.0 // indirect golang.org/x/mod v0.12.0 // indirect golang.org/x/text v0.14.0 // indirect diff --git a/go.sum b/go.sum index b27da2f..8d32359 100644 --- a/go.sum +++ b/go.sum @@ -53,6 +53,8 @@ github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc= github.com/onsi/gomega v1.27.8/go.mod h1:2J8vzI/s+2shY9XHRApDkdgPo1TKT7P2u6fXeJKFnNQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= +github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= github.com/refraction-networking/utls v1.6.4 h1:aeynTroaYn7y+mFtqv8D0bQ4bw0y9nJHneGxJ7lvRDM= github.com/refraction-networking/utls v1.6.4/go.mod h1:2VL2xfiqgFAZtJKeUTlf+PSYFs3Eu7km0gCtXJ3m8zs= github.com/seiflotfy/cuckoofilter v0.0.0-20220411075957-e3b120b3f5fb h1:XfLJSPIOUX+osiMraVgIrMR27uMXnRJWGm1+GL8/63U= From 57e5e4ee3850f536e6a4465f248b038106bf0c35 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Sun, 2 Jun 2024 00:21:04 +0800 Subject: [PATCH 02/13] feat: introduce hysteria2 --- protocol/hysteria2/client/client.go | 315 ++++++ protocol/hysteria2/client/config.go | 112 ++ protocol/hysteria2/client/reconnect.go | 120 +++ protocol/hysteria2/client/udp.go | 185 ++++ protocol/hysteria2/errors/errors.go | 75 ++ .../internal/congestion/bbr/bandwidth.go | 27 + .../congestion/bbr/bandwidth_sampler.go | 874 ++++++++++++++++ .../internal/congestion/bbr/bbr_sender.go | 984 ++++++++++++++++++ .../internal/congestion/bbr/clock.go | 18 + .../bbr/packet_number_indexed_queue.go | 199 ++++ .../internal/congestion/bbr/ringbuffer.go | 118 +++ .../congestion/bbr/windowed_filter.go | 162 +++ .../internal/congestion/brutal/brutal.go | 185 ++++ .../internal/congestion/common/pacer.go | 79 ++ .../hysteria2/internal/congestion/utils.go | 18 + protocol/hysteria2/internal/frag/frag.go | 77 ++ protocol/hysteria2/internal/frag/frag_test.go | 336 ++++++ protocol/hysteria2/internal/pmtud/avail.go | 7 + protocol/hysteria2/internal/pmtud/unavail.go | 13 + protocol/hysteria2/internal/protocol/http.go | 68 ++ .../hysteria2/internal/protocol/padding.go | 31 + protocol/hysteria2/internal/protocol/proxy.go | 255 +++++ .../hysteria2/internal/protocol/proxy_test.go | 317 ++++++ protocol/hysteria2/internal/utils/atomic.go | 24 + protocol/hysteria2/internal/utils/qstream.go | 62 ++ 25 files changed, 4661 insertions(+) create mode 100644 protocol/hysteria2/client/client.go create mode 100644 protocol/hysteria2/client/config.go create mode 100644 protocol/hysteria2/client/reconnect.go create mode 100644 protocol/hysteria2/client/udp.go create mode 100644 protocol/hysteria2/errors/errors.go create mode 100644 protocol/hysteria2/internal/congestion/bbr/bandwidth.go create mode 100644 protocol/hysteria2/internal/congestion/bbr/bandwidth_sampler.go create mode 100644 protocol/hysteria2/internal/congestion/bbr/bbr_sender.go create mode 100644 protocol/hysteria2/internal/congestion/bbr/clock.go create mode 100644 protocol/hysteria2/internal/congestion/bbr/packet_number_indexed_queue.go create mode 100644 protocol/hysteria2/internal/congestion/bbr/ringbuffer.go create mode 100644 protocol/hysteria2/internal/congestion/bbr/windowed_filter.go create mode 100644 protocol/hysteria2/internal/congestion/brutal/brutal.go create mode 100644 protocol/hysteria2/internal/congestion/common/pacer.go create mode 100644 protocol/hysteria2/internal/congestion/utils.go create mode 100644 protocol/hysteria2/internal/frag/frag.go create mode 100644 protocol/hysteria2/internal/frag/frag_test.go create mode 100644 protocol/hysteria2/internal/pmtud/avail.go create mode 100644 protocol/hysteria2/internal/pmtud/unavail.go create mode 100644 protocol/hysteria2/internal/protocol/http.go create mode 100644 protocol/hysteria2/internal/protocol/padding.go create mode 100644 protocol/hysteria2/internal/protocol/proxy.go create mode 100644 protocol/hysteria2/internal/protocol/proxy_test.go create mode 100644 protocol/hysteria2/internal/utils/atomic.go create mode 100644 protocol/hysteria2/internal/utils/qstream.go diff --git a/protocol/hysteria2/client/client.go b/protocol/hysteria2/client/client.go new file mode 100644 index 0000000..5ef5a98 --- /dev/null +++ b/protocol/hysteria2/client/client.go @@ -0,0 +1,315 @@ +package client + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "net/url" + "time" + + coreErrs "github.com/daeuniverse/outbound/protocol/hysteria2/errors" + "github.com/daeuniverse/outbound/protocol/hysteria2/internal/congestion" + "github.com/daeuniverse/outbound/protocol/hysteria2/internal/protocol" + "github.com/daeuniverse/outbound/protocol/hysteria2/internal/utils" + + "github.com/daeuniverse/quic-go" + "github.com/daeuniverse/quic-go/http3" +) + +const ( + closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError + closeErrCodeProtocolError = 0x101 // HTTP3 ErrCodeGeneralProtocolError +) + +type Client interface { + TCP(addr string) (net.Conn, error) + UDP() (HyUDPConn, error) + Close() error +} + +type HyUDPConn interface { + Receive() ([]byte, string, error) + Send([]byte, string) error + Close() error +} + +type HandshakeInfo struct { + UDPEnabled bool + Tx uint64 // 0 if using BBR +} + +func NewClient(config *Config) (Client, *HandshakeInfo, error) { + if err := config.verifyAndFill(); err != nil { + return nil, nil, err + } + c := &clientImpl{ + config: config, + } + info, err := c.connect() + if err != nil { + return nil, nil, err + } + return c, info, nil +} + +type clientImpl struct { + config *Config + + pktConn net.PacketConn + conn quic.Connection + + udpSM *udpSessionManager +} + +func (c *clientImpl) connect() (*HandshakeInfo, error) { + pktConn, err := c.config.ConnFactory.New(c.config.ServerAddr) + if err != nil { + return nil, err + } + // Convert config to TLS config & QUIC config + tlsConfig := &tls.Config{ + ServerName: c.config.TLSConfig.ServerName, + InsecureSkipVerify: c.config.TLSConfig.InsecureSkipVerify, + VerifyPeerCertificate: c.config.TLSConfig.VerifyPeerCertificate, + RootCAs: c.config.TLSConfig.RootCAs, + } + quicConfig := &quic.Config{ + InitialStreamReceiveWindow: c.config.QUICConfig.InitialStreamReceiveWindow, + MaxStreamReceiveWindow: c.config.QUICConfig.MaxStreamReceiveWindow, + InitialConnectionReceiveWindow: c.config.QUICConfig.InitialConnectionReceiveWindow, + MaxConnectionReceiveWindow: c.config.QUICConfig.MaxConnectionReceiveWindow, + MaxIdleTimeout: c.config.QUICConfig.MaxIdleTimeout, + KeepAlivePeriod: c.config.QUICConfig.KeepAlivePeriod, + DisablePathMTUDiscovery: c.config.QUICConfig.DisablePathMTUDiscovery, + EnableDatagrams: true, + } + // Prepare RoundTripper + var conn quic.EarlyConnection + rt := &http3.RoundTripper{ + TLSClientConfig: tlsConfig, + QuicConfig: quicConfig, + Dial: func(ctx context.Context, _ string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + qc, err := quic.DialEarly(ctx, pktConn, c.config.ServerAddr, tlsCfg, cfg) + if err != nil { + return nil, err + } + conn = qc + return qc, nil + }, + } + // Send auth HTTP request + req := &http.Request{ + Method: http.MethodPost, + URL: &url.URL{ + Scheme: "https", + Host: protocol.URLHost, + Path: protocol.URLPath, + }, + Header: make(http.Header), + } + protocol.AuthRequestToHeader(req.Header, protocol.AuthRequest{ + Auth: c.config.Auth, + Rx: c.config.BandwidthConfig.MaxRx, + }) + resp, err := rt.RoundTrip(req) + if err != nil { + if conn != nil { + _ = conn.CloseWithError(closeErrCodeProtocolError, "") + } + _ = pktConn.Close() + return nil, coreErrs.ConnectError{Err: err} + } + if resp.StatusCode != protocol.StatusAuthOK { + _ = conn.CloseWithError(closeErrCodeProtocolError, "") + _ = pktConn.Close() + return nil, coreErrs.AuthError{StatusCode: resp.StatusCode} + } + // Auth OK + authResp := protocol.AuthResponseFromHeader(resp.Header) + var actualTx uint64 + if authResp.RxAuto { + // Server asks client to use bandwidth detection, + // ignore local bandwidth config and use BBR + congestion.UseBBR(conn) + } else { + // actualTx = min(serverRx, clientTx) + actualTx = authResp.Rx + if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx { + // Server doesn't have a limit, or our clientTx is smaller than serverRx + actualTx = c.config.BandwidthConfig.MaxTx + } + if actualTx > 0 { + congestion.UseBrutal(conn, actualTx) + } else { + // We don't know our own bandwidth either, use BBR + congestion.UseBBR(conn) + } + } + _ = resp.Body.Close() + + c.pktConn = pktConn + c.conn = conn + if authResp.UDPEnabled { + c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn}) + } + return &HandshakeInfo{ + UDPEnabled: authResp.UDPEnabled, + Tx: actualTx, + }, nil +} + +// openStream wraps the stream with QStream, which handles Close() properly +func (c *clientImpl) openStream() (quic.Stream, error) { + stream, err := c.conn.OpenStream() + if err != nil { + return nil, err + } + return &utils.QStream{Stream: stream}, nil +} + +func (c *clientImpl) TCP(addr string) (net.Conn, error) { + stream, err := c.openStream() + if err != nil { + return nil, wrapIfConnectionClosed(err) + } + // Send request + err = protocol.WriteTCPRequest(stream, addr) + if err != nil { + _ = stream.Close() + return nil, wrapIfConnectionClosed(err) + } + if c.config.FastOpen { + // Don't wait for the response when fast open is enabled. + // Return the connection immediately, defer the response handling + // to the first Read() call. + return &tcpConn{ + Orig: stream, + PseudoLocalAddr: c.conn.LocalAddr(), + PseudoRemoteAddr: c.conn.RemoteAddr(), + Established: false, + }, nil + } + // Read response + ok, msg, err := protocol.ReadTCPResponse(stream) + if err != nil { + _ = stream.Close() + return nil, wrapIfConnectionClosed(err) + } + if !ok { + _ = stream.Close() + return nil, coreErrs.DialError{Message: msg} + } + return &tcpConn{ + Orig: stream, + PseudoLocalAddr: c.conn.LocalAddr(), + PseudoRemoteAddr: c.conn.RemoteAddr(), + Established: true, + }, nil +} + +func (c *clientImpl) UDP() (HyUDPConn, error) { + if c.udpSM == nil { + return nil, coreErrs.DialError{Message: "UDP not enabled"} + } + return c.udpSM.NewUDP() +} + +func (c *clientImpl) Close() error { + _ = c.conn.CloseWithError(closeErrCodeOK, "") + _ = c.pktConn.Close() + return nil +} + +// wrapIfConnectionClosed checks if the error returned by quic-go +// indicates that the QUIC connection has been permanently closed, +// and if so, wraps the error with coreErrs.ClosedError. +// PITFALL: sometimes quic-go has "internal errors" that are not net.Error, +// but we still need to treat them as ClosedError. +func wrapIfConnectionClosed(err error) error { + netErr, ok := err.(net.Error) + if !ok || !netErr.Temporary() { + return coreErrs.ClosedError{Err: err} + } else { + return err + } +} + +type tcpConn struct { + Orig quic.Stream + PseudoLocalAddr net.Addr + PseudoRemoteAddr net.Addr + Established bool +} + +func (c *tcpConn) Read(b []byte) (n int, err error) { + if !c.Established { + // Read response + ok, msg, err := protocol.ReadTCPResponse(c.Orig) + if err != nil { + return 0, err + } + if !ok { + return 0, coreErrs.DialError{Message: msg} + } + c.Established = true + } + return c.Orig.Read(b) +} + +func (c *tcpConn) Write(b []byte) (n int, err error) { + return c.Orig.Write(b) +} + +func (c *tcpConn) Close() error { + return c.Orig.Close() +} + +func (c *tcpConn) LocalAddr() net.Addr { + return c.PseudoLocalAddr +} + +func (c *tcpConn) RemoteAddr() net.Addr { + return c.PseudoRemoteAddr +} + +func (c *tcpConn) SetDeadline(t time.Time) error { + return c.Orig.SetDeadline(t) +} + +func (c *tcpConn) SetReadDeadline(t time.Time) error { + return c.Orig.SetReadDeadline(t) +} + +func (c *tcpConn) SetWriteDeadline(t time.Time) error { + return c.Orig.SetWriteDeadline(t) +} + +type udpIOImpl struct { + Conn quic.Connection +} + +func (io *udpIOImpl) ReceiveMessage() (*protocol.UDPMessage, error) { + for { + msg, err := io.Conn.ReceiveDatagram(context.Background()) + if err != nil { + // Connection error, this will stop the session manager + return nil, err + } + udpMsg, err := protocol.ParseUDPMessage(msg) + if err != nil { + // Invalid message, this is fine - just wait for the next + continue + } + return udpMsg, nil + } +} + +func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error { + msgN := msg.Serialize(buf) + if msgN < 0 { + // Message larger than buffer, silent drop + return nil + } + return io.Conn.SendDatagram(buf[:msgN]) +} diff --git a/protocol/hysteria2/client/config.go b/protocol/hysteria2/client/config.go new file mode 100644 index 0000000..baf9f23 --- /dev/null +++ b/protocol/hysteria2/client/config.go @@ -0,0 +1,112 @@ +package client + +import ( + "crypto/x509" + "net" + "time" + + "github.com/daeuniverse/outbound/protocol/hysteria2/errors" + "github.com/daeuniverse/outbound/protocol/hysteria2/internal/pmtud" +) + +const ( + defaultStreamReceiveWindow = 8388608 // 8MB + defaultConnReceiveWindow = defaultStreamReceiveWindow * 5 / 2 // 20MB + defaultMaxIdleTimeout = 30 * time.Second + defaultKeepAlivePeriod = 10 * time.Second +) + +type Config struct { + ConnFactory ConnFactory + ServerAddr net.Addr + Auth string + TLSConfig TLSConfig + QUICConfig QUICConfig + BandwidthConfig BandwidthConfig + FastOpen bool + + filled bool // whether the fields have been verified and filled +} + +// verifyAndFill fills the fields that are not set by the user with default values when possible, +// and returns an error if the user has not set a required field or has set an invalid value. +func (c *Config) verifyAndFill() error { + if c.filled { + return nil + } + if c.ConnFactory == nil { + c.ConnFactory = &udpConnFactory{} + } + if c.ServerAddr == nil { + return errors.ConfigError{Field: "ServerAddr", Reason: "must be set"} + } + if c.QUICConfig.InitialStreamReceiveWindow == 0 { + c.QUICConfig.InitialStreamReceiveWindow = defaultStreamReceiveWindow + } else if c.QUICConfig.InitialStreamReceiveWindow < 16384 { + return errors.ConfigError{Field: "QUICConfig.InitialStreamReceiveWindow", Reason: "must be at least 16384"} + } + if c.QUICConfig.MaxStreamReceiveWindow == 0 { + c.QUICConfig.MaxStreamReceiveWindow = defaultStreamReceiveWindow + } else if c.QUICConfig.MaxStreamReceiveWindow < 16384 { + return errors.ConfigError{Field: "QUICConfig.MaxStreamReceiveWindow", Reason: "must be at least 16384"} + } + if c.QUICConfig.InitialConnectionReceiveWindow == 0 { + c.QUICConfig.InitialConnectionReceiveWindow = defaultConnReceiveWindow + } else if c.QUICConfig.InitialConnectionReceiveWindow < 16384 { + return errors.ConfigError{Field: "QUICConfig.InitialConnectionReceiveWindow", Reason: "must be at least 16384"} + } + if c.QUICConfig.MaxConnectionReceiveWindow == 0 { + c.QUICConfig.MaxConnectionReceiveWindow = defaultConnReceiveWindow + } else if c.QUICConfig.MaxConnectionReceiveWindow < 16384 { + return errors.ConfigError{Field: "QUICConfig.MaxConnectionReceiveWindow", Reason: "must be at least 16384"} + } + if c.QUICConfig.MaxIdleTimeout == 0 { + c.QUICConfig.MaxIdleTimeout = defaultMaxIdleTimeout + } else if c.QUICConfig.MaxIdleTimeout < 4*time.Second || c.QUICConfig.MaxIdleTimeout > 120*time.Second { + return errors.ConfigError{Field: "QUICConfig.MaxIdleTimeout", Reason: "must be between 4s and 120s"} + } + if c.QUICConfig.KeepAlivePeriod == 0 { + c.QUICConfig.KeepAlivePeriod = defaultKeepAlivePeriod + } else if c.QUICConfig.KeepAlivePeriod < 2*time.Second || c.QUICConfig.KeepAlivePeriod > 60*time.Second { + return errors.ConfigError{Field: "QUICConfig.KeepAlivePeriod", Reason: "must be between 2s and 60s"} + } + c.QUICConfig.DisablePathMTUDiscovery = c.QUICConfig.DisablePathMTUDiscovery || pmtud.DisablePathMTUDiscovery + + c.filled = true + return nil +} + +type ConnFactory interface { + New(net.Addr) (net.PacketConn, error) +} + +type udpConnFactory struct{} + +func (f *udpConnFactory) New(addr net.Addr) (net.PacketConn, error) { + return net.ListenUDP("udp", nil) +} + +// TLSConfig contains the TLS configuration fields that we want to expose to the user. +type TLSConfig struct { + ServerName string + InsecureSkipVerify bool + VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + RootCAs *x509.CertPool +} + +// QUICConfig contains the QUIC configuration fields that we want to expose to the user. +type QUICConfig struct { + InitialStreamReceiveWindow uint64 + MaxStreamReceiveWindow uint64 + InitialConnectionReceiveWindow uint64 + MaxConnectionReceiveWindow uint64 + MaxIdleTimeout time.Duration + KeepAlivePeriod time.Duration + DisablePathMTUDiscovery bool // The server may still override this to true on unsupported platforms. +} + +// BandwidthConfig describes the maximum bandwidth that the server can use, in bytes per second. +type BandwidthConfig struct { + MaxTx uint64 + MaxRx uint64 +} diff --git a/protocol/hysteria2/client/reconnect.go b/protocol/hysteria2/client/reconnect.go new file mode 100644 index 0000000..659397c --- /dev/null +++ b/protocol/hysteria2/client/reconnect.go @@ -0,0 +1,120 @@ +package client + +import ( + "net" + "sync" + + coreErrs "github.com/daeuniverse/outbound/protocol/hysteria2/errors" +) + +// reconnectableClientImpl is a wrapper of Client, which can reconnect when the connection is closed, +// except when the caller explicitly calls Close() to permanently close this client. +type reconnectableClientImpl struct { + configFunc func() (*Config, error) // called before connecting + connectedFunc func(Client, *HandshakeInfo, int) // called when successfully connected + client Client + count int + m sync.Mutex + closed bool // permanent close +} + +// NewReconnectableClient creates a reconnectable client. +// If lazy is true, the client will not connect until the first call to TCP() or UDP(). +// We use a function for config mainly to delay config evaluation +// (which involves DNS resolution) until the actual connection attempt. +func NewReconnectableClient(configFunc func() (*Config, error), connectedFunc func(Client, *HandshakeInfo, int), lazy bool) (Client, error) { + rc := &reconnectableClientImpl{ + configFunc: configFunc, + connectedFunc: connectedFunc, + } + if !lazy { + if err := rc.reconnect(); err != nil { + return nil, err + } + } + return rc, nil +} + +func (rc *reconnectableClientImpl) reconnect() error { + if rc.client != nil { + _ = rc.client.Close() + } + var info *HandshakeInfo + config, err := rc.configFunc() + if err != nil { + return err + } + rc.client, info, err = NewClient(config) + if err != nil { + return err + } else { + rc.count++ + if rc.connectedFunc != nil { + rc.connectedFunc(rc, info, rc.count) + } + return nil + } +} + +// clientDo calls f with the current client. +// If the client is nil, it will first reconnect. +// It will also detect if the client is closed, and if so, +// set it to nil for reconnect next time. +func (rc *reconnectableClientImpl) clientDo(f func(Client) (interface{}, error)) (interface{}, error) { + rc.m.Lock() + if rc.closed { + rc.m.Unlock() + return nil, coreErrs.ClosedError{} + } + if rc.client == nil { + // No active connection, connect first + if err := rc.reconnect(); err != nil { + rc.m.Unlock() + return nil, err + } + } + client := rc.client + rc.m.Unlock() + + ret, err := f(client) + if _, ok := err.(coreErrs.ClosedError); ok { + // Connection closed, set client to nil for reconnect next time + rc.m.Lock() + if rc.client == client { + // This check is in case the client is already changed by another goroutine + rc.client = nil + } + rc.m.Unlock() + } + return ret, err +} + +func (rc *reconnectableClientImpl) TCP(addr string) (net.Conn, error) { + if c, err := rc.clientDo(func(client Client) (interface{}, error) { + return client.TCP(addr) + }); err != nil { + return nil, err + } else { + return c.(net.Conn), nil + } +} + +func (rc *reconnectableClientImpl) UDP() (HyUDPConn, error) { + if c, err := rc.clientDo(func(client Client) (interface{}, error) { + return client.UDP() + }); err != nil { + return nil, err + } else { + return c.(HyUDPConn), nil + } +} + +func (rc *reconnectableClientImpl) Close() error { + rc.m.Lock() + defer rc.m.Unlock() + rc.closed = true + if rc.client != nil { + return rc.client.Close() + } + return nil +} diff --git a/protocol/hysteria2/client/udp.go b/protocol/hysteria2/client/udp.go new file mode 100644 index 0000000..5378f56 --- /dev/null +++ b/protocol/hysteria2/client/udp.go @@ -0,0 +1,185 @@ +package client + +import ( + "errors" + "io" + "math/rand" + "sync" + + "github.com/daeuniverse/quic-go" + + coreErrs "github.com/daeuniverse/outbound/protocol/hysteria2/errors" + "github.com/daeuniverse/outbound/protocol/hysteria2/internal/frag" + "github.com/daeuniverse/outbound/protocol/hysteria2/internal/protocol" +) + +const ( + udpMessageChanSize = 1024 +) + +type udpIO interface { + ReceiveMessage() (*protocol.UDPMessage, error) + SendMessage([]byte, *protocol.UDPMessage) error +} + +type udpConn struct { + ID uint32 + D *frag.Defragger + ReceiveCh chan *protocol.UDPMessage + SendBuf []byte + SendFunc func([]byte, *protocol.UDPMessage) error + CloseFunc func() + Closed bool +} + +func (u *udpConn) Receive() ([]byte, string, error) { + for { + msg := <-u.ReceiveCh + if msg == nil { + // Closed + return nil, "", io.EOF + } + dfMsg := u.D.Feed(msg) + if dfMsg == nil { + // Incomplete message, wait for more + continue + } + return dfMsg.Data, dfMsg.Addr, nil + } +} + +// Send is not thread-safe, as it uses a shared SendBuf. +func (u *udpConn) Send(data []byte, addr string) error { + // Try no frag first + msg := &protocol.UDPMessage{ + SessionID: u.ID, + PacketID: 0, + FragID: 0, + FragCount: 1, + Addr: addr, + Data: data, + } + err := u.SendFunc(u.SendBuf, msg) + var errTooLarge *quic.DatagramTooLargeError + if errors.As(err, &errTooLarge) { + // Message too large, try fragmentation + msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1 + fMsgs := frag.FragUDPMessage(msg, int(errTooLarge.MaxDataLen)) + for _, fMsg := range fMsgs { + err := u.SendFunc(u.SendBuf, &fMsg) + if err != nil { + return err + } + } + return nil + } else { + return err + } +} + +func (u *udpConn) Close() error { + u.CloseFunc() + return nil +} + +type udpSessionManager struct { + io udpIO + + mutex sync.RWMutex + m map[uint32]*udpConn + nextID uint32 + + closed bool +} + +func newUDPSessionManager(io udpIO) *udpSessionManager { + m := &udpSessionManager{ + io: io, + m: make(map[uint32]*udpConn), + nextID: 1, + } + go m.run() + return m +} + +func (m *udpSessionManager) run() error { + defer m.closeCleanup() + for { + msg, err := m.io.ReceiveMessage() + if err != nil { + return err + } + m.feed(msg) + } +} + +func (m *udpSessionManager) closeCleanup() { + m.mutex.Lock() + defer m.mutex.Unlock() + + for _, conn := range m.m { + m.close(conn) + } + m.closed = true +} + +func (m *udpSessionManager) feed(msg *protocol.UDPMessage) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + conn, ok := m.m[msg.SessionID] + if !ok { + // Ignore message from unknown session + return + } + + select { + case conn.ReceiveCh <- msg: + // OK + default: + // Channel full, drop the message + } +} + +// NewUDP creates a new UDP session. +func (m *udpSessionManager) NewUDP() (HyUDPConn, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.closed { + return nil, coreErrs.ClosedError{} + } + + id := m.nextID + m.nextID++ + + conn := &udpConn{ + ID: id, + D: &frag.Defragger{}, + ReceiveCh: make(chan *protocol.UDPMessage, udpMessageChanSize), + SendBuf: make([]byte, protocol.MaxUDPSize), + SendFunc: m.io.SendMessage, + } + conn.CloseFunc = func() { + m.mutex.Lock() + defer m.mutex.Unlock() + m.close(conn) + } + m.m[id] = conn + + return conn, nil +} + +func (m *udpSessionManager) close(conn *udpConn) { + if !conn.Closed { + conn.Closed = true + close(conn.ReceiveCh) + delete(m.m, conn.ID) + } +} + +func (m *udpSessionManager) Count() int { + m.mutex.RLock() + defer m.mutex.RUnlock() + return len(m.m) +} diff --git a/protocol/hysteria2/errors/errors.go b/protocol/hysteria2/errors/errors.go new file mode 100644 index 0000000..cb69118 --- /dev/null +++ b/protocol/hysteria2/errors/errors.go @@ -0,0 +1,75 @@ +package errors + +import ( + "fmt" + "strconv" +) + +// ConfigError is returned when a configuration field is invalid. +type ConfigError struct { + Field string + Reason string +} + +func (c ConfigError) Error() string { + return fmt.Sprintf("invalid config: %s: %s", c.Field, c.Reason) +} + +// ConnectError is returned when the client fails to connect to the server. +type ConnectError struct { + Err error +} + +func (c ConnectError) Error() string { + return "connect error: " + c.Err.Error() +} + +func (c ConnectError) Unwrap() error { + return c.Err +} + +// AuthError is returned when the client fails to authenticate with the server. +type AuthError struct { + StatusCode int +} + +func (a AuthError) Error() string { + return "authentication error, HTTP status code: " + strconv.Itoa(a.StatusCode) +} + +// DialError is returned when the server rejects the client's dial request. +// This applies to both TCP and UDP. +type DialError struct { + Message string +} + +func (c DialError) Error() string { + return "dial error: " + c.Message +} + +// ClosedError is returned when the client attempts to use a closed connection. +type ClosedError struct { + Err error // Can be nil +} + +func (c ClosedError) Error() string { + if c.Err == nil { + return "connection closed" + } else { + return "connection closed: " + c.Err.Error() + } +} + +func (c ClosedError) Unwrap() error { + return c.Err +} + +// ProtocolError is returned when the server/client runs into an unexpected +// or malformed request/response/message. +type ProtocolError struct { + Message string +} + +func (p ProtocolError) Error() string { + return "protocol error: " + p.Message +} diff --git a/protocol/hysteria2/internal/congestion/bbr/bandwidth.go b/protocol/hysteria2/internal/congestion/bbr/bandwidth.go new file mode 100644 index 0000000..23d870d --- /dev/null +++ b/protocol/hysteria2/internal/congestion/bbr/bandwidth.go @@ -0,0 +1,27 @@ +package bbr + +import ( + "math" + "time" + + "github.com/daeuniverse/quic-go/congestion" +) + +const ( + infBandwidth = Bandwidth(math.MaxUint64) +) + +// Bandwidth of a connection +type Bandwidth uint64 + +const ( + // BitsPerSecond is 1 bit per second + BitsPerSecond Bandwidth = 1 + // BytesPerSecond is 1 byte per second + BytesPerSecond = 8 * BitsPerSecond +) + +// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta +func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth { + return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond +} diff --git a/protocol/hysteria2/internal/congestion/bbr/bandwidth_sampler.go b/protocol/hysteria2/internal/congestion/bbr/bandwidth_sampler.go new file mode 100644 index 0000000..4b28d42 --- /dev/null +++ b/protocol/hysteria2/internal/congestion/bbr/bandwidth_sampler.go @@ -0,0 +1,874 @@ +package bbr + +import ( + "math" + "time" + + "github.com/daeuniverse/quic-go/congestion" +) + +const ( + infRTT = time.Duration(math.MaxInt64) + defaultConnectionStateMapQueueSize = 256 + defaultCandidatesBufferSize = 256 +) + +type roundTripCount uint64 + +// SendTimeState is a subset of ConnectionStateOnSentPacket which is returned +// to the caller when the packet is acked or lost. +type sendTimeState struct { + // Whether other states in this object is valid. + isValid bool + // Whether the sender is app limited at the time the packet was sent. + // App limited bandwidth sample might be artificially low because the sender + // did not have enough data to send in order to saturate the link. + isAppLimited bool + // Total number of sent bytes at the time the packet was sent. + // Includes the packet itself. + totalBytesSent congestion.ByteCount + // Total number of acked bytes at the time the packet was sent. + totalBytesAcked congestion.ByteCount + // Total number of lost bytes at the time the packet was sent. + totalBytesLost congestion.ByteCount + // Total number of inflight bytes at the time the packet was sent. + // Includes the packet itself. + // It should be equal to |total_bytes_sent| minus the sum of + // |total_bytes_acked|, |total_bytes_lost| and total neutered bytes. + bytesInFlight congestion.ByteCount +} + +func newSendTimeState( + isAppLimited bool, + totalBytesSent congestion.ByteCount, + totalBytesAcked congestion.ByteCount, + totalBytesLost congestion.ByteCount, + bytesInFlight congestion.ByteCount, +) *sendTimeState { + return &sendTimeState{ + isValid: true, + isAppLimited: isAppLimited, + totalBytesSent: totalBytesSent, + totalBytesAcked: totalBytesAcked, + totalBytesLost: totalBytesLost, + bytesInFlight: bytesInFlight, + } +} + +type extraAckedEvent struct { + // The excess bytes acknowlwedged in the time delta for this event. + extraAcked congestion.ByteCount + + // The bytes acknowledged and time delta from the event. + bytesAcked congestion.ByteCount + timeDelta time.Duration + // The round trip of the event. + round roundTripCount +} + +func maxExtraAckedEventFunc(a, b extraAckedEvent) int { + if a.extraAcked > b.extraAcked { + return 1 + } else if a.extraAcked < b.extraAcked { + return -1 + } + return 0 +} + +// BandwidthSample +type bandwidthSample struct { + // The bandwidth at that particular sample. Zero if no valid bandwidth sample + // is available. + bandwidth Bandwidth + // The RTT measurement at this particular sample. Zero if no RTT sample is + // available. Does not correct for delayed ack time. + rtt time.Duration + // |send_rate| is computed from the current packet being acked('P') and an + // earlier packet that is acked before P was sent. + sendRate Bandwidth + // States captured when the packet was sent. + stateAtSend sendTimeState +} + +func newBandwidthSample() *bandwidthSample { + return &bandwidthSample{ + sendRate: infBandwidth, + } +} + +// MaxAckHeightTracker is part of the BandwidthSampler. It is called after every +// ack event to keep track the degree of ack aggregation(a.k.a "ack height"). +type maxAckHeightTracker struct { + // Tracks the maximum number of bytes acked faster than the estimated + // bandwidth. + maxAckHeightFilter *WindowedFilter[extraAckedEvent, roundTripCount] + // The time this aggregation started and the number of bytes acked during it. + aggregationEpochStartTime time.Time + aggregationEpochBytes congestion.ByteCount + // The last sent packet number before the current aggregation epoch started. + lastSentPacketNumberBeforeEpoch congestion.PacketNumber + // The number of ack aggregation epochs ever started, including the ongoing + // one. Stats only. + numAckAggregationEpochs uint64 + ackAggregationBandwidthThreshold float64 + startNewAggregationEpochAfterFullRound bool + reduceExtraAckedOnBandwidthIncrease bool +} + +func newMaxAckHeightTracker(windowLength roundTripCount) *maxAckHeightTracker { + return &maxAckHeightTracker{ + maxAckHeightFilter: NewWindowedFilter(windowLength, maxExtraAckedEventFunc), + lastSentPacketNumberBeforeEpoch: invalidPacketNumber, + ackAggregationBandwidthThreshold: 1.0, + } +} + +func (m *maxAckHeightTracker) Get() congestion.ByteCount { + return m.maxAckHeightFilter.GetBest().extraAcked +} + +func (m *maxAckHeightTracker) Update( + bandwidthEstimate Bandwidth, + isNewMaxBandwidth bool, + roundTripCount roundTripCount, + lastSentPacketNumber congestion.PacketNumber, + lastAckedPacketNumber congestion.PacketNumber, + ackTime time.Time, + bytesAcked congestion.ByteCount, +) congestion.ByteCount { + forceNewEpoch := false + + if m.reduceExtraAckedOnBandwidthIncrease && isNewMaxBandwidth { + // Save and clear existing entries. + best := m.maxAckHeightFilter.GetBest() + secondBest := m.maxAckHeightFilter.GetSecondBest() + thirdBest := m.maxAckHeightFilter.GetThirdBest() + m.maxAckHeightFilter.Clear() + + // Reinsert the heights into the filter after recalculating. + expectedBytesAcked := bytesFromBandwidthAndTimeDelta(bandwidthEstimate, best.timeDelta) + if expectedBytesAcked < best.bytesAcked { + best.extraAcked = best.bytesAcked - expectedBytesAcked + m.maxAckHeightFilter.Update(best, best.round) + } + expectedBytesAcked = bytesFromBandwidthAndTimeDelta(bandwidthEstimate, secondBest.timeDelta) + if expectedBytesAcked < secondBest.bytesAcked { + secondBest.extraAcked = secondBest.bytesAcked - expectedBytesAcked + m.maxAckHeightFilter.Update(secondBest, secondBest.round) + } + expectedBytesAcked = bytesFromBandwidthAndTimeDelta(bandwidthEstimate, thirdBest.timeDelta) + if expectedBytesAcked < thirdBest.bytesAcked { + thirdBest.extraAcked = thirdBest.bytesAcked - expectedBytesAcked + m.maxAckHeightFilter.Update(thirdBest, thirdBest.round) + } + } + + // If any packet sent after the start of the epoch has been acked, start a new + // epoch. + if m.startNewAggregationEpochAfterFullRound && + m.lastSentPacketNumberBeforeEpoch != invalidPacketNumber && + lastAckedPacketNumber != invalidPacketNumber && + lastAckedPacketNumber > m.lastSentPacketNumberBeforeEpoch { + forceNewEpoch = true + } + if m.aggregationEpochStartTime.IsZero() || forceNewEpoch { + m.aggregationEpochBytes = bytesAcked + m.aggregationEpochStartTime = ackTime + m.lastSentPacketNumberBeforeEpoch = lastSentPacketNumber + m.numAckAggregationEpochs++ + return 0 + } + + // Compute how many bytes are expected to be delivered, assuming max bandwidth + // is correct. + aggregationDelta := ackTime.Sub(m.aggregationEpochStartTime) + expectedBytesAcked := bytesFromBandwidthAndTimeDelta(bandwidthEstimate, aggregationDelta) + // Reset the current aggregation epoch as soon as the ack arrival rate is less + // than or equal to the max bandwidth. + if m.aggregationEpochBytes <= congestion.ByteCount(m.ackAggregationBandwidthThreshold*float64(expectedBytesAcked)) { + // Reset to start measuring a new aggregation epoch. + m.aggregationEpochBytes = bytesAcked + m.aggregationEpochStartTime = ackTime + m.lastSentPacketNumberBeforeEpoch = lastSentPacketNumber + m.numAckAggregationEpochs++ + return 0 + } + + m.aggregationEpochBytes += bytesAcked + + // Compute how many extra bytes were delivered vs max bandwidth. + extraBytesAcked := m.aggregationEpochBytes - expectedBytesAcked + newEvent := extraAckedEvent{ + extraAcked: expectedBytesAcked, + bytesAcked: m.aggregationEpochBytes, + timeDelta: aggregationDelta, + } + m.maxAckHeightFilter.Update(newEvent, roundTripCount) + return extraBytesAcked +} + +func (m *maxAckHeightTracker) SetFilterWindowLength(length roundTripCount) { + m.maxAckHeightFilter.SetWindowLength(length) +} + +func (m *maxAckHeightTracker) Reset(newHeight congestion.ByteCount, newTime roundTripCount) { + newEvent := extraAckedEvent{ + extraAcked: newHeight, + round: newTime, + } + m.maxAckHeightFilter.Reset(newEvent, newTime) +} + +func (m *maxAckHeightTracker) SetAckAggregationBandwidthThreshold(threshold float64) { + m.ackAggregationBandwidthThreshold = threshold +} + +func (m *maxAckHeightTracker) SetStartNewAggregationEpochAfterFullRound(value bool) { + m.startNewAggregationEpochAfterFullRound = value +} + +func (m *maxAckHeightTracker) SetReduceExtraAckedOnBandwidthIncrease(value bool) { + m.reduceExtraAckedOnBandwidthIncrease = value +} + +func (m *maxAckHeightTracker) AckAggregationBandwidthThreshold() float64 { + return m.ackAggregationBandwidthThreshold +} + +func (m *maxAckHeightTracker) NumAckAggregationEpochs() uint64 { + return m.numAckAggregationEpochs +} + +// AckPoint represents a point on the ack line. +type ackPoint struct { + ackTime time.Time + totalBytesAcked congestion.ByteCount +} + +// RecentAckPoints maintains the most recent 2 ack points at distinct times. +type recentAckPoints struct { + ackPoints [2]ackPoint +} + +func (r *recentAckPoints) Update(ackTime time.Time, totalBytesAcked congestion.ByteCount) { + if ackTime.Before(r.ackPoints[1].ackTime) { + r.ackPoints[1].ackTime = ackTime + } else if ackTime.After(r.ackPoints[1].ackTime) { + r.ackPoints[0] = r.ackPoints[1] + r.ackPoints[1].ackTime = ackTime + } + + r.ackPoints[1].totalBytesAcked = totalBytesAcked +} + +func (r *recentAckPoints) Clear() { + r.ackPoints[0] = ackPoint{} + r.ackPoints[1] = ackPoint{} +} + +func (r *recentAckPoints) MostRecentPoint() *ackPoint { + return &r.ackPoints[1] +} + +func (r *recentAckPoints) LessRecentPoint() *ackPoint { + if r.ackPoints[0].totalBytesAcked != 0 { + return &r.ackPoints[0] + } + + return &r.ackPoints[1] +} + +// ConnectionStateOnSentPacket represents the information about a sent packet +// and the state of the connection at the moment the packet was sent, +// specifically the information about the most recently acknowledged packet at +// that moment. +type connectionStateOnSentPacket struct { + // Time at which the packet is sent. + sentTime time.Time + // Size of the packet. + size congestion.ByteCount + // The value of |totalBytesSentAtLastAckedPacket| at the time the + // packet was sent. + totalBytesSentAtLastAckedPacket congestion.ByteCount + // The value of |lastAckedPacketSentTime| at the time the packet was + // sent. + lastAckedPacketSentTime time.Time + // The value of |lastAckedPacketAckTime| at the time the packet was + // sent. + lastAckedPacketAckTime time.Time + // Send time states that are returned to the congestion controller when the + // packet is acked or lost. + sendTimeState sendTimeState +} + +// Snapshot constructor. Records the current state of the bandwidth +// sampler. +// |bytes_in_flight| is the bytes in flight right after the packet is sent. +func newConnectionStateOnSentPacket( + sentTime time.Time, + size congestion.ByteCount, + bytesInFlight congestion.ByteCount, + sampler *bandwidthSampler, +) *connectionStateOnSentPacket { + return &connectionStateOnSentPacket{ + sentTime: sentTime, + size: size, + totalBytesSentAtLastAckedPacket: sampler.totalBytesSentAtLastAckedPacket, + lastAckedPacketSentTime: sampler.lastAckedPacketSentTime, + lastAckedPacketAckTime: sampler.lastAckedPacketAckTime, + sendTimeState: *newSendTimeState( + sampler.isAppLimited, + sampler.totalBytesSent, + sampler.totalBytesAcked, + sampler.totalBytesLost, + bytesInFlight, + ), + } +} + +// BandwidthSampler keeps track of sent and acknowledged packets and outputs a +// bandwidth sample for every packet acknowledged. The samples are taken for +// individual packets, and are not filtered; the consumer has to filter the +// bandwidth samples itself. In certain cases, the sampler will locally severely +// underestimate the bandwidth, hence a maximum filter with a size of at least +// one RTT is recommended. +// +// This class bases its samples on the slope of two curves: the number of bytes +// sent over time, and the number of bytes acknowledged as received over time. +// It produces a sample of both slopes for every packet that gets acknowledged, +// based on a slope between two points on each of the corresponding curves. Note +// that due to the packet loss, the number of bytes on each curve might get +// further and further away from each other, meaning that it is not feasible to +// compare byte values coming from different curves with each other. +// +// The obvious points for measuring slope sample are the ones corresponding to +// the packet that was just acknowledged. Let us denote them as S_1 (point at +// which the current packet was sent) and A_1 (point at which the current packet +// was acknowledged). However, taking a slope requires two points on each line, +// so estimating bandwidth requires picking a packet in the past with respect to +// which the slope is measured. +// +// For that purpose, BandwidthSampler always keeps track of the most recently +// acknowledged packet, and records it together with every outgoing packet. +// When a packet gets acknowledged (A_1), it has not only information about when +// it itself was sent (S_1), but also the information about the latest +// acknowledged packet right before it was sent (S_0 and A_0). +// +// Based on that data, send and ack rate are estimated as: +// +// send_rate = (bytes(S_1) - bytes(S_0)) / (time(S_1) - time(S_0)) +// ack_rate = (bytes(A_1) - bytes(A_0)) / (time(A_1) - time(A_0)) +// +// Here, the ack rate is intuitively the rate we want to treat as bandwidth. +// However, in certain cases (e.g. ack compression) the ack rate at a point may +// end up higher than the rate at which the data was originally sent, which is +// not indicative of the real bandwidth. Hence, we use the send rate as an upper +// bound, and the sample value is +// +// rate_sample = min(send_rate, ack_rate) +// +// An important edge case handled by the sampler is tracking the app-limited +// samples. There are multiple meaning of "app-limited" used interchangeably, +// hence it is important to understand and to be able to distinguish between +// them. +// +// Meaning 1: connection state. The connection is said to be app-limited when +// there is no outstanding data to send. This means that certain bandwidth +// samples in the future would not be an accurate indication of the link +// capacity, and it is important to inform consumer about that. Whenever +// connection becomes app-limited, the sampler is notified via OnAppLimited() +// method. +// +// Meaning 2: a phase in the bandwidth sampler. As soon as the bandwidth +// sampler becomes notified about the connection being app-limited, it enters +// app-limited phase. In that phase, all *sent* packets are marked as +// app-limited. Note that the connection itself does not have to be +// app-limited during the app-limited phase, and in fact it will not be +// (otherwise how would it send packets?). The boolean flag below indicates +// whether the sampler is in that phase. +// +// Meaning 3: a flag on the sent packet and on the sample. If a sent packet is +// sent during the app-limited phase, the resulting sample related to the +// packet will be marked as app-limited. +// +// With the terminology issue out of the way, let us consider the question of +// what kind of situation it addresses. +// +// Consider a scenario where we first send packets 1 to 20 at a regular +// bandwidth, and then immediately run out of data. After a few seconds, we send +// packets 21 to 60, and only receive ack for 21 between sending packets 40 and +// 41. In this case, when we sample bandwidth for packets 21 to 40, the S_0/A_0 +// we use to compute the slope is going to be packet 20, a few seconds apart +// from the current packet, hence the resulting estimate would be extremely low +// and not indicative of anything. Only at packet 41 the S_0/A_0 will become 21, +// meaning that the bandwidth sample would exclude the quiescence. +// +// Based on the analysis of that scenario, we implement the following rule: once +// OnAppLimited() is called, all sent packets will produce app-limited samples +// up until an ack for a packet that was sent after OnAppLimited() was called. +// Note that while the scenario above is not the only scenario when the +// connection is app-limited, the approach works in other cases too. + +type congestionEventSample struct { + // The maximum bandwidth sample from all acked packets. + // QuicBandwidth::Zero() if no samples are available. + sampleMaxBandwidth Bandwidth + // Whether |sample_max_bandwidth| is from a app-limited sample. + sampleIsAppLimited bool + // The minimum rtt sample from all acked packets. + // QuicTime::Delta::Infinite() if no samples are available. + sampleRtt time.Duration + // For each packet p in acked packets, this is the max value of INFLIGHT(p), + // where INFLIGHT(p) is the number of bytes acked while p is inflight. + sampleMaxInflight congestion.ByteCount + // The send state of the largest packet in acked_packets, unless it is + // empty. If acked_packets is empty, it's the send state of the largest + // packet in lost_packets. + lastPacketSendState sendTimeState + // The number of extra bytes acked from this ack event, compared to what is + // expected from the flow's bandwidth. Larger value means more ack + // aggregation. + extraAcked congestion.ByteCount +} + +func newCongestionEventSample() *congestionEventSample { + return &congestionEventSample{ + sampleRtt: infRTT, + } +} + +type bandwidthSampler struct { + // The total number of congestion controlled bytes sent during the connection. + totalBytesSent congestion.ByteCount + + // The total number of congestion controlled bytes which were acknowledged. + totalBytesAcked congestion.ByteCount + + // The total number of congestion controlled bytes which were lost. + totalBytesLost congestion.ByteCount + + // The total number of congestion controlled bytes which have been neutered. + totalBytesNeutered congestion.ByteCount + + // The value of |total_bytes_sent_| at the time the last acknowledged packet + // was sent. Valid only when |last_acked_packet_sent_time_| is valid. + totalBytesSentAtLastAckedPacket congestion.ByteCount + + // The time at which the last acknowledged packet was sent. Set to + // QuicTime::Zero() if no valid timestamp is available. + lastAckedPacketSentTime time.Time + + // The time at which the most recent packet was acknowledged. + lastAckedPacketAckTime time.Time + + // The most recently sent packet. + lastSentPacket congestion.PacketNumber + + // The most recently acked packet. + lastAckedPacket congestion.PacketNumber + + // Indicates whether the bandwidth sampler is currently in an app-limited + // phase. + isAppLimited bool + + // The packet that will be acknowledged after this one will cause the sampler + // to exit the app-limited phase. + endOfAppLimitedPhase congestion.PacketNumber + + // Record of the connection state at the point where each packet in flight was + // sent, indexed by the packet number. + connectionStateMap *packetNumberIndexedQueue[connectionStateOnSentPacket] + + recentAckPoints recentAckPoints + a0Candidates RingBuffer[ackPoint] + + // Maximum number of tracked packets. + maxTrackedPackets congestion.ByteCount + + maxAckHeightTracker *maxAckHeightTracker + totalBytesAckedAfterLastAckEvent congestion.ByteCount + + // True if connection option 'BSAO' is set. + overestimateAvoidance bool + + // True if connection option 'BBRB' is set. + limitMaxAckHeightTrackerBySendRate bool +} + +func newBandwidthSampler(maxAckHeightTrackerWindowLength roundTripCount) *bandwidthSampler { + b := &bandwidthSampler{ + maxAckHeightTracker: newMaxAckHeightTracker(maxAckHeightTrackerWindowLength), + connectionStateMap: newPacketNumberIndexedQueue[connectionStateOnSentPacket](defaultConnectionStateMapQueueSize), + lastSentPacket: invalidPacketNumber, + lastAckedPacket: invalidPacketNumber, + endOfAppLimitedPhase: invalidPacketNumber, + } + + b.a0Candidates.Init(defaultCandidatesBufferSize) + + return b +} + +func (b *bandwidthSampler) MaxAckHeight() congestion.ByteCount { + return b.maxAckHeightTracker.Get() +} + +func (b *bandwidthSampler) NumAckAggregationEpochs() uint64 { + return b.maxAckHeightTracker.NumAckAggregationEpochs() +} + +func (b *bandwidthSampler) SetMaxAckHeightTrackerWindowLength(length roundTripCount) { + b.maxAckHeightTracker.SetFilterWindowLength(length) +} + +func (b *bandwidthSampler) ResetMaxAckHeightTracker(newHeight congestion.ByteCount, newTime roundTripCount) { + b.maxAckHeightTracker.Reset(newHeight, newTime) +} + +func (b *bandwidthSampler) SetStartNewAggregationEpochAfterFullRound(value bool) { + b.maxAckHeightTracker.SetStartNewAggregationEpochAfterFullRound(value) +} + +func (b *bandwidthSampler) SetLimitMaxAckHeightTrackerBySendRate(value bool) { + b.limitMaxAckHeightTrackerBySendRate = value +} + +func (b *bandwidthSampler) SetReduceExtraAckedOnBandwidthIncrease(value bool) { + b.maxAckHeightTracker.SetReduceExtraAckedOnBandwidthIncrease(value) +} + +func (b *bandwidthSampler) EnableOverestimateAvoidance() { + if b.overestimateAvoidance { + return + } + + b.overestimateAvoidance = true + b.maxAckHeightTracker.SetAckAggregationBandwidthThreshold(2.0) +} + +func (b *bandwidthSampler) IsOverestimateAvoidanceEnabled() bool { + return b.overestimateAvoidance +} + +func (b *bandwidthSampler) OnPacketSent( + sentTime time.Time, + packetNumber congestion.PacketNumber, + bytes congestion.ByteCount, + bytesInFlight congestion.ByteCount, + isRetransmittable bool, +) { + b.lastSentPacket = packetNumber + + if !isRetransmittable { + return + } + + b.totalBytesSent += bytes + + // If there are no packets in flight, the time at which the new transmission + // opens can be treated as the A_0 point for the purpose of bandwidth + // sampling. This underestimates bandwidth to some extent, and produces some + // artificially low samples for most packets in flight, but it provides with + // samples at important points where we would not have them otherwise, most + // importantly at the beginning of the connection. + if bytesInFlight == 0 { + b.lastAckedPacketAckTime = sentTime + if b.overestimateAvoidance { + b.recentAckPoints.Clear() + b.recentAckPoints.Update(sentTime, b.totalBytesAcked) + b.a0Candidates.Clear() + b.a0Candidates.PushBack(*b.recentAckPoints.MostRecentPoint()) + } + b.totalBytesSentAtLastAckedPacket = b.totalBytesSent + + // In this situation ack compression is not a concern, set send rate to + // effectively infinite. + b.lastAckedPacketSentTime = sentTime + } + + b.connectionStateMap.Emplace(packetNumber, newConnectionStateOnSentPacket( + sentTime, + bytes, + bytesInFlight+bytes, + b, + )) +} + +func (b *bandwidthSampler) OnCongestionEvent( + ackTime time.Time, + ackedPackets []congestion.AckedPacketInfo, + lostPackets []congestion.LostPacketInfo, + maxBandwidth Bandwidth, + estBandwidthUpperBound Bandwidth, + roundTripCount roundTripCount, +) congestionEventSample { + eventSample := newCongestionEventSample() + + var lastLostPacketSendState sendTimeState + + for _, p := range lostPackets { + sendState := b.OnPacketLost(p.PacketNumber, p.BytesLost) + if sendState.isValid { + lastLostPacketSendState = sendState + } + } + + if len(ackedPackets) == 0 { + // Only populate send state for a loss-only event. + eventSample.lastPacketSendState = lastLostPacketSendState + return *eventSample + } + + var lastAckedPacketSendState sendTimeState + var maxSendRate Bandwidth + + for _, p := range ackedPackets { + sample := b.onPacketAcknowledged(ackTime, p.PacketNumber) + if !sample.stateAtSend.isValid { + continue + } + + lastAckedPacketSendState = sample.stateAtSend + + if sample.rtt != 0 { + eventSample.sampleRtt = min(eventSample.sampleRtt, sample.rtt) + } + if sample.bandwidth > eventSample.sampleMaxBandwidth { + eventSample.sampleMaxBandwidth = sample.bandwidth + eventSample.sampleIsAppLimited = sample.stateAtSend.isAppLimited + } + if sample.sendRate != infBandwidth { + maxSendRate = max(maxSendRate, sample.sendRate) + } + inflightSample := b.totalBytesAcked - lastAckedPacketSendState.totalBytesAcked + if inflightSample > eventSample.sampleMaxInflight { + eventSample.sampleMaxInflight = inflightSample + } + } + + if !lastLostPacketSendState.isValid { + eventSample.lastPacketSendState = lastAckedPacketSendState + } else if !lastAckedPacketSendState.isValid { + eventSample.lastPacketSendState = lastLostPacketSendState + } else { + // If two packets are inflight and an alarm is armed to lose a packet and it + // wakes up late, then the first of two in flight packets could have been + // acknowledged before the wakeup, which re-evaluates loss detection, and + // could declare the later of the two lost. + if lostPackets[len(lostPackets)-1].PacketNumber > ackedPackets[len(ackedPackets)-1].PacketNumber { + eventSample.lastPacketSendState = lastLostPacketSendState + } else { + eventSample.lastPacketSendState = lastAckedPacketSendState + } + } + + isNewMaxBandwidth := eventSample.sampleMaxBandwidth > maxBandwidth + maxBandwidth = max(maxBandwidth, eventSample.sampleMaxBandwidth) + if b.limitMaxAckHeightTrackerBySendRate { + maxBandwidth = max(maxBandwidth, maxSendRate) + } + + eventSample.extraAcked = b.onAckEventEnd(min(estBandwidthUpperBound, maxBandwidth), isNewMaxBandwidth, roundTripCount) + + return *eventSample +} + +func (b *bandwidthSampler) OnPacketLost(packetNumber congestion.PacketNumber, bytesLost congestion.ByteCount) (s sendTimeState) { + b.totalBytesLost += bytesLost + if sentPacketPointer := b.connectionStateMap.GetEntry(packetNumber); sentPacketPointer != nil { + sentPacketToSendTimeState(sentPacketPointer, &s) + } + return s +} + +func (b *bandwidthSampler) OnPacketNeutered(packetNumber congestion.PacketNumber) { + b.connectionStateMap.Remove(packetNumber, func(sentPacket connectionStateOnSentPacket) { + b.totalBytesNeutered += sentPacket.size + }) +} + +func (b *bandwidthSampler) OnAppLimited() { + b.isAppLimited = true + b.endOfAppLimitedPhase = b.lastSentPacket +} + +func (b *bandwidthSampler) RemoveObsoletePackets(leastUnacked congestion.PacketNumber) { + // A packet can become obsolete when it is removed from QuicUnackedPacketMap's + // view of inflight before it is acked or marked as lost. For example, when + // QuicSentPacketManager::RetransmitCryptoPackets retransmits a crypto packet, + // the packet is removed from QuicUnackedPacketMap's inflight, but is not + // marked as acked or lost in the BandwidthSampler. + b.connectionStateMap.RemoveUpTo(leastUnacked) +} + +func (b *bandwidthSampler) TotalBytesSent() congestion.ByteCount { + return b.totalBytesSent +} + +func (b *bandwidthSampler) TotalBytesLost() congestion.ByteCount { + return b.totalBytesLost +} + +func (b *bandwidthSampler) TotalBytesAcked() congestion.ByteCount { + return b.totalBytesAcked +} + +func (b *bandwidthSampler) TotalBytesNeutered() congestion.ByteCount { + return b.totalBytesNeutered +} + +func (b *bandwidthSampler) IsAppLimited() bool { + return b.isAppLimited +} + +func (b *bandwidthSampler) EndOfAppLimitedPhase() congestion.PacketNumber { + return b.endOfAppLimitedPhase +} + +func (b *bandwidthSampler) max_ack_height() congestion.ByteCount { + return b.maxAckHeightTracker.Get() +} + +func (b *bandwidthSampler) chooseA0Point(totalBytesAcked congestion.ByteCount, a0 *ackPoint) bool { + if b.a0Candidates.Empty() { + return false + } + + if b.a0Candidates.Len() == 1 { + *a0 = *b.a0Candidates.Front() + return true + } + + for i := 1; i < b.a0Candidates.Len(); i++ { + if b.a0Candidates.Offset(i).totalBytesAcked > totalBytesAcked { + *a0 = *b.a0Candidates.Offset(i - 1) + if i > 1 { + for j := 0; j < i-1; j++ { + b.a0Candidates.PopFront() + } + } + return true + } + } + + *a0 = *b.a0Candidates.Back() + for k := 0; k < b.a0Candidates.Len()-1; k++ { + b.a0Candidates.PopFront() + } + return true +} + +func (b *bandwidthSampler) onPacketAcknowledged(ackTime time.Time, packetNumber congestion.PacketNumber) bandwidthSample { + sample := newBandwidthSample() + b.lastAckedPacket = packetNumber + sentPacketPointer := b.connectionStateMap.GetEntry(packetNumber) + if sentPacketPointer == nil { + return *sample + } + + // OnPacketAcknowledgedInner + b.totalBytesAcked += sentPacketPointer.size + b.totalBytesSentAtLastAckedPacket = sentPacketPointer.sendTimeState.totalBytesSent + b.lastAckedPacketSentTime = sentPacketPointer.sentTime + b.lastAckedPacketAckTime = ackTime + if b.overestimateAvoidance { + b.recentAckPoints.Update(ackTime, b.totalBytesAcked) + } + + if b.isAppLimited { + // Exit app-limited phase in two cases: + // (1) end_of_app_limited_phase_ is not initialized, i.e., so far all + // packets are sent while there are buffered packets or pending data. + // (2) The current acked packet is after the sent packet marked as the end + // of the app limit phase. + if b.endOfAppLimitedPhase == invalidPacketNumber || + packetNumber > b.endOfAppLimitedPhase { + b.isAppLimited = false + } + } + + // There might have been no packets acknowledged at the moment when the + // current packet was sent. In that case, there is no bandwidth sample to + // make. + if sentPacketPointer.lastAckedPacketSentTime.IsZero() { + return *sample + } + + // Infinite rate indicates that the sampler is supposed to discard the + // current send rate sample and use only the ack rate. + sendRate := infBandwidth + if sentPacketPointer.sentTime.After(sentPacketPointer.lastAckedPacketSentTime) { + sendRate = BandwidthFromDelta( + sentPacketPointer.sendTimeState.totalBytesSent-sentPacketPointer.totalBytesSentAtLastAckedPacket, + sentPacketPointer.sentTime.Sub(sentPacketPointer.lastAckedPacketSentTime)) + } + + var a0 ackPoint + if b.overestimateAvoidance && b.chooseA0Point(sentPacketPointer.sendTimeState.totalBytesAcked, &a0) { + } else { + a0.ackTime = sentPacketPointer.lastAckedPacketAckTime + a0.totalBytesAcked = sentPacketPointer.sendTimeState.totalBytesAcked + } + + // During the slope calculation, ensure that ack time of the current packet is + // always larger than the time of the previous packet, otherwise division by + // zero or integer underflow can occur. + if ackTime.Sub(a0.ackTime) <= 0 { + return *sample + } + + ackRate := BandwidthFromDelta(b.totalBytesAcked-a0.totalBytesAcked, ackTime.Sub(a0.ackTime)) + + sample.bandwidth = min(sendRate, ackRate) + // Note: this sample does not account for delayed acknowledgement time. This + // means that the RTT measurements here can be artificially high, especially + // on low bandwidth connections. + sample.rtt = ackTime.Sub(sentPacketPointer.sentTime) + sample.sendRate = sendRate + sentPacketToSendTimeState(sentPacketPointer, &sample.stateAtSend) + + return *sample +} + +func (b *bandwidthSampler) onAckEventEnd( + bandwidthEstimate Bandwidth, + isNewMaxBandwidth bool, + roundTripCount roundTripCount, +) congestion.ByteCount { + newlyAckedBytes := b.totalBytesAcked - b.totalBytesAckedAfterLastAckEvent + if newlyAckedBytes == 0 { + return 0 + } + b.totalBytesAckedAfterLastAckEvent = b.totalBytesAcked + extraAcked := b.maxAckHeightTracker.Update( + bandwidthEstimate, + isNewMaxBandwidth, + roundTripCount, + b.lastSentPacket, + b.lastAckedPacket, + b.lastAckedPacketAckTime, + newlyAckedBytes) + // If |extra_acked| is zero, i.e. this ack event marks the start of a new ack + // aggregation epoch, save LessRecentPoint, which is the last ack point of the + // previous epoch, as a A0 candidate. + if b.overestimateAvoidance && extraAcked == 0 { + b.a0Candidates.PushBack(*b.recentAckPoints.LessRecentPoint()) + } + return extraAcked +} + +func sentPacketToSendTimeState(sentPacket *connectionStateOnSentPacket, sendTimeState *sendTimeState) { + *sendTimeState = sentPacket.sendTimeState + sendTimeState.isValid = true +} + +// BytesFromBandwidthAndTimeDelta calculates the bytes +// from a bandwidth(bits per second) and a time delta +func bytesFromBandwidthAndTimeDelta(bandwidth Bandwidth, delta time.Duration) congestion.ByteCount { + return (congestion.ByteCount(bandwidth) * congestion.ByteCount(delta)) / + (congestion.ByteCount(time.Second) * 8) +} + +func timeDeltaFromBytesAndBandwidth(bytes congestion.ByteCount, bandwidth Bandwidth) time.Duration { + return time.Duration(bytes*8) * time.Second / time.Duration(bandwidth) +} diff --git a/protocol/hysteria2/internal/congestion/bbr/bbr_sender.go b/protocol/hysteria2/internal/congestion/bbr/bbr_sender.go new file mode 100644 index 0000000..63f5528 --- /dev/null +++ b/protocol/hysteria2/internal/congestion/bbr/bbr_sender.go @@ -0,0 +1,984 @@ +package bbr + +import ( + "fmt" + "math/rand" + "net" + "os" + "strconv" + "time" + + "github.com/daeuniverse/quic-go/congestion" + + "github.com/daeuniverse/outbound/protocol/hysteria2/internal/congestion/common" +) + +// BbrSender implements BBR congestion control algorithm. BBR aims to estimate +// the current available Bottleneck Bandwidth and RTT (hence the name), and +// regulates the pacing rate and the size of the congestion window based on +// those signals. +// +// BBR relies on pacing in order to function properly. Do not use BBR when +// pacing is disabled. +// + +const ( + minBps = 65536 // 64 kbps + + invalidPacketNumber = -1 + initialCongestionWindowPackets = 32 + + // Constants based on TCP defaults. + // The minimum CWND to ensure delayed acks don't reduce bandwidth measurements. + // Does not inflate the pacing rate. + defaultMinimumCongestionWindow = 4 * congestion.ByteCount(congestion.InitialPacketSizeIPv4) + + // The gain used for the STARTUP, equal to 2/ln(2). + defaultHighGain = 2.885 + // The newly derived gain for STARTUP, equal to 4 * ln(2) + derivedHighGain = 2.773 + // The newly derived CWND gain for STARTUP, 2. + derivedHighCWNDGain = 2.0 + + debugEnv = "HYSTERIA_BBR_DEBUG" +) + +// The cycle of gains used during the PROBE_BW stage. +var pacingGain = [...]float64{1.25, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0} + +const ( + // The length of the gain cycle. + gainCycleLength = len(pacingGain) + // The size of the bandwidth filter window, in round-trips. + bandwidthWindowSize = gainCycleLength + 2 + + // The time after which the current min_rtt value expires. + minRttExpiry = 10 * time.Second + // The minimum time the connection can spend in PROBE_RTT mode. + probeRttTime = 200 * time.Millisecond + // If the bandwidth does not increase by the factor of |kStartupGrowthTarget| + // within |kRoundTripsWithoutGrowthBeforeExitingStartup| rounds, the connection + // will exit the STARTUP mode. + startupGrowthTarget = 1.25 + roundTripsWithoutGrowthBeforeExitingStartup = int64(3) + + // Flag. + defaultStartupFullLossCount = 8 + quicBbr2DefaultLossThreshold = 0.02 + maxBbrBurstPackets = 10 +) + +type bbrMode int + +const ( + // Startup phase of the connection. + bbrModeStartup = iota + // After achieving the highest possible bandwidth during the startup, lower + // the pacing rate in order to drain the queue. + bbrModeDrain + // Cruising mode. + bbrModeProbeBw + // Temporarily slow down sending in order to empty the buffer and measure + // the real minimum RTT. + bbrModeProbeRtt +) + +// Indicates how the congestion control limits the amount of bytes in flight. +type bbrRecoveryState int + +const ( + // Do not limit. + bbrRecoveryStateNotInRecovery = iota + // Allow an extra outstanding byte for each byte acknowledged. + bbrRecoveryStateConservation + // Allow two extra outstanding bytes for each byte acknowledged (slow + // start). + bbrRecoveryStateGrowth +) + +type bbrSender struct { + rttStats congestion.RTTStatsProvider + clock Clock + pacer *common.Pacer + + mode bbrMode + + // Bandwidth sampler provides BBR with the bandwidth measurements at + // individual points. + sampler *bandwidthSampler + + // The number of the round trips that have occurred during the connection. + roundTripCount roundTripCount + + // The packet number of the most recently sent packet. + lastSentPacket congestion.PacketNumber + // Acknowledgement of any packet after |current_round_trip_end_| will cause + // the round trip counter to advance. + currentRoundTripEnd congestion.PacketNumber + + // Number of congestion events with some losses, in the current round. + numLossEventsInRound uint64 + + // Number of total bytes lost in the current round. + bytesLostInRound congestion.ByteCount + + // The filter that tracks the maximum bandwidth over the multiple recent + // round-trips. + maxBandwidth *WindowedFilter[Bandwidth, roundTripCount] + + // Minimum RTT estimate. Automatically expires within 10 seconds (and + // triggers PROBE_RTT mode) if no new value is sampled during that period. + minRtt time.Duration + // The time at which the current value of |min_rtt_| was assigned. + minRttTimestamp time.Time + + // The maximum allowed number of bytes in flight. + congestionWindow congestion.ByteCount + + // The initial value of the |congestion_window_|. + initialCongestionWindow congestion.ByteCount + + // The largest value the |congestion_window_| can achieve. + maxCongestionWindow congestion.ByteCount + + // The smallest value the |congestion_window_| can achieve. + minCongestionWindow congestion.ByteCount + + // The pacing gain applied during the STARTUP phase. + highGain float64 + + // The CWND gain applied during the STARTUP phase. + highCwndGain float64 + + // The pacing gain applied during the DRAIN phase. + drainGain float64 + + // The current pacing rate of the connection. + pacingRate Bandwidth + + // The gain currently applied to the pacing rate. + pacingGain float64 + // The gain currently applied to the congestion window. + congestionWindowGain float64 + + // The gain used for the congestion window during PROBE_BW. Latched from + // quic_bbr_cwnd_gain flag. + congestionWindowGainConstant float64 + // The number of RTTs to stay in STARTUP mode. Defaults to 3. + numStartupRtts int64 + + // Number of round-trips in PROBE_BW mode, used for determining the current + // pacing gain cycle. + cycleCurrentOffset int + // The time at which the last pacing gain cycle was started. + lastCycleStart time.Time + + // Indicates whether the connection has reached the full bandwidth mode. + isAtFullBandwidth bool + // Number of rounds during which there was no significant bandwidth increase. + roundsWithoutBandwidthGain int64 + // The bandwidth compared to which the increase is measured. + bandwidthAtLastRound Bandwidth + + // Set to true upon exiting quiescence. + exitingQuiescence bool + + // Time at which PROBE_RTT has to be exited. Setting it to zero indicates + // that the time is yet unknown as the number of packets in flight has not + // reached the required value. + exitProbeRttAt time.Time + // Indicates whether a round-trip has passed since PROBE_RTT became active. + probeRttRoundPassed bool + + // Indicates whether the most recent bandwidth sample was marked as + // app-limited. + lastSampleIsAppLimited bool + // Indicates whether any non app-limited samples have been recorded. + hasNoAppLimitedSample bool + + // Current state of recovery. + recoveryState bbrRecoveryState + // Receiving acknowledgement of a packet after |end_recovery_at_| will cause + // BBR to exit the recovery mode. A value above zero indicates at least one + // loss has been detected, so it must not be set back to zero. + endRecoveryAt congestion.PacketNumber + // A window used to limit the number of bytes in flight during loss recovery. + recoveryWindow congestion.ByteCount + // If true, consider all samples in recovery app-limited. + isAppLimitedRecovery bool // not used + + // When true, pace at 1.5x and disable packet conservation in STARTUP. + slowerStartup bool // not used + // When true, disables packet conservation in STARTUP. + rateBasedStartup bool // not used + + // When true, add the most recent ack aggregation measurement during STARTUP. + enableAckAggregationDuringStartup bool + // When true, expire the windowed ack aggregation values in STARTUP when + // bandwidth increases more than 25%. + expireAckAggregationInStartup bool + + // If true, will not exit low gain mode until bytes_in_flight drops below BDP + // or it's time for high gain mode. + drainToTarget bool + + // If true, slow down pacing rate in STARTUP when overshooting is detected. + detectOvershooting bool + // Bytes lost while detect_overshooting_ is true. + bytesLostWhileDetectingOvershooting congestion.ByteCount + // Slow down pacing rate if + // bytes_lost_while_detecting_overshooting_ * + // bytes_lost_multiplier_while_detecting_overshooting_ > IW. + bytesLostMultiplierWhileDetectingOvershooting uint8 + // When overshooting is detected, do not drop pacing_rate_ below this value / + // min_rtt. + cwndToCalculateMinPacingRate congestion.ByteCount + + // Max congestion window when adjusting network parameters. + maxCongestionWindowWithNetworkParametersAdjusted congestion.ByteCount // not used + + // Params. + maxDatagramSize congestion.ByteCount + // Recorded on packet sent. equivalent |unacked_packets_->bytes_in_flight()| + bytesInFlight congestion.ByteCount + + debug bool +} + +var _ congestion.CongestionControl = &bbrSender{} + +func NewBbrSender( + clock Clock, + initialMaxDatagramSize congestion.ByteCount, +) *bbrSender { + return newBbrSender( + clock, + initialMaxDatagramSize, + initialCongestionWindowPackets*initialMaxDatagramSize, + congestion.MaxCongestionWindowPackets*initialMaxDatagramSize, + ) +} + +func newBbrSender( + clock Clock, + initialMaxDatagramSize, + initialCongestionWindow, + initialMaxCongestionWindow congestion.ByteCount, +) *bbrSender { + debug, _ := strconv.ParseBool(os.Getenv(debugEnv)) + b := &bbrSender{ + clock: clock, + mode: bbrModeStartup, + sampler: newBandwidthSampler(roundTripCount(bandwidthWindowSize)), + lastSentPacket: invalidPacketNumber, + currentRoundTripEnd: invalidPacketNumber, + maxBandwidth: NewWindowedFilter(roundTripCount(bandwidthWindowSize), MaxFilter[Bandwidth]), + congestionWindow: initialCongestionWindow, + initialCongestionWindow: initialCongestionWindow, + maxCongestionWindow: initialMaxCongestionWindow, + minCongestionWindow: defaultMinimumCongestionWindow, + highGain: defaultHighGain, + highCwndGain: defaultHighGain, + drainGain: 1.0 / defaultHighGain, + pacingGain: 1.0, + congestionWindowGain: 1.0, + congestionWindowGainConstant: 2.0, + numStartupRtts: roundTripsWithoutGrowthBeforeExitingStartup, + recoveryState: bbrRecoveryStateNotInRecovery, + endRecoveryAt: invalidPacketNumber, + recoveryWindow: initialMaxCongestionWindow, + bytesLostMultiplierWhileDetectingOvershooting: 2, + cwndToCalculateMinPacingRate: initialCongestionWindow, + maxCongestionWindowWithNetworkParametersAdjusted: initialMaxCongestionWindow, + maxDatagramSize: initialMaxDatagramSize, + debug: debug, + } + b.pacer = common.NewPacer(b.bandwidthForPacer) + + /* + if b.tracer != nil { + b.lastState = logging.CongestionStateStartup + b.tracer.UpdatedCongestionState(logging.CongestionStateStartup) + } + */ + + b.enterStartupMode(b.clock.Now()) + b.setHighCwndGain(derivedHighCWNDGain) + + return b +} + +func (b *bbrSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) { + b.rttStats = provider +} + +// TimeUntilSend implements the SendAlgorithm interface. +func (b *bbrSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time { + return b.pacer.TimeUntilSend() +} + +// HasPacingBudget implements the SendAlgorithm interface. +func (b *bbrSender) HasPacingBudget(now time.Time) bool { + return b.pacer.Budget(now) >= b.maxDatagramSize +} + +// OnPacketSent implements the SendAlgorithm interface. +func (b *bbrSender) OnPacketSent( + sentTime time.Time, + bytesInFlight congestion.ByteCount, + packetNumber congestion.PacketNumber, + bytes congestion.ByteCount, + isRetransmittable bool, +) { + b.pacer.SentPacket(sentTime, bytes) + + b.lastSentPacket = packetNumber + b.bytesInFlight = bytesInFlight + + if bytesInFlight == 0 { + b.exitingQuiescence = true + } + + b.sampler.OnPacketSent(sentTime, packetNumber, bytes, bytesInFlight, isRetransmittable) +} + +// CanSend implements the SendAlgorithm interface. +func (b *bbrSender) CanSend(bytesInFlight congestion.ByteCount) bool { + return bytesInFlight < b.GetCongestionWindow() +} + +// MaybeExitSlowStart implements the SendAlgorithm interface. +func (b *bbrSender) MaybeExitSlowStart() { + // Do nothing +} + +// OnPacketAcked implements the SendAlgorithm interface. +func (b *bbrSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes, priorInFlight congestion.ByteCount, eventTime time.Time) { + // Do nothing. +} + +// OnPacketLost implements the SendAlgorithm interface. +func (b *bbrSender) OnPacketLost(number congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) { + // Do nothing. +} + +// OnRetransmissionTimeout implements the SendAlgorithm interface. +func (b *bbrSender) OnRetransmissionTimeout(packetsRetransmitted bool) { + // Do nothing. +} + +// SetMaxDatagramSize implements the SendAlgorithm interface. +func (b *bbrSender) SetMaxDatagramSize(s congestion.ByteCount) { + if s < b.maxDatagramSize { + panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", b.maxDatagramSize, s)) + } + cwndIsMinCwnd := b.congestionWindow == b.minCongestionWindow + b.maxDatagramSize = s + if cwndIsMinCwnd { + b.congestionWindow = b.minCongestionWindow + } + b.pacer.SetMaxDatagramSize(s) +} + +// InSlowStart implements the SendAlgorithmWithDebugInfos interface. +func (b *bbrSender) InSlowStart() bool { + return b.mode == bbrModeStartup +} + +// InRecovery implements the SendAlgorithmWithDebugInfos interface. +func (b *bbrSender) InRecovery() bool { + return b.recoveryState != bbrRecoveryStateNotInRecovery +} + +// GetCongestionWindow implements the SendAlgorithmWithDebugInfos interface. +func (b *bbrSender) GetCongestionWindow() congestion.ByteCount { + if b.mode == bbrModeProbeRtt { + return b.probeRttCongestionWindow() + } + + if b.InRecovery() { + return min(b.congestionWindow, b.recoveryWindow) + } + + return b.congestionWindow +} + +func (b *bbrSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) { + // Do nothing. +} + +func (b *bbrSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) { + totalBytesAckedBefore := b.sampler.TotalBytesAcked() + totalBytesLostBefore := b.sampler.TotalBytesLost() + + var isRoundStart, minRttExpired bool + var excessAcked, bytesLost congestion.ByteCount + + // The send state of the largest packet in acked_packets, unless it is + // empty. If acked_packets is empty, it's the send state of the largest + // packet in lost_packets. + var lastPacketSendState sendTimeState + + b.maybeAppLimited(priorInFlight) + + // Update bytesInFlight + b.bytesInFlight = priorInFlight + for _, p := range ackedPackets { + b.bytesInFlight -= p.BytesAcked + } + for _, p := range lostPackets { + b.bytesInFlight -= p.BytesLost + } + + if len(ackedPackets) != 0 { + lastAckedPacket := ackedPackets[len(ackedPackets)-1].PacketNumber + isRoundStart = b.updateRoundTripCounter(lastAckedPacket) + b.updateRecoveryState(lastAckedPacket, len(lostPackets) != 0, isRoundStart) + } + + sample := b.sampler.OnCongestionEvent(eventTime, + ackedPackets, lostPackets, b.maxBandwidth.GetBest(), infBandwidth, b.roundTripCount) + if sample.lastPacketSendState.isValid { + b.lastSampleIsAppLimited = sample.lastPacketSendState.isAppLimited + b.hasNoAppLimitedSample = b.hasNoAppLimitedSample || !b.lastSampleIsAppLimited + } + // Avoid updating |max_bandwidth_| if a) this is a loss-only event, or b) all + // packets in |acked_packets| did not generate valid samples. (e.g. ack of + // ack-only packets). In both cases, sampler_.total_bytes_acked() will not + // change. + if totalBytesAckedBefore != b.sampler.TotalBytesAcked() { + if !sample.sampleIsAppLimited || sample.sampleMaxBandwidth > b.maxBandwidth.GetBest() { + b.maxBandwidth.Update(sample.sampleMaxBandwidth, b.roundTripCount) + } + } + + if sample.sampleRtt != infRTT { + minRttExpired = b.maybeUpdateMinRtt(eventTime, sample.sampleRtt) + } + bytesLost = b.sampler.TotalBytesLost() - totalBytesLostBefore + + excessAcked = sample.extraAcked + lastPacketSendState = sample.lastPacketSendState + + if len(lostPackets) != 0 { + b.numLossEventsInRound++ + b.bytesLostInRound += bytesLost + } + + // Handle logic specific to PROBE_BW mode. + if b.mode == bbrModeProbeBw { + b.updateGainCyclePhase(eventTime, priorInFlight, len(lostPackets) != 0) + } + + // Handle logic specific to STARTUP and DRAIN modes. + if isRoundStart && !b.isAtFullBandwidth { + b.checkIfFullBandwidthReached(&lastPacketSendState) + } + + b.maybeExitStartupOrDrain(eventTime) + + // Handle logic specific to PROBE_RTT. + b.maybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired) + + // Calculate number of packets acked and lost. + bytesAcked := b.sampler.TotalBytesAcked() - totalBytesAckedBefore + + // After the model is updated, recalculate the pacing rate and congestion + // window. + b.calculatePacingRate(bytesLost) + b.calculateCongestionWindow(bytesAcked, excessAcked) + b.calculateRecoveryWindow(bytesAcked, bytesLost) + + // Cleanup internal state. + // This is where we clean up obsolete (acked or lost) packets from the bandwidth sampler. + // The "least unacked" should actually be FirstOutstanding, but since we are not passing + // that through OnCongestionEventEx, we will only do an estimate using acked/lost packets + // for now. Because of fast retransmission, they should differ by no more than 2 packets. + // (this is controlled by packetThreshold in quic-go's sentPacketHandler) + var leastUnacked congestion.PacketNumber + if len(ackedPackets) != 0 { + leastUnacked = ackedPackets[len(ackedPackets)-1].PacketNumber - 2 + } else { + leastUnacked = lostPackets[len(lostPackets)-1].PacketNumber + 1 + } + b.sampler.RemoveObsoletePackets(leastUnacked) + + if isRoundStart { + b.numLossEventsInRound = 0 + b.bytesLostInRound = 0 + } +} + +func (b *bbrSender) PacingRate() Bandwidth { + if b.pacingRate == 0 { + return Bandwidth(b.highGain * float64( + BandwidthFromDelta(b.initialCongestionWindow, b.getMinRtt()))) + } + + return b.pacingRate +} + +func (b *bbrSender) hasGoodBandwidthEstimateForResumption() bool { + return b.hasNonAppLimitedSample() +} + +func (b *bbrSender) hasNonAppLimitedSample() bool { + return b.hasNoAppLimitedSample +} + +// Sets the pacing gain used in STARTUP. Must be greater than 1. +func (b *bbrSender) setHighGain(highGain float64) { + b.highGain = highGain + if b.mode == bbrModeStartup { + b.pacingGain = highGain + } +} + +// Sets the CWND gain used in STARTUP. Must be greater than 1. +func (b *bbrSender) setHighCwndGain(highCwndGain float64) { + b.highCwndGain = highCwndGain + if b.mode == bbrModeStartup { + b.congestionWindowGain = highCwndGain + } +} + +// Sets the gain used in DRAIN. Must be less than 1. +func (b *bbrSender) setDrainGain(drainGain float64) { + b.drainGain = drainGain +} + +// Get the current bandwidth estimate. Note that Bandwidth is in bits per second. +func (b *bbrSender) bandwidthEstimate() Bandwidth { + return b.maxBandwidth.GetBest() +} + +func (b *bbrSender) bandwidthForPacer() congestion.ByteCount { + bps := congestion.ByteCount(float64(b.bandwidthEstimate()) * b.congestionWindowGain / float64(BytesPerSecond)) + if bps < minBps { + // We need to make sure that the bandwidth value for pacer is never zero, + // otherwise it will go into an edge case where HasPacingBudget = false + // but TimeUntilSend is before, causing the quic-go send loop to go crazy and get stuck. + return minBps + } + return bps +} + +// Returns the current estimate of the RTT of the connection. Outside of the +// edge cases, this is minimum RTT. +func (b *bbrSender) getMinRtt() time.Duration { + if b.minRtt != 0 { + return b.minRtt + } + // min_rtt could be available if the handshake packet gets neutered then + // gets acknowledged. This could only happen for QUIC crypto where we do not + // drop keys. + minRtt := b.rttStats.MinRTT() + if minRtt == 0 { + return 100 * time.Millisecond + } else { + return minRtt + } +} + +// Computes the target congestion window using the specified gain. +func (b *bbrSender) getTargetCongestionWindow(gain float64) congestion.ByteCount { + bdp := bdpFromRttAndBandwidth(b.getMinRtt(), b.bandwidthEstimate()) + congestionWindow := congestion.ByteCount(gain * float64(bdp)) + + // BDP estimate will be zero if no bandwidth samples are available yet. + if congestionWindow == 0 { + congestionWindow = congestion.ByteCount(gain * float64(b.initialCongestionWindow)) + } + + return max(congestionWindow, b.minCongestionWindow) +} + +// The target congestion window during PROBE_RTT. +func (b *bbrSender) probeRttCongestionWindow() congestion.ByteCount { + return b.minCongestionWindow +} + +func (b *bbrSender) maybeUpdateMinRtt(now time.Time, sampleMinRtt time.Duration) bool { + // Do not expire min_rtt if none was ever available. + minRttExpired := b.minRtt != 0 && now.After(b.minRttTimestamp.Add(minRttExpiry)) + if minRttExpired || sampleMinRtt < b.minRtt || b.minRtt == 0 { + b.minRtt = sampleMinRtt + b.minRttTimestamp = now + } + + return minRttExpired +} + +// Enters the STARTUP mode. +func (b *bbrSender) enterStartupMode(now time.Time) { + b.mode = bbrModeStartup + // b.maybeTraceStateChange(logging.CongestionStateStartup) + b.pacingGain = b.highGain + b.congestionWindowGain = b.highCwndGain + + if b.debug { + b.debugPrint("Phase: STARTUP") + } +} + +// Enters the PROBE_BW mode. +func (b *bbrSender) enterProbeBandwidthMode(now time.Time) { + b.mode = bbrModeProbeBw + // b.maybeTraceStateChange(logging.CongestionStateProbeBw) + b.congestionWindowGain = b.congestionWindowGainConstant + + // Pick a random offset for the gain cycle out of {0, 2..7} range. 1 is + // excluded because in that case increased gain and decreased gain would not + // follow each other. + b.cycleCurrentOffset = int(rand.Int31n(congestion.PacketsPerConnectionID)) % (gainCycleLength - 1) + if b.cycleCurrentOffset >= 1 { + b.cycleCurrentOffset += 1 + } + + b.lastCycleStart = now + b.pacingGain = pacingGain[b.cycleCurrentOffset] + + if b.debug { + b.debugPrint("Phase: PROBE_BW") + } +} + +// Updates the round-trip counter if a round-trip has passed. Returns true if +// the counter has been advanced. +func (b *bbrSender) updateRoundTripCounter(lastAckedPacket congestion.PacketNumber) bool { + if b.currentRoundTripEnd == invalidPacketNumber || lastAckedPacket > b.currentRoundTripEnd { + b.roundTripCount++ + b.currentRoundTripEnd = b.lastSentPacket + return true + } + return false +} + +// Updates the current gain used in PROBE_BW mode. +func (b *bbrSender) updateGainCyclePhase(now time.Time, priorInFlight congestion.ByteCount, hasLosses bool) { + // In most cases, the cycle is advanced after an RTT passes. + shouldAdvanceGainCycling := now.After(b.lastCycleStart.Add(b.getMinRtt())) + // If the pacing gain is above 1.0, the connection is trying to probe the + // bandwidth by increasing the number of bytes in flight to at least + // pacing_gain * BDP. Make sure that it actually reaches the target, as long + // as there are no losses suggesting that the buffers are not able to hold + // that much. + if b.pacingGain > 1.0 && !hasLosses && priorInFlight < b.getTargetCongestionWindow(b.pacingGain) { + shouldAdvanceGainCycling = false + } + + // If pacing gain is below 1.0, the connection is trying to drain the extra + // queue which could have been incurred by probing prior to it. If the number + // of bytes in flight falls down to the estimated BDP value earlier, conclude + // that the queue has been successfully drained and exit this cycle early. + if b.pacingGain < 1.0 && b.bytesInFlight <= b.getTargetCongestionWindow(1) { + shouldAdvanceGainCycling = true + } + + if shouldAdvanceGainCycling { + b.cycleCurrentOffset = (b.cycleCurrentOffset + 1) % gainCycleLength + b.lastCycleStart = now + // Stay in low gain mode until the target BDP is hit. + // Low gain mode will be exited immediately when the target BDP is achieved. + if b.drainToTarget && b.pacingGain < 1 && + pacingGain[b.cycleCurrentOffset] == 1 && + b.bytesInFlight > b.getTargetCongestionWindow(1) { + return + } + b.pacingGain = pacingGain[b.cycleCurrentOffset] + } +} + +// Tracks for how many round-trips the bandwidth has not increased +// significantly. +func (b *bbrSender) checkIfFullBandwidthReached(lastPacketSendState *sendTimeState) { + if b.lastSampleIsAppLimited { + return + } + + target := Bandwidth(float64(b.bandwidthAtLastRound) * startupGrowthTarget) + if b.bandwidthEstimate() >= target { + b.bandwidthAtLastRound = b.bandwidthEstimate() + b.roundsWithoutBandwidthGain = 0 + if b.expireAckAggregationInStartup { + // Expire old excess delivery measurements now that bandwidth increased. + b.sampler.ResetMaxAckHeightTracker(0, b.roundTripCount) + } + return + } + + b.roundsWithoutBandwidthGain++ + if b.roundsWithoutBandwidthGain >= b.numStartupRtts || + b.shouldExitStartupDueToLoss(lastPacketSendState) { + b.isAtFullBandwidth = true + } +} + +func (b *bbrSender) maybeAppLimited(bytesInFlight congestion.ByteCount) { + if bytesInFlight < b.getTargetCongestionWindow(1) { + b.sampler.OnAppLimited() + } +} + +// Transitions from STARTUP to DRAIN and from DRAIN to PROBE_BW if +// appropriate. +func (b *bbrSender) maybeExitStartupOrDrain(now time.Time) { + if b.mode == bbrModeStartup && b.isAtFullBandwidth { + b.mode = bbrModeDrain + // b.maybeTraceStateChange(logging.CongestionStateDrain) + b.pacingGain = b.drainGain + b.congestionWindowGain = b.highCwndGain + + if b.debug { + b.debugPrint("Phase: DRAIN") + } + } + if b.mode == bbrModeDrain && b.bytesInFlight <= b.getTargetCongestionWindow(1) { + b.enterProbeBandwidthMode(now) + } +} + +// Decides whether to enter or exit PROBE_RTT. +func (b *bbrSender) maybeEnterOrExitProbeRtt(now time.Time, isRoundStart, minRttExpired bool) { + if minRttExpired && !b.exitingQuiescence && b.mode != bbrModeProbeRtt { + b.mode = bbrModeProbeRtt + // b.maybeTraceStateChange(logging.CongestionStateProbRtt) + b.pacingGain = 1.0 + // Do not decide on the time to exit PROBE_RTT until the |bytes_in_flight| + // is at the target small value. + b.exitProbeRttAt = time.Time{} + + if b.debug { + b.debugPrint("BandwidthEstimate: %s, CongestionWindowGain: %.2f, PacingGain: %.2f, PacingRate: %s", + formatSpeed(b.bandwidthEstimate()), b.congestionWindowGain, b.pacingGain, formatSpeed(b.PacingRate())) + b.debugPrint("Phase: PROBE_RTT") + } + } + + if b.mode == bbrModeProbeRtt { + b.sampler.OnAppLimited() + // b.maybeTraceStateChange(logging.CongestionStateApplicationLimited) + + if b.exitProbeRttAt.IsZero() { + // If the window has reached the appropriate size, schedule exiting + // PROBE_RTT. The CWND during PROBE_RTT is kMinimumCongestionWindow, but + // we allow an extra packet since QUIC checks CWND before sending a + // packet. + if b.bytesInFlight < b.probeRttCongestionWindow()+congestion.MaxPacketBufferSize { + b.exitProbeRttAt = now.Add(probeRttTime) + b.probeRttRoundPassed = false + } + } else { + if isRoundStart { + b.probeRttRoundPassed = true + } + if now.Sub(b.exitProbeRttAt) >= 0 && b.probeRttRoundPassed { + b.minRttTimestamp = now + if b.debug { + b.debugPrint("MinRTT: %s", b.getMinRtt()) + } + if !b.isAtFullBandwidth { + b.enterStartupMode(now) + } else { + b.enterProbeBandwidthMode(now) + } + } + } + } + + b.exitingQuiescence = false +} + +// Determines whether BBR needs to enter, exit or advance state of the +// recovery. +func (b *bbrSender) updateRecoveryState(lastAckedPacket congestion.PacketNumber, hasLosses, isRoundStart bool) { + // Disable recovery in startup, if loss-based exit is enabled. + if !b.isAtFullBandwidth { + return + } + + // Exit recovery when there are no losses for a round. + if hasLosses { + b.endRecoveryAt = b.lastSentPacket + } + + switch b.recoveryState { + case bbrRecoveryStateNotInRecovery: + if hasLosses { + b.recoveryState = bbrRecoveryStateConservation + // This will cause the |recovery_window_| to be set to the correct + // value in CalculateRecoveryWindow(). + b.recoveryWindow = 0 + // Since the conservation phase is meant to be lasting for a whole + // round, extend the current round as if it were started right now. + b.currentRoundTripEnd = b.lastSentPacket + } + case bbrRecoveryStateConservation: + if isRoundStart { + b.recoveryState = bbrRecoveryStateGrowth + } + fallthrough + case bbrRecoveryStateGrowth: + // Exit recovery if appropriate. + if !hasLosses && lastAckedPacket > b.endRecoveryAt { + b.recoveryState = bbrRecoveryStateNotInRecovery + } + } +} + +// Determines the appropriate pacing rate for the connection. +func (b *bbrSender) calculatePacingRate(bytesLost congestion.ByteCount) { + if b.bandwidthEstimate() == 0 { + return + } + + targetRate := Bandwidth(b.pacingGain * float64(b.bandwidthEstimate())) + if b.isAtFullBandwidth { + b.pacingRate = targetRate + return + } + + // Pace at the rate of initial_window / RTT as soon as RTT measurements are + // available. + if b.pacingRate == 0 && b.rttStats.MinRTT() != 0 { + b.pacingRate = BandwidthFromDelta(b.initialCongestionWindow, b.rttStats.MinRTT()) + return + } + + if b.detectOvershooting { + b.bytesLostWhileDetectingOvershooting += bytesLost + // Check for overshooting with network parameters adjusted when pacing rate + // > target_rate and loss has been detected. + if b.pacingRate > targetRate && b.bytesLostWhileDetectingOvershooting > 0 { + if b.hasNoAppLimitedSample || + b.bytesLostWhileDetectingOvershooting*congestion.ByteCount(b.bytesLostMultiplierWhileDetectingOvershooting) > b.initialCongestionWindow { + // We are fairly sure overshoot happens if 1) there is at least one + // non app-limited bw sample or 2) half of IW gets lost. Slow pacing + // rate. + b.pacingRate = max(targetRate, BandwidthFromDelta(b.cwndToCalculateMinPacingRate, b.rttStats.MinRTT())) + b.bytesLostWhileDetectingOvershooting = 0 + b.detectOvershooting = false + } + } + } + + // Do not decrease the pacing rate during startup. + b.pacingRate = max(b.pacingRate, targetRate) +} + +// Determines the appropriate congestion window for the connection. +func (b *bbrSender) calculateCongestionWindow(bytesAcked, excessAcked congestion.ByteCount) { + if b.mode == bbrModeProbeRtt { + return + } + + targetWindow := b.getTargetCongestionWindow(b.congestionWindowGain) + if b.isAtFullBandwidth { + // Add the max recently measured ack aggregation to CWND. + targetWindow += b.sampler.MaxAckHeight() + } else if b.enableAckAggregationDuringStartup { + // Add the most recent excess acked. Because CWND never decreases in + // STARTUP, this will automatically create a very localized max filter. + targetWindow += excessAcked + } + + // Instead of immediately setting the target CWND as the new one, BBR grows + // the CWND towards |target_window| by only increasing it |bytes_acked| at a + // time. + if b.isAtFullBandwidth { + b.congestionWindow = min(targetWindow, b.congestionWindow+bytesAcked) + } else if b.congestionWindow < targetWindow || + b.sampler.TotalBytesAcked() < b.initialCongestionWindow { + // If the connection is not yet out of startup phase, do not decrease the + // window. + b.congestionWindow += bytesAcked + } + + // Enforce the limits on the congestion window. + b.congestionWindow = max(b.congestionWindow, b.minCongestionWindow) + b.congestionWindow = min(b.congestionWindow, b.maxCongestionWindow) +} + +// Determines the appropriate window that constrains the in-flight during recovery. +func (b *bbrSender) calculateRecoveryWindow(bytesAcked, bytesLost congestion.ByteCount) { + if b.recoveryState == bbrRecoveryStateNotInRecovery { + return + } + + // Set up the initial recovery window. + if b.recoveryWindow == 0 { + b.recoveryWindow = b.bytesInFlight + bytesAcked + b.recoveryWindow = max(b.minCongestionWindow, b.recoveryWindow) + return + } + + // Remove losses from the recovery window, while accounting for a potential + // integer underflow. + if b.recoveryWindow >= bytesLost { + b.recoveryWindow = b.recoveryWindow - bytesLost + } else { + b.recoveryWindow = b.maxDatagramSize + } + + // In CONSERVATION mode, just subtracting losses is sufficient. In GROWTH, + // release additional |bytes_acked| to achieve a slow-start-like behavior. + if b.recoveryState == bbrRecoveryStateGrowth { + b.recoveryWindow += bytesAcked + } + + // Always allow sending at least |bytes_acked| in response. + b.recoveryWindow = max(b.recoveryWindow, b.bytesInFlight+bytesAcked) + b.recoveryWindow = max(b.minCongestionWindow, b.recoveryWindow) +} + +// Return whether we should exit STARTUP due to excessive loss. +func (b *bbrSender) shouldExitStartupDueToLoss(lastPacketSendState *sendTimeState) bool { + if b.numLossEventsInRound < defaultStartupFullLossCount || !lastPacketSendState.isValid { + return false + } + + inflightAtSend := lastPacketSendState.bytesInFlight + + if inflightAtSend > 0 && b.bytesLostInRound > 0 { + if b.bytesLostInRound > congestion.ByteCount(float64(inflightAtSend)*quicBbr2DefaultLossThreshold) { + return true + } + return false + } + return false +} + +func (b *bbrSender) debugPrint(format string, a ...any) { + fmt.Printf("[BBRSender] [%s] %s\n", + time.Now().Format("15:04:05"), + fmt.Sprintf(format, a...)) +} + +func bdpFromRttAndBandwidth(rtt time.Duration, bandwidth Bandwidth) congestion.ByteCount { + return congestion.ByteCount(rtt) * congestion.ByteCount(bandwidth) / congestion.ByteCount(BytesPerSecond) / congestion.ByteCount(time.Second) +} + +func GetInitialPacketSize(addr net.Addr) congestion.ByteCount { + // If this is not a UDP address, we don't know anything about the MTU. + // Use the minimum size of an Initial packet as the max packet size. + if udpAddr, ok := addr.(*net.UDPAddr); ok { + if udpAddr.IP.To4() != nil { + return congestion.InitialPacketSizeIPv4 + } else { + return congestion.InitialPacketSizeIPv6 + } + } else { + return congestion.MinInitialPacketSize + } +} + +func formatSpeed(bw Bandwidth) string { + bwf := float64(bw) + units := []string{"bps", "Kbps", "Mbps", "Gbps"} + unitIndex := 0 + for bwf > 1024 && unitIndex < len(units)-1 { + bwf /= 1024 + unitIndex++ + } + return fmt.Sprintf("%.2f %s", bwf, units[unitIndex]) +} diff --git a/protocol/hysteria2/internal/congestion/bbr/clock.go b/protocol/hysteria2/internal/congestion/bbr/clock.go new file mode 100644 index 0000000..a66344f --- /dev/null +++ b/protocol/hysteria2/internal/congestion/bbr/clock.go @@ -0,0 +1,18 @@ +package bbr + +import "time" + +// A Clock returns the current time +type Clock interface { + Now() time.Time +} + +// DefaultClock implements the Clock interface using the Go stdlib clock. +type DefaultClock struct{} + +var _ Clock = DefaultClock{} + +// Now gets the current time +func (DefaultClock) Now() time.Time { + return time.Now() +} diff --git a/protocol/hysteria2/internal/congestion/bbr/packet_number_indexed_queue.go b/protocol/hysteria2/internal/congestion/bbr/packet_number_indexed_queue.go new file mode 100644 index 0000000..e9fad5a --- /dev/null +++ b/protocol/hysteria2/internal/congestion/bbr/packet_number_indexed_queue.go @@ -0,0 +1,199 @@ +package bbr + +import ( + "github.com/daeuniverse/quic-go/congestion" +) + +// packetNumberIndexedQueue is a queue of mostly continuous numbered entries +// which supports the following operations: +// - adding elements to the end of the queue, or at some point past the end +// - removing elements in any order +// - retrieving elements +// If all elements are inserted in order, all of the operations above are +// amortized O(1) time. +// +// Internally, the data structure is a deque where each element is marked as +// present or not. The deque starts at the lowest present index. Whenever an +// element is removed, it's marked as not present, and the front of the deque is +// cleared of elements that are not present. +// +// The tail of the queue is not cleared due to the assumption of entries being +// inserted in order, though removing all elements of the queue will return it +// to its initial state. +// +// Note that this data structure is inherently hazardous, since an addition of +// just two entries will cause it to consume all of the memory available. +// Because of that, it is not a general-purpose container and should not be used +// as one. + +type entryWrapper[T any] struct { + present bool + entry T +} + +type packetNumberIndexedQueue[T any] struct { + entries RingBuffer[entryWrapper[T]] + numberOfPresentEntries int + firstPacket congestion.PacketNumber +} + +func newPacketNumberIndexedQueue[T any](size int) *packetNumberIndexedQueue[T] { + q := &packetNumberIndexedQueue[T]{ + firstPacket: invalidPacketNumber, + } + + q.entries.Init(size) + + return q +} + +// Emplace inserts data associated |packet_number| into (or past) the end of the +// queue, filling up the missing intermediate entries as necessary. Returns +// true if the element has been inserted successfully, false if it was already +// in the queue or inserted out of order. +func (p *packetNumberIndexedQueue[T]) Emplace(packetNumber congestion.PacketNumber, entry *T) bool { + if packetNumber == invalidPacketNumber || entry == nil { + return false + } + + if p.IsEmpty() { + p.entries.PushBack(entryWrapper[T]{ + present: true, + entry: *entry, + }) + p.numberOfPresentEntries = 1 + p.firstPacket = packetNumber + return true + } + + // Do not allow insertion out-of-order. + if packetNumber <= p.LastPacket() { + return false + } + + // Handle potentially missing elements. + offset := int(packetNumber - p.FirstPacket()) + if gap := offset - p.entries.Len(); gap > 0 { + for i := 0; i < gap; i++ { + p.entries.PushBack(entryWrapper[T]{}) + } + } + + p.entries.PushBack(entryWrapper[T]{ + present: true, + entry: *entry, + }) + p.numberOfPresentEntries++ + return true +} + +// GetEntry Retrieve the entry associated with the packet number. Returns the pointer +// to the entry in case of success, or nullptr if the entry does not exist. +func (p *packetNumberIndexedQueue[T]) GetEntry(packetNumber congestion.PacketNumber) *T { + ew := p.getEntryWraper(packetNumber) + if ew == nil { + return nil + } + + return &ew.entry +} + +// Remove, Same as above, but if an entry is present in the queue, also call f(entry) +// before removing it. +func (p *packetNumberIndexedQueue[T]) Remove(packetNumber congestion.PacketNumber, f func(T)) bool { + ew := p.getEntryWraper(packetNumber) + if ew == nil { + return false + } + if f != nil { + f(ew.entry) + } + ew.present = false + p.numberOfPresentEntries-- + + if packetNumber == p.FirstPacket() { + p.clearup() + } + + return true +} + +// RemoveUpTo, but not including |packet_number|. +// Unused slots in the front are also removed, which means when the function +// returns, |first_packet()| can be larger than |packet_number|. +func (p *packetNumberIndexedQueue[T]) RemoveUpTo(packetNumber congestion.PacketNumber) { + for !p.entries.Empty() && + p.firstPacket != invalidPacketNumber && + p.firstPacket < packetNumber { + if p.entries.Front().present { + p.numberOfPresentEntries-- + } + p.entries.PopFront() + p.firstPacket++ + } + p.clearup() + + return +} + +// IsEmpty return if queue is empty. +func (p *packetNumberIndexedQueue[T]) IsEmpty() bool { + return p.numberOfPresentEntries == 0 +} + +// NumberOfPresentEntries returns the number of entries in the queue. +func (p *packetNumberIndexedQueue[T]) NumberOfPresentEntries() int { + return p.numberOfPresentEntries +} + +// EntrySlotsUsed returns the number of entries allocated in the underlying deque. This is +// proportional to the memory usage of the queue. +func (p *packetNumberIndexedQueue[T]) EntrySlotsUsed() int { + return p.entries.Len() +} + +// LastPacket returns packet number of the first entry in the queue. +func (p *packetNumberIndexedQueue[T]) FirstPacket() (packetNumber congestion.PacketNumber) { + return p.firstPacket +} + +// LastPacket returns packet number of the last entry ever inserted in the queue. Note that the +// entry in question may have already been removed. Zero if the queue is +// empty. +func (p *packetNumberIndexedQueue[T]) LastPacket() (packetNumber congestion.PacketNumber) { + if p.IsEmpty() { + return invalidPacketNumber + } + + return p.firstPacket + congestion.PacketNumber(p.entries.Len()-1) +} + +func (p *packetNumberIndexedQueue[T]) clearup() { + for !p.entries.Empty() && !p.entries.Front().present { + p.entries.PopFront() + p.firstPacket++ + } + if p.entries.Empty() { + p.firstPacket = invalidPacketNumber + } +} + +func (p *packetNumberIndexedQueue[T]) getEntryWraper(packetNumber congestion.PacketNumber) *entryWrapper[T] { + if packetNumber == invalidPacketNumber || + p.IsEmpty() || + packetNumber < p.firstPacket { + return nil + } + + offset := int(packetNumber - p.firstPacket) + if offset >= p.entries.Len() { + return nil + } + + ew := p.entries.Offset(offset) + if ew == nil || !ew.present { + return nil + } + + return ew +} diff --git a/protocol/hysteria2/internal/congestion/bbr/ringbuffer.go b/protocol/hysteria2/internal/congestion/bbr/ringbuffer.go new file mode 100644 index 0000000..ed92d4c --- /dev/null +++ b/protocol/hysteria2/internal/congestion/bbr/ringbuffer.go @@ -0,0 +1,118 @@ +package bbr + +// A RingBuffer is a ring buffer. +// It acts as a heap that doesn't cause any allocations. +type RingBuffer[T any] struct { + ring []T + headPos, tailPos int + full bool +} + +// Init preallocs a buffer with a certain size. +func (r *RingBuffer[T]) Init(size int) { + r.ring = make([]T, size) +} + +// Len returns the number of elements in the ring buffer. +func (r *RingBuffer[T]) Len() int { + if r.full { + return len(r.ring) + } + if r.tailPos >= r.headPos { + return r.tailPos - r.headPos + } + return r.tailPos - r.headPos + len(r.ring) +} + +// Empty says if the ring buffer is empty. +func (r *RingBuffer[T]) Empty() bool { + return !r.full && r.headPos == r.tailPos +} + +// PushBack adds a new element. +// If the ring buffer is full, its capacity is increased first. +func (r *RingBuffer[T]) PushBack(t T) { + if r.full || len(r.ring) == 0 { + r.grow() + } + r.ring[r.tailPos] = t + r.tailPos++ + if r.tailPos == len(r.ring) { + r.tailPos = 0 + } + if r.tailPos == r.headPos { + r.full = true + } +} + +// PopFront returns the next element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first. +func (r *RingBuffer[T]) PopFront() T { + if r.Empty() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: pop from an empty queue") + } + r.full = false + t := r.ring[r.headPos] + r.ring[r.headPos] = *new(T) + r.headPos++ + if r.headPos == len(r.ring) { + r.headPos = 0 + } + return t +} + +// Offset returns the offset element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first +// and check if the index larger than buffer length. +func (r *RingBuffer[T]) Offset(index int) *T { + if r.Empty() || index >= r.Len() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: offset from invalid index") + } + offset := (r.headPos + index) % len(r.ring) + return &r.ring[offset] +} + +// Front returns the front element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first. +func (r *RingBuffer[T]) Front() *T { + if r.Empty() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: front from an empty queue") + } + return &r.ring[r.headPos] +} + +// Back returns the back element. +// It must not be called when the buffer is empty, that means that +// callers might need to check if there are elements in the buffer first. +func (r *RingBuffer[T]) Back() *T { + if r.Empty() { + panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: back from an empty queue") + } + return r.Offset(r.Len() - 1) +} + +// Grow the maximum size of the queue. +// This method assume the queue is full. +func (r *RingBuffer[T]) grow() { + oldRing := r.ring + newSize := len(oldRing) * 2 + if newSize == 0 { + newSize = 1 + } + r.ring = make([]T, newSize) + headLen := copy(r.ring, oldRing[r.headPos:]) + copy(r.ring[headLen:], oldRing[:r.headPos]) + r.headPos, r.tailPos, r.full = 0, len(oldRing), false +} + +// Clear removes all elements. +func (r *RingBuffer[T]) Clear() { + var zeroValue T + for i := range r.ring { + r.ring[i] = zeroValue + } + r.headPos, r.tailPos, r.full = 0, 0, false +} diff --git a/protocol/hysteria2/internal/congestion/bbr/windowed_filter.go b/protocol/hysteria2/internal/congestion/bbr/windowed_filter.go new file mode 100644 index 0000000..4773bce --- /dev/null +++ b/protocol/hysteria2/internal/congestion/bbr/windowed_filter.go @@ -0,0 +1,162 @@ +package bbr + +import ( + "golang.org/x/exp/constraints" +) + +// Implements Kathleen Nichols' algorithm for tracking the minimum (or maximum) +// estimate of a stream of samples over some fixed time interval. (E.g., +// the minimum RTT over the past five minutes.) The algorithm keeps track of +// the best, second best, and third best min (or max) estimates, maintaining an +// invariant that the measurement time of the n'th best >= n-1'th best. + +// The algorithm works as follows. On a reset, all three estimates are set to +// the same sample. The second best estimate is then recorded in the second +// quarter of the window, and a third best estimate is recorded in the second +// half of the window, bounding the worst case error when the true min is +// monotonically increasing (or true max is monotonically decreasing) over the +// window. +// +// A new best sample replaces all three estimates, since the new best is lower +// (or higher) than everything else in the window and it is the most recent. +// The window thus effectively gets reset on every new min. The same property +// holds true for second best and third best estimates. Specifically, when a +// sample arrives that is better than the second best but not better than the +// best, it replaces the second and third best estimates but not the best +// estimate. Similarly, a sample that is better than the third best estimate +// but not the other estimates replaces only the third best estimate. +// +// Finally, when the best expires, it is replaced by the second best, which in +// turn is replaced by the third best. The newest sample replaces the third +// best. + +type WindowedFilterValue interface { + any +} + +type WindowedFilterTime interface { + constraints.Integer | constraints.Float +} + +type WindowedFilter[V WindowedFilterValue, T WindowedFilterTime] struct { + // Time length of window. + windowLength T + estimates []entry[V, T] + comparator func(V, V) int +} + +type entry[V WindowedFilterValue, T WindowedFilterTime] struct { + sample V + time T +} + +// Compares two values and returns true if the first is greater than or equal +// to the second. +func MaxFilter[O constraints.Ordered](a, b O) int { + if a > b { + return 1 + } else if a < b { + return -1 + } + return 0 +} + +// Compares two values and returns true if the first is less than or equal +// to the second. +func MinFilter[O constraints.Ordered](a, b O) int { + if a < b { + return 1 + } else if a > b { + return -1 + } + return 0 +} + +func NewWindowedFilter[V WindowedFilterValue, T WindowedFilterTime](windowLength T, comparator func(V, V) int) *WindowedFilter[V, T] { + return &WindowedFilter[V, T]{ + windowLength: windowLength, + estimates: make([]entry[V, T], 3, 3), + comparator: comparator, + } +} + +// Changes the window length. Does not update any current samples. +func (f *WindowedFilter[V, T]) SetWindowLength(windowLength T) { + f.windowLength = windowLength +} + +func (f *WindowedFilter[V, T]) GetBest() V { + return f.estimates[0].sample +} + +func (f *WindowedFilter[V, T]) GetSecondBest() V { + return f.estimates[1].sample +} + +func (f *WindowedFilter[V, T]) GetThirdBest() V { + return f.estimates[2].sample +} + +// Updates best estimates with |sample|, and expires and updates best +// estimates as necessary. +func (f *WindowedFilter[V, T]) Update(newSample V, newTime T) { + // Reset all estimates if they have not yet been initialized, if new sample + // is a new best, or if the newest recorded estimate is too old. + if f.comparator(f.estimates[0].sample, *new(V)) == 0 || + f.comparator(newSample, f.estimates[0].sample) >= 0 || + newTime-f.estimates[2].time > f.windowLength { + f.Reset(newSample, newTime) + return + } + + if f.comparator(newSample, f.estimates[1].sample) >= 0 { + f.estimates[1] = entry[V, T]{newSample, newTime} + f.estimates[2] = f.estimates[1] + } else if f.comparator(newSample, f.estimates[2].sample) >= 0 { + f.estimates[2] = entry[V, T]{newSample, newTime} + } + + // Expire and update estimates as necessary. + if newTime-f.estimates[0].time > f.windowLength { + // The best estimate hasn't been updated for an entire window, so promote + // second and third best estimates. + f.estimates[0] = f.estimates[1] + f.estimates[1] = f.estimates[2] + f.estimates[2] = entry[V, T]{newSample, newTime} + // Need to iterate one more time. Check if the new best estimate is + // outside the window as well, since it may also have been recorded a + // long time ago. Don't need to iterate once more since we cover that + // case at the beginning of the method. + if newTime-f.estimates[0].time > f.windowLength { + f.estimates[0] = f.estimates[1] + f.estimates[1] = f.estimates[2] + } + return + } + if f.comparator(f.estimates[1].sample, f.estimates[0].sample) == 0 && + newTime-f.estimates[1].time > f.windowLength/4 { + // A quarter of the window has passed without a better sample, so the + // second-best estimate is taken from the second quarter of the window. + f.estimates[1] = entry[V, T]{newSample, newTime} + f.estimates[2] = f.estimates[1] + return + } + + if f.comparator(f.estimates[2].sample, f.estimates[1].sample) == 0 && + newTime-f.estimates[2].time > f.windowLength/2 { + // We've passed a half of the window without a better estimate, so take + // a third-best estimate from the second half of the window. + f.estimates[2] = entry[V, T]{newSample, newTime} + } +} + +// Resets all estimates to new sample. +func (f *WindowedFilter[V, T]) Reset(newSample V, newTime T) { + f.estimates[2] = entry[V, T]{newSample, newTime} + f.estimates[1] = f.estimates[2] + f.estimates[0] = f.estimates[1] +} + +func (f *WindowedFilter[V, T]) Clear() { + f.estimates = make([]entry[V, T], 3, 3) +} diff --git a/protocol/hysteria2/internal/congestion/brutal/brutal.go b/protocol/hysteria2/internal/congestion/brutal/brutal.go new file mode 100644 index 0000000..b353090 --- /dev/null +++ b/protocol/hysteria2/internal/congestion/brutal/brutal.go @@ -0,0 +1,185 @@ +package brutal + +import ( + "fmt" + "os" + "strconv" + "time" + + "github.com/daeuniverse/outbound/protocol/hysteria2/internal/congestion/common" + + "github.com/daeuniverse/quic-go/congestion" +) + +const ( + pktInfoSlotCount = 5 // slot index is based on seconds, so this is basically how many seconds we sample + minSampleCount = 50 + minAckRate = 0.8 + congestionWindowMultiplier = 2 + + debugEnv = "HYSTERIA_BRUTAL_DEBUG" + debugPrintInterval = 2 +) + +var _ congestion.CongestionControl = &BrutalSender{} + +type BrutalSender struct { + rttStats congestion.RTTStatsProvider + bps congestion.ByteCount + maxDatagramSize congestion.ByteCount + pacer *common.Pacer + + pktInfoSlots [pktInfoSlotCount]pktInfo + ackRate float64 + + debug bool + lastAckPrintTimestamp int64 +} + +type pktInfo struct { + Timestamp int64 + AckCount uint64 + LossCount uint64 +} + +func NewBrutalSender(bps uint64) *BrutalSender { + debug, _ := strconv.ParseBool(os.Getenv(debugEnv)) + bs := &BrutalSender{ + bps: congestion.ByteCount(bps), + maxDatagramSize: congestion.InitialPacketSizeIPv4, + ackRate: 1, + debug: debug, + } + bs.pacer = common.NewPacer(func() congestion.ByteCount { + return congestion.ByteCount(float64(bs.bps) / bs.ackRate) + }) + return bs +} + +func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) { + b.rttStats = rttStats +} + +func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time { + return b.pacer.TimeUntilSend() +} + +func (b *BrutalSender) HasPacingBudget(now time.Time) bool { + return b.pacer.Budget(now) >= b.maxDatagramSize +} + +func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool { + return bytesInFlight <= b.GetCongestionWindow() +} + +func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount { + rtt := b.rttStats.SmoothedRTT() + if rtt <= 0 { + return 10240 + } + cwnd := congestion.ByteCount(float64(b.bps) * rtt.Seconds() * congestionWindowMultiplier / b.ackRate) + if cwnd < b.maxDatagramSize { + cwnd = b.maxDatagramSize + } + return cwnd +} + +func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount, + packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool, +) { + b.pacer.SentPacket(sentTime, bytes) +} + +func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount, + priorInFlight congestion.ByteCount, eventTime time.Time, +) { + // Stub +} + +func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes congestion.ByteCount, + priorInFlight congestion.ByteCount, +) { + // Stub +} + +func (b *BrutalSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) { + currentTimestamp := eventTime.Unix() + slot := currentTimestamp % pktInfoSlotCount + if b.pktInfoSlots[slot].Timestamp == currentTimestamp { + b.pktInfoSlots[slot].LossCount += uint64(len(lostPackets)) + b.pktInfoSlots[slot].AckCount += uint64(len(ackedPackets)) + } else { + // uninitialized slot or too old, reset + b.pktInfoSlots[slot].Timestamp = currentTimestamp + b.pktInfoSlots[slot].AckCount = uint64(len(ackedPackets)) + b.pktInfoSlots[slot].LossCount = uint64(len(lostPackets)) + } + b.updateAckRate(currentTimestamp) +} + +func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) { + b.maxDatagramSize = size + b.pacer.SetMaxDatagramSize(size) + if b.debug { + b.debugPrint("SetMaxDatagramSize: %d", size) + } +} + +func (b *BrutalSender) updateAckRate(currentTimestamp int64) { + minTimestamp := currentTimestamp - pktInfoSlotCount + var ackCount, lossCount uint64 + for _, info := range b.pktInfoSlots { + if info.Timestamp < minTimestamp { + continue + } + ackCount += info.AckCount + lossCount += info.LossCount + } + if ackCount+lossCount < minSampleCount { + b.ackRate = 1 + if b.canPrintAckRate(currentTimestamp) { + b.lastAckPrintTimestamp = currentTimestamp + b.debugPrint("Not enough samples (total=%d, ack=%d, loss=%d, rtt=%d)", + ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) + } + return + } + rate := float64(ackCount) / float64(ackCount+lossCount) + if rate < minAckRate { + b.ackRate = minAckRate + if b.canPrintAckRate(currentTimestamp) { + b.lastAckPrintTimestamp = currentTimestamp + b.debugPrint("ACK rate too low: %.2f, clamped to %.2f (total=%d, ack=%d, loss=%d, rtt=%d)", + rate, minAckRate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) + } + return + } + b.ackRate = rate + if b.canPrintAckRate(currentTimestamp) { + b.lastAckPrintTimestamp = currentTimestamp + b.debugPrint("ACK rate: %.2f (total=%d, ack=%d, loss=%d, rtt=%d)", + rate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) + } +} + +func (b *BrutalSender) InSlowStart() bool { + return false +} + +func (b *BrutalSender) InRecovery() bool { + return false +} + +func (b *BrutalSender) MaybeExitSlowStart() {} + +func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {} + +func (b *BrutalSender) canPrintAckRate(currentTimestamp int64) bool { + return b.debug && currentTimestamp-b.lastAckPrintTimestamp >= debugPrintInterval +} + +func (b *BrutalSender) debugPrint(format string, a ...any) { + fmt.Printf("[BrutalSender] [%s] %s\n", + time.Now().Format("15:04:05"), + fmt.Sprintf(format, a...)) +} diff --git a/protocol/hysteria2/internal/congestion/common/pacer.go b/protocol/hysteria2/internal/congestion/common/pacer.go new file mode 100644 index 0000000..9d55876 --- /dev/null +++ b/protocol/hysteria2/internal/congestion/common/pacer.go @@ -0,0 +1,79 @@ +package common + +import ( + "time" + + "github.com/daeuniverse/quic-go/congestion" +) + +const ( + maxBurstPackets = 10 + maxBurstPacingDelayMultiplier = 4 +) + +// Pacer implements a token bucket pacing algorithm. +type Pacer struct { + budgetAtLastSent congestion.ByteCount + maxDatagramSize congestion.ByteCount + lastSentTime time.Time + getBandwidth func() congestion.ByteCount // in bytes/s +} + +func NewPacer(getBandwidth func() congestion.ByteCount) *Pacer { + p := &Pacer{ + budgetAtLastSent: maxBurstPackets * congestion.InitialPacketSizeIPv4, + maxDatagramSize: congestion.InitialPacketSizeIPv4, + getBandwidth: getBandwidth, + } + return p +} + +func (p *Pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) { + budget := p.Budget(sendTime) + if size > budget { + p.budgetAtLastSent = 0 + } else { + p.budgetAtLastSent = budget - size + } + p.lastSentTime = sendTime +} + +func (p *Pacer) Budget(now time.Time) congestion.ByteCount { + if p.lastSentTime.IsZero() { + return p.maxBurstSize() + } + budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 + if budget < 0 { // protect against overflows + budget = congestion.ByteCount(1<<62 - 1) + } + return min(p.maxBurstSize(), budget) +} + +func (p *Pacer) maxBurstSize() congestion.ByteCount { + return max( + congestion.ByteCount((maxBurstPacingDelayMultiplier*congestion.MinPacingDelay).Nanoseconds())*p.getBandwidth()/1e9, + maxBurstPackets*p.maxDatagramSize, + ) +} + +// TimeUntilSend returns when the next packet should be sent. +// It returns the zero value of time.Time if a packet can be sent immediately. +func (p *Pacer) TimeUntilSend() time.Time { + if p.budgetAtLastSent >= p.maxDatagramSize { + return time.Time{} + } + diff := 1e9 * uint64(p.maxDatagramSize-p.budgetAtLastSent) + bw := uint64(p.getBandwidth()) + // We might need to round up this value. + // Otherwise, we might have a budget (slightly) smaller than the datagram size when the timer expires. + d := diff / bw + // this is effectively a math.Ceil, but using only integer math + if diff%bw > 0 { + d++ + } + return p.lastSentTime.Add(max(congestion.MinPacingDelay, time.Duration(d)*time.Nanosecond)) +} + +func (p *Pacer) SetMaxDatagramSize(s congestion.ByteCount) { + p.maxDatagramSize = s +} diff --git a/protocol/hysteria2/internal/congestion/utils.go b/protocol/hysteria2/internal/congestion/utils.go new file mode 100644 index 0000000..99a562a --- /dev/null +++ b/protocol/hysteria2/internal/congestion/utils.go @@ -0,0 +1,18 @@ +package congestion + +import ( + "github.com/daeuniverse/outbound/protocol/hysteria2/internal/congestion/bbr" + "github.com/daeuniverse/outbound/protocol/hysteria2/internal/congestion/brutal" + "github.com/daeuniverse/quic-go" +) + +func UseBBR(conn quic.Connection) { + conn.SetCongestionControl(bbr.NewBbrSender( + bbr.DefaultClock{}, + bbr.GetInitialPacketSize(conn.RemoteAddr()), + )) +} + +func UseBrutal(conn quic.Connection, tx uint64) { + conn.SetCongestionControl(brutal.NewBrutalSender(tx)) +} diff --git a/protocol/hysteria2/internal/frag/frag.go b/protocol/hysteria2/internal/frag/frag.go new file mode 100644 index 0000000..730dc44 --- /dev/null +++ b/protocol/hysteria2/internal/frag/frag.go @@ -0,0 +1,77 @@ +package frag + +import ( + "github.com/daeuniverse/outbound/protocol/hysteria2/internal/protocol" +) + +func FragUDPMessage(m *protocol.UDPMessage, maxSize int) []protocol.UDPMessage { + if m.Size() <= maxSize { + return []protocol.UDPMessage{*m} + } + fullPayload := m.Data + maxPayloadSize := maxSize - m.HeaderSize() + off := 0 + fragID := uint8(0) + fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up + frags := make([]protocol.UDPMessage, fragCount) + for off < len(fullPayload) { + payloadSize := len(fullPayload) - off + if payloadSize > maxPayloadSize { + payloadSize = maxPayloadSize + } + frag := *m + frag.FragID = fragID + frag.FragCount = fragCount + frag.Data = fullPayload[off : off+payloadSize] + frags[fragID] = frag + off += payloadSize + fragID++ + } + return frags +} + +// Defragger handles the defragmentation of UDP messages. +// The current implementation can only handle one packet ID at a time. +// If another packet arrives before a packet has received all fragments +// in their entirety, any previous state is discarded. +type Defragger struct { + pktID uint16 + frags []*protocol.UDPMessage + count uint8 + size int // data size +} + +func (d *Defragger) Feed(m *protocol.UDPMessage) *protocol.UDPMessage { + if m.FragCount <= 1 { + return m + } + if m.FragID >= m.FragCount { + // wtf is this? + return nil + } + if m.PacketID != d.pktID || m.FragCount != uint8(len(d.frags)) { + // new message, clear previous state + d.pktID = m.PacketID + d.frags = make([]*protocol.UDPMessage, m.FragCount) + d.frags[m.FragID] = m + d.count = 1 + d.size = len(m.Data) + } else if d.frags[m.FragID] == nil { + d.frags[m.FragID] = m + d.count++ + d.size += len(m.Data) + if int(d.count) == len(d.frags) { + // all fragments received, assemble + data := make([]byte, d.size) + off := 0 + for _, frag := range d.frags { + off += copy(data[off:], frag.Data) + } + m.Data = data + m.FragID = 0 + m.FragCount = 1 + return m + } + } + return nil +} diff --git a/protocol/hysteria2/internal/frag/frag_test.go b/protocol/hysteria2/internal/frag/frag_test.go new file mode 100644 index 0000000..71ba6e9 --- /dev/null +++ b/protocol/hysteria2/internal/frag/frag_test.go @@ -0,0 +1,336 @@ +package frag + +import ( + "reflect" + "testing" + + "github.com/daeuniverse/outbound/protocol/hysteria2/internal/protocol" +) + +func TestFragUDPMessage(t *testing.T) { + type args struct { + m *protocol.UDPMessage + maxSize int + } + tests := []struct { + name string + args args + want []protocol.UDPMessage + }{ + { + "no frag", + args{ + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 123, + FragID: 0, + FragCount: 1, + Addr: "test:123", + Data: []byte("hello"), + }, + 100, + }, + []protocol.UDPMessage{ + { + SessionID: 123, + PacketID: 123, + FragID: 0, + FragCount: 1, + Addr: "test:123", + Data: []byte("hello"), + }, + }, + }, + { + "2 frags", + args{ + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 123, + FragID: 0, + FragCount: 1, + Addr: "test:123", + Data: []byte("hello"), + }, + 20, + }, + []protocol.UDPMessage{ + { + SessionID: 123, + PacketID: 123, + FragID: 0, + FragCount: 2, + Addr: "test:123", + Data: []byte("hel"), + }, + { + SessionID: 123, + PacketID: 123, + FragID: 1, + FragCount: 2, + Addr: "test:123", + Data: []byte("lo"), + }, + }, + }, + { + "4 frags", + args{ + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 123, + FragID: 0, + FragCount: 1, + Addr: "test:123", + Data: []byte("abcdefgh"), + }, + 19, + }, + []protocol.UDPMessage{ + { + SessionID: 123, + PacketID: 123, + FragID: 0, + FragCount: 4, + Addr: "test:123", + Data: []byte("ab"), + }, + { + SessionID: 123, + PacketID: 123, + FragID: 1, + FragCount: 4, + Addr: "test:123", + Data: []byte("cd"), + }, + { + SessionID: 123, + PacketID: 123, + FragID: 2, + FragCount: 4, + Addr: "test:123", + Data: []byte("ef"), + }, + { + SessionID: 123, + PacketID: 123, + FragID: 3, + FragCount: 4, + Addr: "test:123", + Data: []byte("gh"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := FragUDPMessage(tt.args.m, tt.args.maxSize); !reflect.DeepEqual(got, tt.want) { + t.Errorf("FragUDPMessage() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDefragger(t *testing.T) { + type args struct { + m *protocol.UDPMessage + } + tests := []struct { + name string + args args + want *protocol.UDPMessage + }{ + { + "no frag", + args{ + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 987, + FragID: 0, + FragCount: 1, + Addr: "test:123", + Data: []byte("hello"), + }, + }, + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 987, + FragID: 0, + FragCount: 1, + Addr: "test:123", + Data: []byte("hello"), + }, + }, + { + "frag 0 - 1/2", + args{ + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 987, + FragID: 0, + FragCount: 2, + Addr: "test:123", + Data: []byte("hello "), + }, + }, + nil, + }, + { + "frag 0 - 2/2", + args{ + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 987, + FragID: 1, + FragCount: 2, + Addr: "test:123", + Data: []byte("moto"), + }, + }, + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 987, + FragID: 0, + FragCount: 1, + Addr: "test:123", + Data: []byte("hello moto"), + }, + }, + { + "frag 1 - 1/3", + args{ + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 987, + FragID: 0, + FragCount: 3, + Addr: "test:123", + Data: []byte("deco"), + }, + }, + nil, + }, + { + "frag 1 - 2/3", + args{ + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 987, + FragID: 1, + FragCount: 3, + Addr: "test:123", + Data: []byte("*"), + }, + }, + nil, + }, + { + "frag 1 - 3/3", + args{ + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 987, + FragID: 2, + FragCount: 3, + Addr: "test:123", + Data: []byte("27"), + }, + }, + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 987, + FragID: 0, + FragCount: 1, + Addr: "test:123", + Data: []byte("deco*27"), + }, + }, + { + "frag 2 - 1/2", + args{ + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 233, + FragID: 1, + FragCount: 2, + Addr: "test:123", + Data: []byte("shinsekai"), + }, + }, + nil, + }, + { + "frag 3 - 2/2", + args{ + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 244, + FragID: 1, + FragCount: 2, + Addr: "test:123", + Data: []byte("what???"), + }, + }, + nil, + }, + { + "frag 2 - 2/2", + args{ + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 233, + FragID: 1, + FragCount: 2, + Addr: "test:123", + Data: []byte(" annaijo"), + }, + }, + nil, + }, + { + "invalid id", + args{ + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 233, + FragID: 88, + FragCount: 2, + Addr: "test:123", + Data: []byte("shinsekai"), + }, + }, + nil, + }, + { + "frag 2 - 1/2 re", + args{ + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 233, + FragID: 0, + FragCount: 2, + Addr: "test:123", + Data: []byte("shinsekai"), + }, + }, + &protocol.UDPMessage{ + SessionID: 123, + PacketID: 233, + FragID: 0, + FragCount: 1, + Addr: "test:123", + Data: []byte("shinsekai annaijo"), + }, + }, + } + + d := &Defragger{} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := d.Feed(tt.args.m); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Feed() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/protocol/hysteria2/internal/pmtud/avail.go b/protocol/hysteria2/internal/pmtud/avail.go new file mode 100644 index 0000000..cd7afd0 --- /dev/null +++ b/protocol/hysteria2/internal/pmtud/avail.go @@ -0,0 +1,7 @@ +//go:build linux || windows || darwin + +package pmtud + +const ( + DisablePathMTUDiscovery = false +) diff --git a/protocol/hysteria2/internal/pmtud/unavail.go b/protocol/hysteria2/internal/pmtud/unavail.go new file mode 100644 index 0000000..917b973 --- /dev/null +++ b/protocol/hysteria2/internal/pmtud/unavail.go @@ -0,0 +1,13 @@ +//go:build !linux && !windows && !darwin + +package pmtud + +// quic-go's MTU detection is enabled by default on all platforms. +// However, it only actually sets the DF bit on 3 supported platforms (Windows, macOS, Linux). +// As a result, on other platforms, probe packets that should never be fragmented will still +// be fragmented and transmitted. So we're only enabling it for platforms where we've verified +// its functionality for now. + +const ( + DisablePathMTUDiscovery = true +) diff --git a/protocol/hysteria2/internal/protocol/http.go b/protocol/hysteria2/internal/protocol/http.go new file mode 100644 index 0000000..abcc1a4 --- /dev/null +++ b/protocol/hysteria2/internal/protocol/http.go @@ -0,0 +1,68 @@ +package protocol + +import ( + "net/http" + "strconv" +) + +const ( + URLHost = "hysteria" + URLPath = "/auth" + + RequestHeaderAuth = "Hysteria-Auth" + ResponseHeaderUDPEnabled = "Hysteria-UDP" + CommonHeaderCCRX = "Hysteria-CC-RX" + CommonHeaderPadding = "Hysteria-Padding" + + StatusAuthOK = 233 +) + +// AuthRequest is what client sends to server for authentication. +type AuthRequest struct { + Auth string + Rx uint64 // 0 = unknown, client asks server to use bandwidth detection +} + +// AuthResponse is what server sends to client when authentication is passed. +type AuthResponse struct { + UDPEnabled bool + Rx uint64 // 0 = unlimited + RxAuto bool // true = server asks client to use bandwidth detection +} + +func AuthRequestFromHeader(h http.Header) AuthRequest { + rx, _ := strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64) + return AuthRequest{ + Auth: h.Get(RequestHeaderAuth), + Rx: rx, + } +} + +func AuthRequestToHeader(h http.Header, req AuthRequest) { + h.Set(RequestHeaderAuth, req.Auth) + h.Set(CommonHeaderCCRX, strconv.FormatUint(req.Rx, 10)) + h.Set(CommonHeaderPadding, authRequestPadding.String()) +} + +func AuthResponseFromHeader(h http.Header) AuthResponse { + resp := AuthResponse{} + resp.UDPEnabled, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled)) + rxStr := h.Get(CommonHeaderCCRX) + if rxStr == "auto" { + // Special case for server requesting client to use bandwidth detection + resp.RxAuto = true + } else { + resp.Rx, _ = strconv.ParseUint(rxStr, 10, 64) + } + return resp +} + +func AuthResponseToHeader(h http.Header, resp AuthResponse) { + h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(resp.UDPEnabled)) + if resp.RxAuto { + h.Set(CommonHeaderCCRX, "auto") + } else { + h.Set(CommonHeaderCCRX, strconv.FormatUint(resp.Rx, 10)) + } + h.Set(CommonHeaderPadding, authResponsePadding.String()) +} diff --git a/protocol/hysteria2/internal/protocol/padding.go b/protocol/hysteria2/internal/protocol/padding.go new file mode 100644 index 0000000..9895cdc --- /dev/null +++ b/protocol/hysteria2/internal/protocol/padding.go @@ -0,0 +1,31 @@ +package protocol + +import ( + "math/rand" +) + +const ( + paddingChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +) + +// padding specifies a half-open range [Min, Max). +type padding struct { + Min int + Max int +} + +func (p padding) String() string { + n := p.Min + rand.Intn(p.Max-p.Min) + bs := make([]byte, n) + for i := range bs { + bs[i] = paddingChars[rand.Intn(len(paddingChars))] + } + return string(bs) +} + +var ( + authRequestPadding = padding{Min: 256, Max: 2048} + authResponsePadding = padding{Min: 256, Max: 2048} + tcpRequestPadding = padding{Min: 64, Max: 512} + tcpResponsePadding = padding{Min: 128, Max: 1024} +) diff --git a/protocol/hysteria2/internal/protocol/proxy.go b/protocol/hysteria2/internal/protocol/proxy.go new file mode 100644 index 0000000..2a98a8e --- /dev/null +++ b/protocol/hysteria2/internal/protocol/proxy.go @@ -0,0 +1,255 @@ +package protocol + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/daeuniverse/outbound/protocol/hysteria2/errors" + + "github.com/daeuniverse/quic-go/quicvarint" +) + +const ( + FrameTypeTCPRequest = 0x401 + + // Max length values are for preventing DoS attacks + + MaxAddressLength = 2048 + MaxMessageLength = 2048 + MaxPaddingLength = 4096 + + MaxUDPSize = 4096 + + maxVarInt1 = 63 + maxVarInt2 = 16383 + maxVarInt4 = 1073741823 + maxVarInt8 = 4611686018427387903 +) + +// TCPRequest format: +// 0x401 (QUIC varint) +// Address length (QUIC varint) +// Address (bytes) +// Padding length (QUIC varint) +// Padding (bytes) + +func ReadTCPRequest(r io.Reader) (string, error) { + bReader := quicvarint.NewReader(r) + addrLen, err := quicvarint.Read(bReader) + if err != nil { + return "", err + } + if addrLen == 0 || addrLen > MaxAddressLength { + return "", errors.ProtocolError{Message: "invalid address length"} + } + addrBuf := make([]byte, addrLen) + _, err = io.ReadFull(r, addrBuf) + if err != nil { + return "", err + } + paddingLen, err := quicvarint.Read(bReader) + if err != nil { + return "", err + } + if paddingLen > MaxPaddingLength { + return "", errors.ProtocolError{Message: "invalid padding length"} + } + if paddingLen > 0 { + _, err = io.CopyN(io.Discard, r, int64(paddingLen)) + if err != nil { + return "", err + } + } + return string(addrBuf), nil +} + +func WriteTCPRequest(w io.Writer, addr string) error { + padding := tcpRequestPadding.String() + paddingLen := len(padding) + addrLen := len(addr) + sz := int(quicvarint.Len(FrameTypeTCPRequest)) + + int(quicvarint.Len(uint64(addrLen))) + addrLen + + int(quicvarint.Len(uint64(paddingLen))) + paddingLen + buf := make([]byte, sz) + i := varintPut(buf, FrameTypeTCPRequest) + i += varintPut(buf[i:], uint64(addrLen)) + i += copy(buf[i:], addr) + i += varintPut(buf[i:], uint64(paddingLen)) + copy(buf[i:], padding) + _, err := w.Write(buf) + return err +} + +// TCPResponse format: +// Status (byte, 0=ok, 1=error) +// Message length (QUIC varint) +// Message (bytes) +// Padding length (QUIC varint) +// Padding (bytes) + +func ReadTCPResponse(r io.Reader) (bool, string, error) { + var status [1]byte + if _, err := io.ReadFull(r, status[:]); err != nil { + return false, "", err + } + bReader := quicvarint.NewReader(r) + msgLen, err := quicvarint.Read(bReader) + if err != nil { + return false, "", err + } + if msgLen > MaxMessageLength { + return false, "", errors.ProtocolError{Message: "invalid message length"} + } + var msgBuf []byte + // No message is fine + if msgLen > 0 { + msgBuf = make([]byte, msgLen) + _, err = io.ReadFull(r, msgBuf) + if err != nil { + return false, "", err + } + } + paddingLen, err := quicvarint.Read(bReader) + if err != nil { + return false, "", err + } + if paddingLen > MaxPaddingLength { + return false, "", errors.ProtocolError{Message: "invalid padding length"} + } + if paddingLen > 0 { + _, err = io.CopyN(io.Discard, r, int64(paddingLen)) + if err != nil { + return false, "", err + } + } + return status[0] == 0, string(msgBuf), nil +} + +func WriteTCPResponse(w io.Writer, ok bool, msg string) error { + padding := tcpResponsePadding.String() + paddingLen := len(padding) + msgLen := len(msg) + sz := 1 + int(quicvarint.Len(uint64(msgLen))) + msgLen + + int(quicvarint.Len(uint64(paddingLen))) + paddingLen + buf := make([]byte, sz) + if ok { + buf[0] = 0 + } else { + buf[0] = 1 + } + i := varintPut(buf[1:], uint64(msgLen)) + i += copy(buf[1+i:], msg) + i += varintPut(buf[1+i:], uint64(paddingLen)) + copy(buf[1+i:], padding) + _, err := w.Write(buf) + return err +} + +// UDPMessage format: +// Session ID (uint32 BE) +// Packet ID (uint16 BE) +// Fragment ID (uint8) +// Fragment count (uint8) +// Address length (QUIC varint) +// Address (bytes) +// Data... + +type UDPMessage struct { + SessionID uint32 // 4 + PacketID uint16 // 2 + FragID uint8 // 1 + FragCount uint8 // 1 + Addr string // varint + bytes + Data []byte +} + +func (m *UDPMessage) HeaderSize() int { + lAddr := len(m.Addr) + return 4 + 2 + 1 + 1 + int(quicvarint.Len(uint64(lAddr))) + lAddr +} + +func (m *UDPMessage) Size() int { + return m.HeaderSize() + len(m.Data) +} + +func (m *UDPMessage) Serialize(buf []byte) int { + // Make sure the buffer is big enough + if len(buf) < m.Size() { + return -1 + } + binary.BigEndian.PutUint32(buf, m.SessionID) + binary.BigEndian.PutUint16(buf[4:], m.PacketID) + buf[6] = m.FragID + buf[7] = m.FragCount + i := varintPut(buf[8:], uint64(len(m.Addr))) + i += copy(buf[8+i:], m.Addr) + i += copy(buf[8+i:], m.Data) + return 8 + i +} + +func ParseUDPMessage(msg []byte) (*UDPMessage, error) { + m := &UDPMessage{} + buf := bytes.NewBuffer(msg) + if err := binary.Read(buf, binary.BigEndian, &m.SessionID); err != nil { + return nil, err + } + if err := binary.Read(buf, binary.BigEndian, &m.PacketID); err != nil { + return nil, err + } + if err := binary.Read(buf, binary.BigEndian, &m.FragID); err != nil { + return nil, err + } + if err := binary.Read(buf, binary.BigEndian, &m.FragCount); err != nil { + return nil, err + } + lAddr, err := quicvarint.Read(buf) + if err != nil { + return nil, err + } + if lAddr == 0 || lAddr > MaxMessageLength { + return nil, errors.ProtocolError{Message: "invalid address length"} + } + bs := buf.Bytes() + if len(bs) <= int(lAddr) { + // We use <= instead of < here as we expect at least one byte of data after the address + return nil, errors.ProtocolError{Message: "invalid message length"} + } + m.Addr = string(bs[:lAddr]) + m.Data = bs[lAddr:] + return m, nil +} + +// varintPut is like quicvarint.Append, but instead of appending to a slice, +// it writes to a fixed-size buffer. Returns the number of bytes written. +func varintPut(b []byte, i uint64) int { + if i <= maxVarInt1 { + b[0] = uint8(i) + return 1 + } + if i <= maxVarInt2 { + b[0] = uint8(i>>8) | 0x40 + b[1] = uint8(i) + return 2 + } + if i <= maxVarInt4 { + b[0] = uint8(i>>24) | 0x80 + b[1] = uint8(i >> 16) + b[2] = uint8(i >> 8) + b[3] = uint8(i) + return 4 + } + if i <= maxVarInt8 { + b[0] = uint8(i>>56) | 0xc0 + b[1] = uint8(i >> 48) + b[2] = uint8(i >> 40) + b[3] = uint8(i >> 32) + b[4] = uint8(i >> 24) + b[5] = uint8(i >> 16) + b[6] = uint8(i >> 8) + b[7] = uint8(i) + return 8 + } + panic(fmt.Sprintf("%#x doesn't fit into 62 bits", i)) +} diff --git a/protocol/hysteria2/internal/protocol/proxy_test.go b/protocol/hysteria2/internal/protocol/proxy_test.go new file mode 100644 index 0000000..111c615 --- /dev/null +++ b/protocol/hysteria2/internal/protocol/proxy_test.go @@ -0,0 +1,317 @@ +package protocol + +import ( + "bytes" + "reflect" + "strings" + "testing" +) + +func TestUDPMessage(t *testing.T) { + t.Run("buffer too small", func(t *testing.T) { + // Make sure Serialize returns -1 when the buffer is too small. + tBuf := make([]byte, 20) + if (&UDPMessage{ + SessionID: 66, + PacketID: 77, + FragID: 2, + FragCount: 5, + Addr: "random_addr", + Data: []byte("random_data"), + }).Serialize(tBuf) != -1 { + t.Error("Serialize() did not return -1 when the buffer was too small") + } + }) + + type fields struct { + SessionID uint32 + PacketID uint16 + FragID uint8 + FragCount uint8 + Addr string + Data []byte + } + tests := []struct { + name string + fields fields + want []byte + }{ + { + name: "test 1", + fields: fields{ + SessionID: 1, + PacketID: 1, + FragID: 0, + FragCount: 1, + Addr: "example.com:80", + Data: []byte("GET /nothing HTTP/1.1\r\n"), + }, + want: []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x1, 0x0, 0x1, 0xe, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x3a, 0x38, 0x30, 0x47, 0x45, 0x54, 0x20, 0x2f, 0x6e, 0x6f, 0x74, 0x68, 0x69, 0x6e, 0x67, 0x20, 0x48, 0x54, 0x54, 0x50, 0x2f, 0x31, 0x2e, 0x31, 0xd, 0xa}, + }, + { + name: "test 2", + fields: fields{ + SessionID: 1329655244, + Addr: "some_random_goofy_ahh_address_which_is_very_long_some_random_goofy_ahh_address_which_is_very_long_some_random_goofy_ahh_address_which_is_very_long_some_random_goofy_ahh_address_which_is_very_long_some_random_goofy_ahh_address_which_is_very_long_some_random_goofy_ahh_address_which_is_very_long_some_random_goofy_ahh_address_which_is_very_long_some_random_goofy_ahh_address_which_is_very_long_some_random_goofy_ahh_address_which_is_very_long_some_random_goofy_ahh_address_which_is_very_long:9000", + PacketID: 62233, + FragID: 8, + FragCount: 19, + Data: []byte("God is great, beer is good, and people are crazy."), + }, + want: []byte{0x4f, 0x40, 0xed, 0xcc, 0xf3, 0x19, 0x8, 0x13, 0x41, 0xee, 0x73, 0x6f, 0x6d, 0x65, 0x5f, 0x72, 0x61, 0x6e, 0x64, 0x6f, 0x6d, 0x5f, 0x67, 0x6f, 0x6f, 0x66, 0x79, 0x5f, 0x61, 0x68, 0x68, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x5f, 0x77, 0x68, 0x69, 0x63, 0x68, 0x5f, 0x69, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x79, 0x5f, 0x6c, 0x6f, 0x6e, 0x67, 0x5f, 0x73, 0x6f, 0x6d, 0x65, 0x5f, 0x72, 0x61, 0x6e, 0x64, 0x6f, 0x6d, 0x5f, 0x67, 0x6f, 0x6f, 0x66, 0x79, 0x5f, 0x61, 0x68, 0x68, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x5f, 0x77, 0x68, 0x69, 0x63, 0x68, 0x5f, 0x69, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x79, 0x5f, 0x6c, 0x6f, 0x6e, 0x67, 0x5f, 0x73, 0x6f, 0x6d, 0x65, 0x5f, 0x72, 0x61, 0x6e, 0x64, 0x6f, 0x6d, 0x5f, 0x67, 0x6f, 0x6f, 0x66, 0x79, 0x5f, 0x61, 0x68, 0x68, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x5f, 0x77, 0x68, 0x69, 0x63, 0x68, 0x5f, 0x69, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x79, 0x5f, 0x6c, 0x6f, 0x6e, 0x67, 0x5f, 0x73, 0x6f, 0x6d, 0x65, 0x5f, 0x72, 0x61, 0x6e, 0x64, 0x6f, 0x6d, 0x5f, 0x67, 0x6f, 0x6f, 0x66, 0x79, 0x5f, 0x61, 0x68, 0x68, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x5f, 0x77, 0x68, 0x69, 0x63, 0x68, 0x5f, 0x69, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x79, 0x5f, 0x6c, 0x6f, 0x6e, 0x67, 0x5f, 0x73, 0x6f, 0x6d, 0x65, 0x5f, 0x72, 0x61, 0x6e, 0x64, 0x6f, 0x6d, 0x5f, 0x67, 0x6f, 0x6f, 0x66, 0x79, 0x5f, 0x61, 0x68, 0x68, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x5f, 0x77, 0x68, 0x69, 0x63, 0x68, 0x5f, 0x69, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x79, 0x5f, 0x6c, 0x6f, 0x6e, 0x67, 0x5f, 0x73, 0x6f, 0x6d, 0x65, 0x5f, 0x72, 0x61, 0x6e, 0x64, 0x6f, 0x6d, 0x5f, 0x67, 0x6f, 0x6f, 0x66, 0x79, 0x5f, 0x61, 0x68, 0x68, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x5f, 0x77, 0x68, 0x69, 0x63, 0x68, 0x5f, 0x69, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x79, 0x5f, 0x6c, 0x6f, 0x6e, 0x67, 0x5f, 0x73, 0x6f, 0x6d, 0x65, 0x5f, 0x72, 0x61, 0x6e, 0x64, 0x6f, 0x6d, 0x5f, 0x67, 0x6f, 0x6f, 0x66, 0x79, 0x5f, 0x61, 0x68, 0x68, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x5f, 0x77, 0x68, 0x69, 0x63, 0x68, 0x5f, 0x69, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x79, 0x5f, 0x6c, 0x6f, 0x6e, 0x67, 0x5f, 0x73, 0x6f, 0x6d, 0x65, 0x5f, 0x72, 0x61, 0x6e, 0x64, 0x6f, 0x6d, 0x5f, 0x67, 0x6f, 0x6f, 0x66, 0x79, 0x5f, 0x61, 0x68, 0x68, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x5f, 0x77, 0x68, 0x69, 0x63, 0x68, 0x5f, 0x69, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x79, 0x5f, 0x6c, 0x6f, 0x6e, 0x67, 0x5f, 0x73, 0x6f, 0x6d, 0x65, 0x5f, 0x72, 0x61, 0x6e, 0x64, 0x6f, 0x6d, 0x5f, 0x67, 0x6f, 0x6f, 0x66, 0x79, 0x5f, 0x61, 0x68, 0x68, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x5f, 0x77, 0x68, 0x69, 0x63, 0x68, 0x5f, 0x69, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x79, 0x5f, 0x6c, 0x6f, 0x6e, 0x67, 0x5f, 0x73, 0x6f, 0x6d, 0x65, 0x5f, 0x72, 0x61, 0x6e, 0x64, 0x6f, 0x6d, 0x5f, 0x67, 0x6f, 0x6f, 0x66, 0x79, 0x5f, 0x61, 0x68, 0x68, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x5f, 0x77, 0x68, 0x69, 0x63, 0x68, 0x5f, 0x69, 0x73, 0x5f, 0x76, 0x65, 0x72, 0x79, 0x5f, 0x6c, 0x6f, 0x6e, 0x67, 0x3a, 0x39, 0x30, 0x30, 0x30, 0x47, 0x6f, 0x64, 0x20, 0x69, 0x73, 0x20, 0x67, 0x72, 0x65, 0x61, 0x74, 0x2c, 0x20, 0x62, 0x65, 0x65, 0x72, 0x20, 0x69, 0x73, 0x20, 0x67, 0x6f, 0x6f, 0x64, 0x2c, 0x20, 0x61, 0x6e, 0x64, 0x20, 0x70, 0x65, 0x6f, 0x70, 0x6c, 0x65, 0x20, 0x61, 0x72, 0x65, 0x20, 0x63, 0x72, 0x61, 0x7a, 0x79, 0x2e}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &UDPMessage{ + SessionID: tt.fields.SessionID, + Addr: tt.fields.Addr, + PacketID: tt.fields.PacketID, + FragID: tt.fields.FragID, + FragCount: tt.fields.FragCount, + Data: tt.fields.Data, + } + // Serialize + buf := make([]byte, MaxUDPSize) + n := m.Serialize(buf) + if got := buf[:n]; !reflect.DeepEqual(got, tt.want) { + t.Errorf("Serialize() = %v, want %v", got, tt.want) + } + // Parse back + if m2, err := ParseUDPMessage(tt.want); err != nil { + t.Errorf("ParseUDPMessage() error = %v", err) + } else { + if !reflect.DeepEqual(m2, m) { + t.Errorf("ParseUDPMessage() = %v, want %v", m2, m) + } + } + }) + } +} + +// TestUDPMessageMalformed is to make sure ParseUDPMessage() fails (but not panic) on malformed data. +func TestUDPMessageMalformed(t *testing.T) { + tests := []struct { + name string + data []byte + }{ + { + name: "empty", + data: []byte{}, + }, + { + name: "zeroes 1", + data: []byte{0, 0, 0, 0}, + }, + { + name: "zeroes 2", + data: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + { + name: "incomplete 1", + data: []byte{0x66, 0xCC, 0xFF, 0xFF, 0x11, 0x22, 0x33, 0x44, 0x55}, + }, + { + name: "incomplete 2", + data: []byte{0x66, 0xCC, 0xFF, 0xFF, 0x11, 0x22, 0x33, 0x44, 0x90, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if _, err := ParseUDPMessage(tt.data); err == nil { + t.Errorf("ParseUDPMessage() should fail") + } + }) + } +} + +func TestReadTCPRequest(t *testing.T) { + tests := []struct { + name string + data []byte + want string + wantErr bool + }{ + { + name: "normal no padding", + data: []byte("\x0egoogle.com:443\x00"), + want: "google.com:443", + wantErr: false, + }, + { + name: "normal with padding", + data: []byte("\x0bholy.cc:443\x02gg"), + want: "holy.cc:443", + wantErr: false, + }, + { + name: "incomplete 1", + data: []byte("\x0bhoho"), + want: "", + wantErr: true, + }, + { + name: "incomplete 2", + data: []byte("\x0bholy.cc:443\x05x"), + want: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := bytes.NewReader(tt.data) + got, err := ReadTCPRequest(r) + if (err != nil) != tt.wantErr { + t.Errorf("ReadTCPRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ReadTCPRequest() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWriteTCPRequest(t *testing.T) { + tests := []struct { + name string + addr string + wantW string // Just a prefix, we don't care about the padding + wantErr bool + }{ + { + name: "normal 1", + addr: "google.com:443", + wantW: "\x44\x01\x0egoogle.com:443", + wantErr: false, + }, + { + name: "normal 2", + addr: "client-api.arkoselabs.com:8080", + wantW: "\x44\x01\x1eclient-api.arkoselabs.com:8080", + wantErr: false, + }, + { + name: "empty", + addr: "", + wantW: "\x44\x01\x00", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &bytes.Buffer{} + err := WriteTCPRequest(w, tt.addr) + if (err != nil) != tt.wantErr { + t.Errorf("WriteTCPRequest() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) { + t.Errorf("WriteTCPRequest() gotW = %v, want %v", gotW, tt.wantW) + } + }) + } +} + +func TestReadTCPResponse(t *testing.T) { + tests := []struct { + name string + data []byte + want bool + want1 string + wantErr bool + }{ + { + name: "normal ok no padding", + data: []byte("\x00\x0bhello world\x00"), + want: true, + want1: "hello world", + wantErr: false, + }, + { + name: "normal error with padding", + data: []byte("\x01\x06stop!!\x05xxxxx"), + want1: "stop!!", + wantErr: false, + }, + { + name: "normal ok no message with padding", + data: []byte("\x01\x00\x05xxxxx"), + want1: "", + wantErr: false, + }, + { + name: "incomplete 1", + data: []byte("\x00\x0bhoho"), + want1: "", + wantErr: true, + }, + { + name: "incomplete 2", + data: []byte("\x01\x05jesus\x05x"), + want1: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := bytes.NewReader(tt.data) + got, got1, err := ReadTCPResponse(r) + if (err != nil) != tt.wantErr { + t.Errorf("ReadTCPResponse() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ReadTCPResponse() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("ReadTCPResponse() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestWriteTCPResponse(t *testing.T) { + type args struct { + ok bool + msg string + } + tests := []struct { + name string + args args + wantW string // Just a prefix, we don't care about the padding + wantErr bool + }{ + { + name: "normal ok", + args: args{ok: true, msg: "hello world"}, + wantW: "\x00\x0bhello world", + wantErr: false, + }, + { + name: "normal error", + args: args{ok: false, msg: "stop!!"}, + wantW: "\x01\x06stop!!", + wantErr: false, + }, + { + name: "empty", + args: args{ok: true, msg: ""}, + wantW: "\x00\x00", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &bytes.Buffer{} + err := WriteTCPResponse(w, tt.args.ok, tt.args.msg) + if (err != nil) != tt.wantErr { + t.Errorf("WriteTCPResponse() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) { + t.Errorf("WriteTCPResponse() gotW = %v, want %v", gotW, tt.wantW) + } + }) + } +} diff --git a/protocol/hysteria2/internal/utils/atomic.go b/protocol/hysteria2/internal/utils/atomic.go new file mode 100644 index 0000000..e3c3d97 --- /dev/null +++ b/protocol/hysteria2/internal/utils/atomic.go @@ -0,0 +1,24 @@ +package utils + +import ( + "sync/atomic" + "time" +) + +type AtomicTime struct { + v atomic.Value +} + +func NewAtomicTime(t time.Time) *AtomicTime { + a := &AtomicTime{} + a.Set(t) + return a +} + +func (t *AtomicTime) Set(new time.Time) { + t.v.Store(new) +} + +func (t *AtomicTime) Get() time.Time { + return t.v.Load().(time.Time) +} diff --git a/protocol/hysteria2/internal/utils/qstream.go b/protocol/hysteria2/internal/utils/qstream.go new file mode 100644 index 0000000..cfb23ed --- /dev/null +++ b/protocol/hysteria2/internal/utils/qstream.go @@ -0,0 +1,62 @@ +package utils + +import ( + "context" + "time" + + "github.com/daeuniverse/quic-go" +) + +// QStream is a wrapper of quic.Stream that handles Close() in a way that +// makes more sense to us. By default, quic.Stream's Close() only closes +// the write side of the stream, not the read side. And if there is unread +// data, the stream is not really considered closed until either the data +// is drained or CancelRead() is called. +// References: +// - https://github.com/libp2p/go-libp2p/blob/master/p2p/transport/quic/stream.go +// - https://github.com/quic-go/quic-go/issues/3558 +// - https://github.com/quic-go/quic-go/issues/1599 +type QStream struct { + Stream quic.Stream +} + +func (s *QStream) StreamID() quic.StreamID { + return s.Stream.StreamID() +} + +func (s *QStream) Read(p []byte) (n int, err error) { + return s.Stream.Read(p) +} + +func (s *QStream) CancelRead(code quic.StreamErrorCode) { + s.Stream.CancelRead(code) +} + +func (s *QStream) SetReadDeadline(t time.Time) error { + return s.Stream.SetReadDeadline(t) +} + +func (s *QStream) Write(p []byte) (n int, err error) { + return s.Stream.Write(p) +} + +func (s *QStream) Close() error { + s.Stream.CancelRead(0) + return s.Stream.Close() +} + +func (s *QStream) CancelWrite(code quic.StreamErrorCode) { + s.Stream.CancelWrite(code) +} + +func (s *QStream) Context() context.Context { + return s.Stream.Context() +} + +func (s *QStream) SetWriteDeadline(t time.Time) error { + return s.Stream.SetWriteDeadline(t) +} + +func (s *QStream) SetDeadline(t time.Time) error { + return s.Stream.SetDeadline(t) +} From a79104dac1866febf86c149ff632c4184afff1fd Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Sun, 2 Jun 2024 00:50:13 +0800 Subject: [PATCH 03/13] feat: update udp connection interface --- protocol/hysteria2/client/client.go | 19 +++++------- protocol/hysteria2/client/reconnect.go | 7 +++-- protocol/hysteria2/client/udp.go | 41 ++++++++++++++++++-------- 3 files changed, 40 insertions(+), 27 deletions(-) diff --git a/protocol/hysteria2/client/client.go b/protocol/hysteria2/client/client.go index 5ef5a98..e4caa6e 100644 --- a/protocol/hysteria2/client/client.go +++ b/protocol/hysteria2/client/client.go @@ -8,6 +8,7 @@ import ( "net/url" "time" + "github.com/daeuniverse/outbound/netproxy" coreErrs "github.com/daeuniverse/outbound/protocol/hysteria2/errors" "github.com/daeuniverse/outbound/protocol/hysteria2/internal/congestion" "github.com/daeuniverse/outbound/protocol/hysteria2/internal/protocol" @@ -23,14 +24,8 @@ const ( ) type Client interface { - TCP(addr string) (net.Conn, error) - UDP() (HyUDPConn, error) - Close() error -} - -type HyUDPConn interface { - Receive() ([]byte, string, error) - Send([]byte, string) error + TCP(addr string) (netproxy.Conn, error) + UDP() (netproxy.Conn, error) Close() error } @@ -56,7 +51,7 @@ func NewClient(config *Config) (Client, *HandshakeInfo, error) { type clientImpl struct { config *Config - pktConn net.PacketConn + pktConn netproxy.PacketConn conn quic.Connection udpSM *udpSessionManager @@ -148,7 +143,7 @@ func (c *clientImpl) connect() (*HandshakeInfo, error) { } _ = resp.Body.Close() - c.pktConn = pktConn + c.pktConn = pktConn.(netproxy.PacketConn) c.conn = conn if authResp.UDPEnabled { c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn}) @@ -168,7 +163,7 @@ func (c *clientImpl) openStream() (quic.Stream, error) { return &utils.QStream{Stream: stream}, nil } -func (c *clientImpl) TCP(addr string) (net.Conn, error) { +func (c *clientImpl) TCP(addr string) (netproxy.Conn, error) { stream, err := c.openStream() if err != nil { return nil, wrapIfConnectionClosed(err) @@ -208,7 +203,7 @@ func (c *clientImpl) TCP(addr string) (net.Conn, error) { }, nil } -func (c *clientImpl) UDP() (HyUDPConn, error) { +func (c *clientImpl) UDP() (netproxy.Conn, error) { if c.udpSM == nil { return nil, coreErrs.DialError{Message: "UDP not enabled"} } diff --git a/protocol/hysteria2/client/reconnect.go b/protocol/hysteria2/client/reconnect.go index 659397c..00586ee 100644 --- a/protocol/hysteria2/client/reconnect.go +++ b/protocol/hysteria2/client/reconnect.go @@ -4,6 +4,7 @@ import ( "net" "sync" + "github.com/daeuniverse/outbound/netproxy" coreErrs "github.com/daeuniverse/outbound/protocol/hysteria2/errors" ) @@ -89,7 +90,7 @@ func (rc *reconnectableClientImpl) clientDo(f func(Client) (interface{}, error)) return ret, err } -func (rc *reconnectableClientImpl) TCP(addr string) (net.Conn, error) { +func (rc *reconnectableClientImpl) TCP(addr string) (netproxy.Conn, error) { if c, err := rc.clientDo(func(client Client) (interface{}, error) { return client.TCP(addr) }); err != nil { @@ -99,13 +100,13 @@ func (rc *reconnectableClientImpl) TCP(addr string) (net.Conn, error) { } } -func (rc *reconnectableClientImpl) UDP() (HyUDPConn, error) { +func (rc *reconnectableClientImpl) UDP() (netproxy.Conn, error) { if c, err := rc.clientDo(func(client Client) (interface{}, error) { return client.UDP() }); err != nil { return nil, err } else { - return c.(HyUDPConn), nil + return c.(netproxy.Conn), nil } } diff --git a/protocol/hysteria2/client/udp.go b/protocol/hysteria2/client/udp.go index 5378f56..95570a8 100644 --- a/protocol/hysteria2/client/udp.go +++ b/protocol/hysteria2/client/udp.go @@ -5,9 +5,11 @@ import ( "io" "math/rand" "sync" + "time" "github.com/daeuniverse/quic-go" + "github.com/daeuniverse/outbound/netproxy" coreErrs "github.com/daeuniverse/outbound/protocol/hysteria2/errors" "github.com/daeuniverse/outbound/protocol/hysteria2/internal/frag" "github.com/daeuniverse/outbound/protocol/hysteria2/internal/protocol" @@ -32,34 +34,49 @@ type udpConn struct { Closed bool } -func (u *udpConn) Receive() ([]byte, string, error) { +func (u *udpConn) Read(b []byte) (n int, err error) { for { msg := <-u.ReceiveCh if msg == nil { // Closed - return nil, "", io.EOF + return 0, io.EOF } dfMsg := u.D.Feed(msg) if dfMsg == nil { // Incomplete message, wait for more continue } - return dfMsg.Data, dfMsg.Addr, nil + n = copy(b, dfMsg.Data) + return n, nil } } -// Send is not thread-safe, as it uses a shared SendBuf. -func (u *udpConn) Send(data []byte, addr string) error { +func (u *udpConn) SetDeadline(t time.Time) error { + // TODO: Implement + return nil +} + +func (u *udpConn) SetReadDeadline(t time.Time) error { + // TODO: Implement + return nil +} + +func (u *udpConn) SetWriteDeadline(t time.Time) error { + // TODO: Implement + return nil +} + +func (u *udpConn) Write(b []byte) (n int, err error) { // Try no frag first msg := &protocol.UDPMessage{ SessionID: u.ID, PacketID: 0, FragID: 0, FragCount: 1, - Addr: addr, - Data: data, + Addr: "", + Data: b, } - err := u.SendFunc(u.SendBuf, msg) + err = u.SendFunc(u.SendBuf, msg) var errTooLarge *quic.DatagramTooLargeError if errors.As(err, &errTooLarge) { // Message too large, try fragmentation @@ -68,12 +85,12 @@ func (u *udpConn) Send(data []byte, addr string) error { for _, fMsg := range fMsgs { err := u.SendFunc(u.SendBuf, &fMsg) if err != nil { - return err + return 0, err } } - return nil + return len(b), nil } else { - return err + return len(b), err } } @@ -142,7 +159,7 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) { } // NewUDP creates a new UDP session. -func (m *udpSessionManager) NewUDP() (HyUDPConn, error) { +func (m *udpSessionManager) NewUDP() (netproxy.Conn, error) { m.mutex.Lock() defer m.mutex.Unlock() From a9634a62d5efb38d5450782f34c91c3506ca8504 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Sun, 2 Jun 2024 02:11:42 +0800 Subject: [PATCH 04/13] refactor: add timer functionality to udpConn --- protocol/hysteria2/client/udp.go | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/protocol/hysteria2/client/udp.go b/protocol/hysteria2/client/udp.go index 95570a8..2a20197 100644 --- a/protocol/hysteria2/client/udp.go +++ b/protocol/hysteria2/client/udp.go @@ -32,6 +32,9 @@ type udpConn struct { SendFunc func([]byte, *protocol.UDPMessage) error CloseFunc func() Closed bool + + muTimer sync.Mutex + timer *time.Timer } func (u *udpConn) Read(b []byte) (n int, err error) { @@ -52,18 +55,30 @@ func (u *udpConn) Read(b []byte) (n int, err error) { } func (u *udpConn) SetDeadline(t time.Time) error { - // TODO: Implement + u.muTimer.Lock() + defer u.muTimer.Unlock() + dur := time.Until(t) + if u.timer != nil { + u.timer.Reset(dur) + } else { + u.timer = time.AfterFunc(dur, func() { + u.muTimer.Lock() + defer u.muTimer.Unlock() + u.Close() + u.timer = nil + }) + } return nil } func (u *udpConn) SetReadDeadline(t time.Time) error { - // TODO: Implement - return nil + // FIXME: Single direction. + return u.SetDeadline(t) } func (u *udpConn) SetWriteDeadline(t time.Time) error { - // TODO: Implement - return nil + // FIXME: Single direction. + return u.SetDeadline(t) } func (u *udpConn) Write(b []byte) (n int, err error) { @@ -176,6 +191,8 @@ func (m *udpSessionManager) NewUDP() (netproxy.Conn, error) { ReceiveCh: make(chan *protocol.UDPMessage, udpMessageChanSize), SendBuf: make([]byte, protocol.MaxUDPSize), SendFunc: m.io.SendMessage, + + muTimer: sync.Mutex{}, } conn.CloseFunc = func() { m.mutex.Lock() From 28e8097438f54ad59152dd90f55b0d3e8c5e40ed Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Sun, 2 Jun 2024 02:18:18 +0800 Subject: [PATCH 05/13] chore: Update net.PacketConn in hysteria2 client --- protocol/hysteria2/client/client.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/protocol/hysteria2/client/client.go b/protocol/hysteria2/client/client.go index e4caa6e..e9c5d59 100644 --- a/protocol/hysteria2/client/client.go +++ b/protocol/hysteria2/client/client.go @@ -51,7 +51,7 @@ func NewClient(config *Config) (Client, *HandshakeInfo, error) { type clientImpl struct { config *Config - pktConn netproxy.PacketConn + pktConn net.PacketConn conn quic.Connection udpSM *udpSessionManager @@ -143,7 +143,7 @@ func (c *clientImpl) connect() (*HandshakeInfo, error) { } _ = resp.Body.Close() - c.pktConn = pktConn.(netproxy.PacketConn) + c.pktConn = pktConn c.conn = conn if authResp.UDPEnabled { c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn}) From a0211c954d128a463b95ebfcc0c91d357147c5fe Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Sun, 2 Jun 2024 03:12:23 +0800 Subject: [PATCH 06/13] feat: Add UDP address parameter to hysteria2 client's UDP method --- protocol/hysteria2/client/client.go | 6 +- protocol/hysteria2/client/reconnect.go | 4 +- protocol/hysteria2/client/udp.go | 76 +++++++++++++++----------- 3 files changed, 48 insertions(+), 38 deletions(-) diff --git a/protocol/hysteria2/client/client.go b/protocol/hysteria2/client/client.go index e9c5d59..fedae2d 100644 --- a/protocol/hysteria2/client/client.go +++ b/protocol/hysteria2/client/client.go @@ -25,7 +25,7 @@ const ( type Client interface { TCP(addr string) (netproxy.Conn, error) - UDP() (netproxy.Conn, error) + UDP(addr string) (netproxy.Conn, error) Close() error } @@ -203,11 +203,11 @@ func (c *clientImpl) TCP(addr string) (netproxy.Conn, error) { }, nil } -func (c *clientImpl) UDP() (netproxy.Conn, error) { +func (c *clientImpl) UDP(addr string) (netproxy.Conn, error) { if c.udpSM == nil { return nil, coreErrs.DialError{Message: "UDP not enabled"} } - return c.udpSM.NewUDP() + return c.udpSM.NewUDP(addr) } func (c *clientImpl) Close() error { diff --git a/protocol/hysteria2/client/reconnect.go b/protocol/hysteria2/client/reconnect.go index 00586ee..e74d915 100644 --- a/protocol/hysteria2/client/reconnect.go +++ b/protocol/hysteria2/client/reconnect.go @@ -100,9 +100,9 @@ func (rc *reconnectableClientImpl) TCP(addr string) (netproxy.Conn, error) { } } -func (rc *reconnectableClientImpl) UDP() (netproxy.Conn, error) { +func (rc *reconnectableClientImpl) UDP(addr string) (netproxy.Conn, error) { if c, err := rc.clientDo(func(client Client) (interface{}, error) { - return client.UDP() + return client.UDP(addr) }); err != nil { return nil, err } else { diff --git a/protocol/hysteria2/client/udp.go b/protocol/hysteria2/client/udp.go index 2a20197..4621938 100644 --- a/protocol/hysteria2/client/udp.go +++ b/protocol/hysteria2/client/udp.go @@ -35,60 +35,42 @@ type udpConn struct { muTimer sync.Mutex timer *time.Timer + target string } func (u *udpConn) Read(b []byte) (n int, err error) { + msg, _, err := u.ReadFrom(b) + return msg, err +} + +func (u *udpConn) Write(b []byte) (n int, err error) { + return u.WriteTo(b, u.target) +} + +func (u *udpConn) ReadFrom(p []byte) (n int, addr string, err error) { for { msg := <-u.ReceiveCh if msg == nil { // Closed - return 0, io.EOF + return 0, "", io.EOF } dfMsg := u.D.Feed(msg) if dfMsg == nil { // Incomplete message, wait for more continue } - n = copy(b, dfMsg.Data) - return n, nil - } -} - -func (u *udpConn) SetDeadline(t time.Time) error { - u.muTimer.Lock() - defer u.muTimer.Unlock() - dur := time.Until(t) - if u.timer != nil { - u.timer.Reset(dur) - } else { - u.timer = time.AfterFunc(dur, func() { - u.muTimer.Lock() - defer u.muTimer.Unlock() - u.Close() - u.timer = nil - }) + return copy(p, dfMsg.Data), dfMsg.Addr, nil } - return nil -} - -func (u *udpConn) SetReadDeadline(t time.Time) error { - // FIXME: Single direction. - return u.SetDeadline(t) -} - -func (u *udpConn) SetWriteDeadline(t time.Time) error { - // FIXME: Single direction. - return u.SetDeadline(t) } -func (u *udpConn) Write(b []byte) (n int, err error) { +func (u *udpConn) WriteTo(b []byte, addr string) (n int, err error) { // Try no frag first msg := &protocol.UDPMessage{ SessionID: u.ID, PacketID: 0, FragID: 0, FragCount: 1, - Addr: "", + Addr: addr, Data: b, } err = u.SendFunc(u.SendBuf, msg) @@ -114,6 +96,33 @@ func (u *udpConn) Close() error { return nil } +func (u *udpConn) SetDeadline(t time.Time) error { + u.muTimer.Lock() + defer u.muTimer.Unlock() + dur := time.Until(t) + if u.timer != nil { + u.timer.Reset(dur) + } else { + u.timer = time.AfterFunc(dur, func() { + u.muTimer.Lock() + defer u.muTimer.Unlock() + u.Close() + u.timer = nil + }) + } + return nil +} + +func (u *udpConn) SetReadDeadline(t time.Time) error { + // FIXME: Single direction. + return u.SetDeadline(t) +} + +func (u *udpConn) SetWriteDeadline(t time.Time) error { + // FIXME: Single direction. + return u.SetDeadline(t) +} + type udpSessionManager struct { io udpIO @@ -174,7 +183,7 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) { } // NewUDP creates a new UDP session. -func (m *udpSessionManager) NewUDP() (netproxy.Conn, error) { +func (m *udpSessionManager) NewUDP(addr string) (netproxy.Conn, error) { m.mutex.Lock() defer m.mutex.Unlock() @@ -193,6 +202,7 @@ func (m *udpSessionManager) NewUDP() (netproxy.Conn, error) { SendFunc: m.io.SendMessage, muTimer: sync.Mutex{}, + target: addr, } conn.CloseFunc = func() { m.mutex.Lock() From 5a15d1290f5d98600048832f1d4fe3dc5835c970 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Sun, 2 Jun 2024 03:13:24 +0800 Subject: [PATCH 07/13] feat: Add hysteria2 dialer --- protocol/hysteria2/dialer.go | 81 ++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 protocol/hysteria2/dialer.go diff --git a/protocol/hysteria2/dialer.go b/protocol/hysteria2/dialer.go new file mode 100644 index 0000000..38ff779 --- /dev/null +++ b/protocol/hysteria2/dialer.go @@ -0,0 +1,81 @@ +package hysteria2 + +import ( + "fmt" + "net" + + "github.com/daeuniverse/outbound/netproxy" + "github.com/daeuniverse/outbound/protocol" + "github.com/daeuniverse/outbound/protocol/hysteria2/client" +) + +func init() { + protocol.Register("hysteria2", NewDialer) +} + +type Dialer struct { + client client.Client + metadata protocol.Metadata +} + +func NewDialer(nextDialer netproxy.Dialer, header protocol.Header) (netproxy.Dialer, error) { + metadata := protocol.Metadata{ + IsClient: header.IsClient, + } + + serverAddr, err := net.ResolveUDPAddr("udp", header.ProxyAddress) + if err != nil { + return nil, err + } + + config := &client.Config{ + ServerAddr: serverAddr, + TLSConfig: client.TLSConfig{ + ServerName: header.TlsConfig.ServerName, + InsecureSkipVerify: header.TlsConfig.InsecureSkipVerify, + VerifyPeerCertificate: header.TlsConfig.VerifyPeerCertificate, + RootCAs: header.TlsConfig.RootCAs, + }, + Auth: header.User, + } + + client, err := client.NewReconnectableClient( + func() (*client.Config, error) { + return config, nil + }, + func(c client.Client, hi *client.HandshakeInfo, i int) { + // Do nothing + }, + false, + ) + if err != nil { + return nil, err + } + + return &Dialer{ + client: client, + metadata: metadata, + }, nil +} + +func (d *Dialer) Dial(network, address string) (netproxy.Conn, error) { + magicNetwork, err := netproxy.ParseMagicNetwork(network) + if err != nil { + return nil, err + } + + metadata, err := protocol.ParseMetadata(address) + if err != nil { + return nil, err + } + metadata.IsClient = d.metadata.IsClient + + switch magicNetwork.Network { + case "tcp": + return d.client.TCP(address) + case "udp": + return d.client.UDP(address) + default: + return nil, fmt.Errorf("unsupported network: %s", network) + } +} From 47c160ca080b01ddac66c636149db147c5762bc5 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Sun, 2 Jun 2024 03:52:05 +0800 Subject: [PATCH 08/13] bug: Fix interface casting issue for udpConn --- protocol/hysteria2/client/udp.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/protocol/hysteria2/client/udp.go b/protocol/hysteria2/client/udp.go index 4621938..7c20935 100644 --- a/protocol/hysteria2/client/udp.go +++ b/protocol/hysteria2/client/udp.go @@ -4,6 +4,7 @@ import ( "errors" "io" "math/rand" + "net/netip" "sync" "time" @@ -47,19 +48,23 @@ func (u *udpConn) Write(b []byte) (n int, err error) { return u.WriteTo(b, u.target) } -func (u *udpConn) ReadFrom(p []byte) (n int, addr string, err error) { +func (u *udpConn) ReadFrom(p []byte) (n int, addr netip.AddrPort, err error) { for { msg := <-u.ReceiveCh if msg == nil { // Closed - return 0, "", io.EOF + return 0, netip.AddrPort{}, io.EOF } dfMsg := u.D.Feed(msg) if dfMsg == nil { // Incomplete message, wait for more continue } - return copy(p, dfMsg.Data), dfMsg.Addr, nil + netipAddr, err := netip.ParseAddrPort(dfMsg.Addr) + if err != nil { + return 0, netipAddr, err + } + return copy(p, dfMsg.Data), netipAddr, nil } } From 303c269bd8ab30427d49f422c128074e92bd27d4 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Sun, 2 Jun 2024 03:52:15 +0800 Subject: [PATCH 09/13] feat: Add hysteria2 dialer test --- protocol/hysteria2/dialer_test.go | 93 +++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 protocol/hysteria2/dialer_test.go diff --git a/protocol/hysteria2/dialer_test.go b/protocol/hysteria2/dialer_test.go new file mode 100644 index 0000000..b958fab --- /dev/null +++ b/protocol/hysteria2/dialer_test.go @@ -0,0 +1,93 @@ +package hysteria2 + +import ( + "bytes" + "crypto/tls" + "fmt" + "net" + "net/http" + "strings" + "testing" + + "github.com/daeuniverse/outbound/netproxy" + "github.com/daeuniverse/outbound/protocol" + "github.com/daeuniverse/outbound/protocol/direct" + "golang.org/x/net/context" +) + +func TestTCP(t *testing.T) { + d, err := NewDialer(direct.SymmetricDirect, protocol.Header{ + ProxyAddress: "localhost:8443", + SNI: "", + TlsConfig: &tls.Config{InsecureSkipVerify: true, NextProtos: []string{"h3"}, MinVersion: tls.VersionTLS13, ServerName: "example.com"}, + User: "auth", + IsClient: true, + Flags: 0, + }) + if err != nil { + t.Fatal(err) + } + c := http.Client{ + Transport: &http.Transport{Dial: func(network string, addr string) (net.Conn, error) { + t.Log("target", addr) + c, err := d.Dial("tcp", addr) + if err != nil { + return nil, err + } + return &netproxy.FakeNetConn{ + Conn: c, + LAddr: nil, + RAddr: nil, + }, nil + }}, + } + resp, err := c.Get("https://ipinfo.io") + if err != nil { + t.Fatal(err) + } + buf := new(bytes.Buffer) + buf.ReadFrom(resp.Body) + defer resp.Body.Close() + t.Log(buf.String()) +} + +func TestUDP(t *testing.T) { + d, err := NewDialer(direct.SymmetricDirect, protocol.Header{ + ProxyAddress: "localhost:8443", + SNI: "", + TlsConfig: &tls.Config{InsecureSkipVerify: true, NextProtos: []string{"h3"}, MinVersion: tls.VersionTLS13, ServerName: "example.com"}, + User: "auth", + IsClient: true, + Flags: 0, + }) + if err != nil { + t.Fatal(err) + } + if err != nil { + t.Fatal(err) + } + resolver := net.Resolver{ + PreferGo: true, + StrictErrors: false, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + t.Log("target", address) + if !strings.HasPrefix(network, "udp") { + return nil, fmt.Errorf("unsupported network") + } + c, err := d.Dial("udp", address) + if err != nil { + return nil, err + } + return &netproxy.FakeNetPacketConn{ + PacketConn: c.(netproxy.PacketConn), + LAddr: nil, + RAddr: nil, + }, nil + }, + } + ips, err := resolver.LookupNetIP(context.TODO(), "ip", "www.baidu.com") + if err != nil { + t.Fatal(err) + } + t.Log(ips) +} From b5e31664ce151ae2207a3559002f51cf248149d6 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Sun, 2 Jun 2024 13:34:01 +0800 Subject: [PATCH 10/13] feat: Add Hysteria2 dialer with URL parser / exporter --- dialer/hysteria2/hysteria2.go | 134 ++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 dialer/hysteria2/hysteria2.go diff --git a/dialer/hysteria2/hysteria2.go b/dialer/hysteria2/hysteria2.go new file mode 100644 index 0000000..e99dc4d --- /dev/null +++ b/dialer/hysteria2/hysteria2.go @@ -0,0 +1,134 @@ +package hysteria2 + +import ( + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "errors" + "net" + "net/url" + "strconv" + "strings" + + "github.com/daeuniverse/outbound/dialer" + "github.com/daeuniverse/outbound/netproxy" + "github.com/daeuniverse/outbound/protocol" +) + +func init() { + dialer.FromLinkRegister("hysteria2", NewHysteria2) + dialer.FromLinkRegister("hy2", NewHysteria2) +} + +type Hysteria2 struct { + Name string + User string + Server string + Port int + Insecure bool + Sni string + PinSHA256 string +} + +func NewHysteria2(option *dialer.ExtraOption, nextDialer netproxy.Dialer, link string) (netproxy.Dialer, *dialer.Property, error) { + s, err := ParseHysteria2URL(link) + if err != nil { + return nil, nil, err + } + return s.Dialer(option, nextDialer) +} + +func (s *Hysteria2) Dialer(option *dialer.ExtraOption, nextDialer netproxy.Dialer) (netproxy.Dialer, *dialer.Property, error) { + d := nextDialer + proxyAddress := net.JoinHostPort(s.Server, strconv.Itoa(s.Port)) + header := protocol.Header{ + ProxyAddress: proxyAddress, + TlsConfig: &tls.Config{ + ServerName: s.Sni, + InsecureSkipVerify: s.Insecure || option.AllowInsecure, + }, + SNI: s.Sni, + User: s.User, + IsClient: true, + } + if s.PinSHA256 != "" { + nHash := normalizeCertHash(s.PinSHA256) + header.TlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + for _, cert := range rawCerts { + hash := sha256.Sum256(cert) + hashHex := hex.EncodeToString(hash[:]) + if hashHex == nHash { + return nil + } + } + // No match + return errors.New("no certificate matches the pinned hash") + } + } + var err error + if d, err = protocol.NewDialer("hysteria2", d, header); err != nil { + return nil, nil, err + } + return d, &dialer.Property{ + Name: s.Name, + Address: proxyAddress, + Protocol: "hysteria2", + Link: s.ExportToURL(), + }, nil +} + +func normalizeCertHash(hash string) string { + r := strings.ToLower(hash) + r = strings.ReplaceAll(r, ":", "") + r = strings.ReplaceAll(r, "-", "") + return r +} + +// ref: https://v2.hysteria.network/zh/docs/developers/URI-Scheme/ +func ParseHysteria2URL(link string) (*Hysteria2, error) { + // TODO: support salamander obfuscation + t, err := url.Parse(link) + if err != nil { + return nil, err + } + port, err := strconv.Atoi(t.Port()) + if err != nil { + return nil, dialer.InvalidParameterErr + } + q := t.Query() + sni := q.Get("sni") + if sni == "" { + sni = t.Hostname() + } + return &Hysteria2{ + Name: t.Fragment, + User: t.User.String(), + Server: t.Hostname(), + Port: port, + Insecure: q.Get("insecure") == "1", + Sni: sni, + PinSHA256: q.Get("pinSHA256"), + }, nil +} + +func (s *Hysteria2) ExportToURL() string { + t := url.URL{ + Scheme: "hysteria2", + Host: net.JoinHostPort(s.Server, strconv.Itoa(s.Port)), + User: url.User(s.User), + Fragment: s.Name, + } + q := t.Query() + if s.Insecure { + q.Set("insecure", "1") + } + if s.Sni != "" { + q.Set("sni", s.Sni) + } + if s.PinSHA256 != "" { + q.Set("pinSHA256", s.PinSHA256) + } + t.RawQuery = q.Encode() + return t.String() +} From fdebd664594f0544ac7258572b87e43dcb9ff940 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Tue, 4 Jun 2024 23:44:35 +0800 Subject: [PATCH 11/13] feat: Enable lazy connection for Hysteria2 dialer --- protocol/hysteria2/dialer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/protocol/hysteria2/dialer.go b/protocol/hysteria2/dialer.go index 38ff779..dbcfa34 100644 --- a/protocol/hysteria2/dialer.go +++ b/protocol/hysteria2/dialer.go @@ -46,7 +46,7 @@ func NewDialer(nextDialer netproxy.Dialer, header protocol.Header) (netproxy.Dia func(c client.Client, hi *client.HandshakeInfo, i int) { // Do nothing }, - false, + true, ) if err != nil { return nil, err From 8f1cb0a3520d4b9d456dc22be3bc7f61a8a33a73 Mon Sep 17 00:00:00 2001 From: mzz2017 <2017@duck.com> Date: Fri, 7 Jun 2024 00:10:33 +0800 Subject: [PATCH 12/13] chore: reuse code from tuic Signed-off-by: Mix <32300164+mnixry@users.noreply.github.com> --- protocol/hysteria2/client/client.go | 2 +- protocol/hysteria2/dialer.go | 6 - .../internal/congestion/bbr/bandwidth.go | 27 - .../congestion/bbr/bandwidth_sampler.go | 874 ---------------- .../internal/congestion/bbr/bbr_sender.go | 984 ------------------ .../internal/congestion/bbr/clock.go | 18 - .../bbr/packet_number_indexed_queue.go | 199 ---- .../internal/congestion/bbr/ringbuffer.go | 118 --- .../congestion/bbr/windowed_filter.go | 162 --- .../internal/congestion/brutal/brutal.go | 185 ---- .../internal/congestion/common/pacer.go | 79 -- .../hysteria2/internal/congestion/utils.go | 18 - protocol/tuic/congestion/brutal/brutal.go | 8 +- 13 files changed, 7 insertions(+), 2673 deletions(-) delete mode 100644 protocol/hysteria2/internal/congestion/bbr/bandwidth.go delete mode 100644 protocol/hysteria2/internal/congestion/bbr/bandwidth_sampler.go delete mode 100644 protocol/hysteria2/internal/congestion/bbr/bbr_sender.go delete mode 100644 protocol/hysteria2/internal/congestion/bbr/clock.go delete mode 100644 protocol/hysteria2/internal/congestion/bbr/packet_number_indexed_queue.go delete mode 100644 protocol/hysteria2/internal/congestion/bbr/ringbuffer.go delete mode 100644 protocol/hysteria2/internal/congestion/bbr/windowed_filter.go delete mode 100644 protocol/hysteria2/internal/congestion/brutal/brutal.go delete mode 100644 protocol/hysteria2/internal/congestion/common/pacer.go delete mode 100644 protocol/hysteria2/internal/congestion/utils.go diff --git a/protocol/hysteria2/client/client.go b/protocol/hysteria2/client/client.go index fedae2d..c0fc16f 100644 --- a/protocol/hysteria2/client/client.go +++ b/protocol/hysteria2/client/client.go @@ -10,9 +10,9 @@ import ( "github.com/daeuniverse/outbound/netproxy" coreErrs "github.com/daeuniverse/outbound/protocol/hysteria2/errors" - "github.com/daeuniverse/outbound/protocol/hysteria2/internal/congestion" "github.com/daeuniverse/outbound/protocol/hysteria2/internal/protocol" "github.com/daeuniverse/outbound/protocol/hysteria2/internal/utils" + "github.com/daeuniverse/outbound/protocol/tuic/congestion" "github.com/daeuniverse/quic-go" "github.com/daeuniverse/quic-go/http3" diff --git a/protocol/hysteria2/dialer.go b/protocol/hysteria2/dialer.go index dbcfa34..22f4d61 100644 --- a/protocol/hysteria2/dialer.go +++ b/protocol/hysteria2/dialer.go @@ -64,12 +64,6 @@ func (d *Dialer) Dial(network, address string) (netproxy.Conn, error) { return nil, err } - metadata, err := protocol.ParseMetadata(address) - if err != nil { - return nil, err - } - metadata.IsClient = d.metadata.IsClient - switch magicNetwork.Network { case "tcp": return d.client.TCP(address) diff --git a/protocol/hysteria2/internal/congestion/bbr/bandwidth.go b/protocol/hysteria2/internal/congestion/bbr/bandwidth.go deleted file mode 100644 index 23d870d..0000000 --- a/protocol/hysteria2/internal/congestion/bbr/bandwidth.go +++ /dev/null @@ -1,27 +0,0 @@ -package bbr - -import ( - "math" - "time" - - "github.com/daeuniverse/quic-go/congestion" -) - -const ( - infBandwidth = Bandwidth(math.MaxUint64) -) - -// Bandwidth of a connection -type Bandwidth uint64 - -const ( - // BitsPerSecond is 1 bit per second - BitsPerSecond Bandwidth = 1 - // BytesPerSecond is 1 byte per second - BytesPerSecond = 8 * BitsPerSecond -) - -// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta -func BandwidthFromDelta(bytes congestion.ByteCount, delta time.Duration) Bandwidth { - return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond -} diff --git a/protocol/hysteria2/internal/congestion/bbr/bandwidth_sampler.go b/protocol/hysteria2/internal/congestion/bbr/bandwidth_sampler.go deleted file mode 100644 index 4b28d42..0000000 --- a/protocol/hysteria2/internal/congestion/bbr/bandwidth_sampler.go +++ /dev/null @@ -1,874 +0,0 @@ -package bbr - -import ( - "math" - "time" - - "github.com/daeuniverse/quic-go/congestion" -) - -const ( - infRTT = time.Duration(math.MaxInt64) - defaultConnectionStateMapQueueSize = 256 - defaultCandidatesBufferSize = 256 -) - -type roundTripCount uint64 - -// SendTimeState is a subset of ConnectionStateOnSentPacket which is returned -// to the caller when the packet is acked or lost. -type sendTimeState struct { - // Whether other states in this object is valid. - isValid bool - // Whether the sender is app limited at the time the packet was sent. - // App limited bandwidth sample might be artificially low because the sender - // did not have enough data to send in order to saturate the link. - isAppLimited bool - // Total number of sent bytes at the time the packet was sent. - // Includes the packet itself. - totalBytesSent congestion.ByteCount - // Total number of acked bytes at the time the packet was sent. - totalBytesAcked congestion.ByteCount - // Total number of lost bytes at the time the packet was sent. - totalBytesLost congestion.ByteCount - // Total number of inflight bytes at the time the packet was sent. - // Includes the packet itself. - // It should be equal to |total_bytes_sent| minus the sum of - // |total_bytes_acked|, |total_bytes_lost| and total neutered bytes. - bytesInFlight congestion.ByteCount -} - -func newSendTimeState( - isAppLimited bool, - totalBytesSent congestion.ByteCount, - totalBytesAcked congestion.ByteCount, - totalBytesLost congestion.ByteCount, - bytesInFlight congestion.ByteCount, -) *sendTimeState { - return &sendTimeState{ - isValid: true, - isAppLimited: isAppLimited, - totalBytesSent: totalBytesSent, - totalBytesAcked: totalBytesAcked, - totalBytesLost: totalBytesLost, - bytesInFlight: bytesInFlight, - } -} - -type extraAckedEvent struct { - // The excess bytes acknowlwedged in the time delta for this event. - extraAcked congestion.ByteCount - - // The bytes acknowledged and time delta from the event. - bytesAcked congestion.ByteCount - timeDelta time.Duration - // The round trip of the event. - round roundTripCount -} - -func maxExtraAckedEventFunc(a, b extraAckedEvent) int { - if a.extraAcked > b.extraAcked { - return 1 - } else if a.extraAcked < b.extraAcked { - return -1 - } - return 0 -} - -// BandwidthSample -type bandwidthSample struct { - // The bandwidth at that particular sample. Zero if no valid bandwidth sample - // is available. - bandwidth Bandwidth - // The RTT measurement at this particular sample. Zero if no RTT sample is - // available. Does not correct for delayed ack time. - rtt time.Duration - // |send_rate| is computed from the current packet being acked('P') and an - // earlier packet that is acked before P was sent. - sendRate Bandwidth - // States captured when the packet was sent. - stateAtSend sendTimeState -} - -func newBandwidthSample() *bandwidthSample { - return &bandwidthSample{ - sendRate: infBandwidth, - } -} - -// MaxAckHeightTracker is part of the BandwidthSampler. It is called after every -// ack event to keep track the degree of ack aggregation(a.k.a "ack height"). -type maxAckHeightTracker struct { - // Tracks the maximum number of bytes acked faster than the estimated - // bandwidth. - maxAckHeightFilter *WindowedFilter[extraAckedEvent, roundTripCount] - // The time this aggregation started and the number of bytes acked during it. - aggregationEpochStartTime time.Time - aggregationEpochBytes congestion.ByteCount - // The last sent packet number before the current aggregation epoch started. - lastSentPacketNumberBeforeEpoch congestion.PacketNumber - // The number of ack aggregation epochs ever started, including the ongoing - // one. Stats only. - numAckAggregationEpochs uint64 - ackAggregationBandwidthThreshold float64 - startNewAggregationEpochAfterFullRound bool - reduceExtraAckedOnBandwidthIncrease bool -} - -func newMaxAckHeightTracker(windowLength roundTripCount) *maxAckHeightTracker { - return &maxAckHeightTracker{ - maxAckHeightFilter: NewWindowedFilter(windowLength, maxExtraAckedEventFunc), - lastSentPacketNumberBeforeEpoch: invalidPacketNumber, - ackAggregationBandwidthThreshold: 1.0, - } -} - -func (m *maxAckHeightTracker) Get() congestion.ByteCount { - return m.maxAckHeightFilter.GetBest().extraAcked -} - -func (m *maxAckHeightTracker) Update( - bandwidthEstimate Bandwidth, - isNewMaxBandwidth bool, - roundTripCount roundTripCount, - lastSentPacketNumber congestion.PacketNumber, - lastAckedPacketNumber congestion.PacketNumber, - ackTime time.Time, - bytesAcked congestion.ByteCount, -) congestion.ByteCount { - forceNewEpoch := false - - if m.reduceExtraAckedOnBandwidthIncrease && isNewMaxBandwidth { - // Save and clear existing entries. - best := m.maxAckHeightFilter.GetBest() - secondBest := m.maxAckHeightFilter.GetSecondBest() - thirdBest := m.maxAckHeightFilter.GetThirdBest() - m.maxAckHeightFilter.Clear() - - // Reinsert the heights into the filter after recalculating. - expectedBytesAcked := bytesFromBandwidthAndTimeDelta(bandwidthEstimate, best.timeDelta) - if expectedBytesAcked < best.bytesAcked { - best.extraAcked = best.bytesAcked - expectedBytesAcked - m.maxAckHeightFilter.Update(best, best.round) - } - expectedBytesAcked = bytesFromBandwidthAndTimeDelta(bandwidthEstimate, secondBest.timeDelta) - if expectedBytesAcked < secondBest.bytesAcked { - secondBest.extraAcked = secondBest.bytesAcked - expectedBytesAcked - m.maxAckHeightFilter.Update(secondBest, secondBest.round) - } - expectedBytesAcked = bytesFromBandwidthAndTimeDelta(bandwidthEstimate, thirdBest.timeDelta) - if expectedBytesAcked < thirdBest.bytesAcked { - thirdBest.extraAcked = thirdBest.bytesAcked - expectedBytesAcked - m.maxAckHeightFilter.Update(thirdBest, thirdBest.round) - } - } - - // If any packet sent after the start of the epoch has been acked, start a new - // epoch. - if m.startNewAggregationEpochAfterFullRound && - m.lastSentPacketNumberBeforeEpoch != invalidPacketNumber && - lastAckedPacketNumber != invalidPacketNumber && - lastAckedPacketNumber > m.lastSentPacketNumberBeforeEpoch { - forceNewEpoch = true - } - if m.aggregationEpochStartTime.IsZero() || forceNewEpoch { - m.aggregationEpochBytes = bytesAcked - m.aggregationEpochStartTime = ackTime - m.lastSentPacketNumberBeforeEpoch = lastSentPacketNumber - m.numAckAggregationEpochs++ - return 0 - } - - // Compute how many bytes are expected to be delivered, assuming max bandwidth - // is correct. - aggregationDelta := ackTime.Sub(m.aggregationEpochStartTime) - expectedBytesAcked := bytesFromBandwidthAndTimeDelta(bandwidthEstimate, aggregationDelta) - // Reset the current aggregation epoch as soon as the ack arrival rate is less - // than or equal to the max bandwidth. - if m.aggregationEpochBytes <= congestion.ByteCount(m.ackAggregationBandwidthThreshold*float64(expectedBytesAcked)) { - // Reset to start measuring a new aggregation epoch. - m.aggregationEpochBytes = bytesAcked - m.aggregationEpochStartTime = ackTime - m.lastSentPacketNumberBeforeEpoch = lastSentPacketNumber - m.numAckAggregationEpochs++ - return 0 - } - - m.aggregationEpochBytes += bytesAcked - - // Compute how many extra bytes were delivered vs max bandwidth. - extraBytesAcked := m.aggregationEpochBytes - expectedBytesAcked - newEvent := extraAckedEvent{ - extraAcked: expectedBytesAcked, - bytesAcked: m.aggregationEpochBytes, - timeDelta: aggregationDelta, - } - m.maxAckHeightFilter.Update(newEvent, roundTripCount) - return extraBytesAcked -} - -func (m *maxAckHeightTracker) SetFilterWindowLength(length roundTripCount) { - m.maxAckHeightFilter.SetWindowLength(length) -} - -func (m *maxAckHeightTracker) Reset(newHeight congestion.ByteCount, newTime roundTripCount) { - newEvent := extraAckedEvent{ - extraAcked: newHeight, - round: newTime, - } - m.maxAckHeightFilter.Reset(newEvent, newTime) -} - -func (m *maxAckHeightTracker) SetAckAggregationBandwidthThreshold(threshold float64) { - m.ackAggregationBandwidthThreshold = threshold -} - -func (m *maxAckHeightTracker) SetStartNewAggregationEpochAfterFullRound(value bool) { - m.startNewAggregationEpochAfterFullRound = value -} - -func (m *maxAckHeightTracker) SetReduceExtraAckedOnBandwidthIncrease(value bool) { - m.reduceExtraAckedOnBandwidthIncrease = value -} - -func (m *maxAckHeightTracker) AckAggregationBandwidthThreshold() float64 { - return m.ackAggregationBandwidthThreshold -} - -func (m *maxAckHeightTracker) NumAckAggregationEpochs() uint64 { - return m.numAckAggregationEpochs -} - -// AckPoint represents a point on the ack line. -type ackPoint struct { - ackTime time.Time - totalBytesAcked congestion.ByteCount -} - -// RecentAckPoints maintains the most recent 2 ack points at distinct times. -type recentAckPoints struct { - ackPoints [2]ackPoint -} - -func (r *recentAckPoints) Update(ackTime time.Time, totalBytesAcked congestion.ByteCount) { - if ackTime.Before(r.ackPoints[1].ackTime) { - r.ackPoints[1].ackTime = ackTime - } else if ackTime.After(r.ackPoints[1].ackTime) { - r.ackPoints[0] = r.ackPoints[1] - r.ackPoints[1].ackTime = ackTime - } - - r.ackPoints[1].totalBytesAcked = totalBytesAcked -} - -func (r *recentAckPoints) Clear() { - r.ackPoints[0] = ackPoint{} - r.ackPoints[1] = ackPoint{} -} - -func (r *recentAckPoints) MostRecentPoint() *ackPoint { - return &r.ackPoints[1] -} - -func (r *recentAckPoints) LessRecentPoint() *ackPoint { - if r.ackPoints[0].totalBytesAcked != 0 { - return &r.ackPoints[0] - } - - return &r.ackPoints[1] -} - -// ConnectionStateOnSentPacket represents the information about a sent packet -// and the state of the connection at the moment the packet was sent, -// specifically the information about the most recently acknowledged packet at -// that moment. -type connectionStateOnSentPacket struct { - // Time at which the packet is sent. - sentTime time.Time - // Size of the packet. - size congestion.ByteCount - // The value of |totalBytesSentAtLastAckedPacket| at the time the - // packet was sent. - totalBytesSentAtLastAckedPacket congestion.ByteCount - // The value of |lastAckedPacketSentTime| at the time the packet was - // sent. - lastAckedPacketSentTime time.Time - // The value of |lastAckedPacketAckTime| at the time the packet was - // sent. - lastAckedPacketAckTime time.Time - // Send time states that are returned to the congestion controller when the - // packet is acked or lost. - sendTimeState sendTimeState -} - -// Snapshot constructor. Records the current state of the bandwidth -// sampler. -// |bytes_in_flight| is the bytes in flight right after the packet is sent. -func newConnectionStateOnSentPacket( - sentTime time.Time, - size congestion.ByteCount, - bytesInFlight congestion.ByteCount, - sampler *bandwidthSampler, -) *connectionStateOnSentPacket { - return &connectionStateOnSentPacket{ - sentTime: sentTime, - size: size, - totalBytesSentAtLastAckedPacket: sampler.totalBytesSentAtLastAckedPacket, - lastAckedPacketSentTime: sampler.lastAckedPacketSentTime, - lastAckedPacketAckTime: sampler.lastAckedPacketAckTime, - sendTimeState: *newSendTimeState( - sampler.isAppLimited, - sampler.totalBytesSent, - sampler.totalBytesAcked, - sampler.totalBytesLost, - bytesInFlight, - ), - } -} - -// BandwidthSampler keeps track of sent and acknowledged packets and outputs a -// bandwidth sample for every packet acknowledged. The samples are taken for -// individual packets, and are not filtered; the consumer has to filter the -// bandwidth samples itself. In certain cases, the sampler will locally severely -// underestimate the bandwidth, hence a maximum filter with a size of at least -// one RTT is recommended. -// -// This class bases its samples on the slope of two curves: the number of bytes -// sent over time, and the number of bytes acknowledged as received over time. -// It produces a sample of both slopes for every packet that gets acknowledged, -// based on a slope between two points on each of the corresponding curves. Note -// that due to the packet loss, the number of bytes on each curve might get -// further and further away from each other, meaning that it is not feasible to -// compare byte values coming from different curves with each other. -// -// The obvious points for measuring slope sample are the ones corresponding to -// the packet that was just acknowledged. Let us denote them as S_1 (point at -// which the current packet was sent) and A_1 (point at which the current packet -// was acknowledged). However, taking a slope requires two points on each line, -// so estimating bandwidth requires picking a packet in the past with respect to -// which the slope is measured. -// -// For that purpose, BandwidthSampler always keeps track of the most recently -// acknowledged packet, and records it together with every outgoing packet. -// When a packet gets acknowledged (A_1), it has not only information about when -// it itself was sent (S_1), but also the information about the latest -// acknowledged packet right before it was sent (S_0 and A_0). -// -// Based on that data, send and ack rate are estimated as: -// -// send_rate = (bytes(S_1) - bytes(S_0)) / (time(S_1) - time(S_0)) -// ack_rate = (bytes(A_1) - bytes(A_0)) / (time(A_1) - time(A_0)) -// -// Here, the ack rate is intuitively the rate we want to treat as bandwidth. -// However, in certain cases (e.g. ack compression) the ack rate at a point may -// end up higher than the rate at which the data was originally sent, which is -// not indicative of the real bandwidth. Hence, we use the send rate as an upper -// bound, and the sample value is -// -// rate_sample = min(send_rate, ack_rate) -// -// An important edge case handled by the sampler is tracking the app-limited -// samples. There are multiple meaning of "app-limited" used interchangeably, -// hence it is important to understand and to be able to distinguish between -// them. -// -// Meaning 1: connection state. The connection is said to be app-limited when -// there is no outstanding data to send. This means that certain bandwidth -// samples in the future would not be an accurate indication of the link -// capacity, and it is important to inform consumer about that. Whenever -// connection becomes app-limited, the sampler is notified via OnAppLimited() -// method. -// -// Meaning 2: a phase in the bandwidth sampler. As soon as the bandwidth -// sampler becomes notified about the connection being app-limited, it enters -// app-limited phase. In that phase, all *sent* packets are marked as -// app-limited. Note that the connection itself does not have to be -// app-limited during the app-limited phase, and in fact it will not be -// (otherwise how would it send packets?). The boolean flag below indicates -// whether the sampler is in that phase. -// -// Meaning 3: a flag on the sent packet and on the sample. If a sent packet is -// sent during the app-limited phase, the resulting sample related to the -// packet will be marked as app-limited. -// -// With the terminology issue out of the way, let us consider the question of -// what kind of situation it addresses. -// -// Consider a scenario where we first send packets 1 to 20 at a regular -// bandwidth, and then immediately run out of data. After a few seconds, we send -// packets 21 to 60, and only receive ack for 21 between sending packets 40 and -// 41. In this case, when we sample bandwidth for packets 21 to 40, the S_0/A_0 -// we use to compute the slope is going to be packet 20, a few seconds apart -// from the current packet, hence the resulting estimate would be extremely low -// and not indicative of anything. Only at packet 41 the S_0/A_0 will become 21, -// meaning that the bandwidth sample would exclude the quiescence. -// -// Based on the analysis of that scenario, we implement the following rule: once -// OnAppLimited() is called, all sent packets will produce app-limited samples -// up until an ack for a packet that was sent after OnAppLimited() was called. -// Note that while the scenario above is not the only scenario when the -// connection is app-limited, the approach works in other cases too. - -type congestionEventSample struct { - // The maximum bandwidth sample from all acked packets. - // QuicBandwidth::Zero() if no samples are available. - sampleMaxBandwidth Bandwidth - // Whether |sample_max_bandwidth| is from a app-limited sample. - sampleIsAppLimited bool - // The minimum rtt sample from all acked packets. - // QuicTime::Delta::Infinite() if no samples are available. - sampleRtt time.Duration - // For each packet p in acked packets, this is the max value of INFLIGHT(p), - // where INFLIGHT(p) is the number of bytes acked while p is inflight. - sampleMaxInflight congestion.ByteCount - // The send state of the largest packet in acked_packets, unless it is - // empty. If acked_packets is empty, it's the send state of the largest - // packet in lost_packets. - lastPacketSendState sendTimeState - // The number of extra bytes acked from this ack event, compared to what is - // expected from the flow's bandwidth. Larger value means more ack - // aggregation. - extraAcked congestion.ByteCount -} - -func newCongestionEventSample() *congestionEventSample { - return &congestionEventSample{ - sampleRtt: infRTT, - } -} - -type bandwidthSampler struct { - // The total number of congestion controlled bytes sent during the connection. - totalBytesSent congestion.ByteCount - - // The total number of congestion controlled bytes which were acknowledged. - totalBytesAcked congestion.ByteCount - - // The total number of congestion controlled bytes which were lost. - totalBytesLost congestion.ByteCount - - // The total number of congestion controlled bytes which have been neutered. - totalBytesNeutered congestion.ByteCount - - // The value of |total_bytes_sent_| at the time the last acknowledged packet - // was sent. Valid only when |last_acked_packet_sent_time_| is valid. - totalBytesSentAtLastAckedPacket congestion.ByteCount - - // The time at which the last acknowledged packet was sent. Set to - // QuicTime::Zero() if no valid timestamp is available. - lastAckedPacketSentTime time.Time - - // The time at which the most recent packet was acknowledged. - lastAckedPacketAckTime time.Time - - // The most recently sent packet. - lastSentPacket congestion.PacketNumber - - // The most recently acked packet. - lastAckedPacket congestion.PacketNumber - - // Indicates whether the bandwidth sampler is currently in an app-limited - // phase. - isAppLimited bool - - // The packet that will be acknowledged after this one will cause the sampler - // to exit the app-limited phase. - endOfAppLimitedPhase congestion.PacketNumber - - // Record of the connection state at the point where each packet in flight was - // sent, indexed by the packet number. - connectionStateMap *packetNumberIndexedQueue[connectionStateOnSentPacket] - - recentAckPoints recentAckPoints - a0Candidates RingBuffer[ackPoint] - - // Maximum number of tracked packets. - maxTrackedPackets congestion.ByteCount - - maxAckHeightTracker *maxAckHeightTracker - totalBytesAckedAfterLastAckEvent congestion.ByteCount - - // True if connection option 'BSAO' is set. - overestimateAvoidance bool - - // True if connection option 'BBRB' is set. - limitMaxAckHeightTrackerBySendRate bool -} - -func newBandwidthSampler(maxAckHeightTrackerWindowLength roundTripCount) *bandwidthSampler { - b := &bandwidthSampler{ - maxAckHeightTracker: newMaxAckHeightTracker(maxAckHeightTrackerWindowLength), - connectionStateMap: newPacketNumberIndexedQueue[connectionStateOnSentPacket](defaultConnectionStateMapQueueSize), - lastSentPacket: invalidPacketNumber, - lastAckedPacket: invalidPacketNumber, - endOfAppLimitedPhase: invalidPacketNumber, - } - - b.a0Candidates.Init(defaultCandidatesBufferSize) - - return b -} - -func (b *bandwidthSampler) MaxAckHeight() congestion.ByteCount { - return b.maxAckHeightTracker.Get() -} - -func (b *bandwidthSampler) NumAckAggregationEpochs() uint64 { - return b.maxAckHeightTracker.NumAckAggregationEpochs() -} - -func (b *bandwidthSampler) SetMaxAckHeightTrackerWindowLength(length roundTripCount) { - b.maxAckHeightTracker.SetFilterWindowLength(length) -} - -func (b *bandwidthSampler) ResetMaxAckHeightTracker(newHeight congestion.ByteCount, newTime roundTripCount) { - b.maxAckHeightTracker.Reset(newHeight, newTime) -} - -func (b *bandwidthSampler) SetStartNewAggregationEpochAfterFullRound(value bool) { - b.maxAckHeightTracker.SetStartNewAggregationEpochAfterFullRound(value) -} - -func (b *bandwidthSampler) SetLimitMaxAckHeightTrackerBySendRate(value bool) { - b.limitMaxAckHeightTrackerBySendRate = value -} - -func (b *bandwidthSampler) SetReduceExtraAckedOnBandwidthIncrease(value bool) { - b.maxAckHeightTracker.SetReduceExtraAckedOnBandwidthIncrease(value) -} - -func (b *bandwidthSampler) EnableOverestimateAvoidance() { - if b.overestimateAvoidance { - return - } - - b.overestimateAvoidance = true - b.maxAckHeightTracker.SetAckAggregationBandwidthThreshold(2.0) -} - -func (b *bandwidthSampler) IsOverestimateAvoidanceEnabled() bool { - return b.overestimateAvoidance -} - -func (b *bandwidthSampler) OnPacketSent( - sentTime time.Time, - packetNumber congestion.PacketNumber, - bytes congestion.ByteCount, - bytesInFlight congestion.ByteCount, - isRetransmittable bool, -) { - b.lastSentPacket = packetNumber - - if !isRetransmittable { - return - } - - b.totalBytesSent += bytes - - // If there are no packets in flight, the time at which the new transmission - // opens can be treated as the A_0 point for the purpose of bandwidth - // sampling. This underestimates bandwidth to some extent, and produces some - // artificially low samples for most packets in flight, but it provides with - // samples at important points where we would not have them otherwise, most - // importantly at the beginning of the connection. - if bytesInFlight == 0 { - b.lastAckedPacketAckTime = sentTime - if b.overestimateAvoidance { - b.recentAckPoints.Clear() - b.recentAckPoints.Update(sentTime, b.totalBytesAcked) - b.a0Candidates.Clear() - b.a0Candidates.PushBack(*b.recentAckPoints.MostRecentPoint()) - } - b.totalBytesSentAtLastAckedPacket = b.totalBytesSent - - // In this situation ack compression is not a concern, set send rate to - // effectively infinite. - b.lastAckedPacketSentTime = sentTime - } - - b.connectionStateMap.Emplace(packetNumber, newConnectionStateOnSentPacket( - sentTime, - bytes, - bytesInFlight+bytes, - b, - )) -} - -func (b *bandwidthSampler) OnCongestionEvent( - ackTime time.Time, - ackedPackets []congestion.AckedPacketInfo, - lostPackets []congestion.LostPacketInfo, - maxBandwidth Bandwidth, - estBandwidthUpperBound Bandwidth, - roundTripCount roundTripCount, -) congestionEventSample { - eventSample := newCongestionEventSample() - - var lastLostPacketSendState sendTimeState - - for _, p := range lostPackets { - sendState := b.OnPacketLost(p.PacketNumber, p.BytesLost) - if sendState.isValid { - lastLostPacketSendState = sendState - } - } - - if len(ackedPackets) == 0 { - // Only populate send state for a loss-only event. - eventSample.lastPacketSendState = lastLostPacketSendState - return *eventSample - } - - var lastAckedPacketSendState sendTimeState - var maxSendRate Bandwidth - - for _, p := range ackedPackets { - sample := b.onPacketAcknowledged(ackTime, p.PacketNumber) - if !sample.stateAtSend.isValid { - continue - } - - lastAckedPacketSendState = sample.stateAtSend - - if sample.rtt != 0 { - eventSample.sampleRtt = min(eventSample.sampleRtt, sample.rtt) - } - if sample.bandwidth > eventSample.sampleMaxBandwidth { - eventSample.sampleMaxBandwidth = sample.bandwidth - eventSample.sampleIsAppLimited = sample.stateAtSend.isAppLimited - } - if sample.sendRate != infBandwidth { - maxSendRate = max(maxSendRate, sample.sendRate) - } - inflightSample := b.totalBytesAcked - lastAckedPacketSendState.totalBytesAcked - if inflightSample > eventSample.sampleMaxInflight { - eventSample.sampleMaxInflight = inflightSample - } - } - - if !lastLostPacketSendState.isValid { - eventSample.lastPacketSendState = lastAckedPacketSendState - } else if !lastAckedPacketSendState.isValid { - eventSample.lastPacketSendState = lastLostPacketSendState - } else { - // If two packets are inflight and an alarm is armed to lose a packet and it - // wakes up late, then the first of two in flight packets could have been - // acknowledged before the wakeup, which re-evaluates loss detection, and - // could declare the later of the two lost. - if lostPackets[len(lostPackets)-1].PacketNumber > ackedPackets[len(ackedPackets)-1].PacketNumber { - eventSample.lastPacketSendState = lastLostPacketSendState - } else { - eventSample.lastPacketSendState = lastAckedPacketSendState - } - } - - isNewMaxBandwidth := eventSample.sampleMaxBandwidth > maxBandwidth - maxBandwidth = max(maxBandwidth, eventSample.sampleMaxBandwidth) - if b.limitMaxAckHeightTrackerBySendRate { - maxBandwidth = max(maxBandwidth, maxSendRate) - } - - eventSample.extraAcked = b.onAckEventEnd(min(estBandwidthUpperBound, maxBandwidth), isNewMaxBandwidth, roundTripCount) - - return *eventSample -} - -func (b *bandwidthSampler) OnPacketLost(packetNumber congestion.PacketNumber, bytesLost congestion.ByteCount) (s sendTimeState) { - b.totalBytesLost += bytesLost - if sentPacketPointer := b.connectionStateMap.GetEntry(packetNumber); sentPacketPointer != nil { - sentPacketToSendTimeState(sentPacketPointer, &s) - } - return s -} - -func (b *bandwidthSampler) OnPacketNeutered(packetNumber congestion.PacketNumber) { - b.connectionStateMap.Remove(packetNumber, func(sentPacket connectionStateOnSentPacket) { - b.totalBytesNeutered += sentPacket.size - }) -} - -func (b *bandwidthSampler) OnAppLimited() { - b.isAppLimited = true - b.endOfAppLimitedPhase = b.lastSentPacket -} - -func (b *bandwidthSampler) RemoveObsoletePackets(leastUnacked congestion.PacketNumber) { - // A packet can become obsolete when it is removed from QuicUnackedPacketMap's - // view of inflight before it is acked or marked as lost. For example, when - // QuicSentPacketManager::RetransmitCryptoPackets retransmits a crypto packet, - // the packet is removed from QuicUnackedPacketMap's inflight, but is not - // marked as acked or lost in the BandwidthSampler. - b.connectionStateMap.RemoveUpTo(leastUnacked) -} - -func (b *bandwidthSampler) TotalBytesSent() congestion.ByteCount { - return b.totalBytesSent -} - -func (b *bandwidthSampler) TotalBytesLost() congestion.ByteCount { - return b.totalBytesLost -} - -func (b *bandwidthSampler) TotalBytesAcked() congestion.ByteCount { - return b.totalBytesAcked -} - -func (b *bandwidthSampler) TotalBytesNeutered() congestion.ByteCount { - return b.totalBytesNeutered -} - -func (b *bandwidthSampler) IsAppLimited() bool { - return b.isAppLimited -} - -func (b *bandwidthSampler) EndOfAppLimitedPhase() congestion.PacketNumber { - return b.endOfAppLimitedPhase -} - -func (b *bandwidthSampler) max_ack_height() congestion.ByteCount { - return b.maxAckHeightTracker.Get() -} - -func (b *bandwidthSampler) chooseA0Point(totalBytesAcked congestion.ByteCount, a0 *ackPoint) bool { - if b.a0Candidates.Empty() { - return false - } - - if b.a0Candidates.Len() == 1 { - *a0 = *b.a0Candidates.Front() - return true - } - - for i := 1; i < b.a0Candidates.Len(); i++ { - if b.a0Candidates.Offset(i).totalBytesAcked > totalBytesAcked { - *a0 = *b.a0Candidates.Offset(i - 1) - if i > 1 { - for j := 0; j < i-1; j++ { - b.a0Candidates.PopFront() - } - } - return true - } - } - - *a0 = *b.a0Candidates.Back() - for k := 0; k < b.a0Candidates.Len()-1; k++ { - b.a0Candidates.PopFront() - } - return true -} - -func (b *bandwidthSampler) onPacketAcknowledged(ackTime time.Time, packetNumber congestion.PacketNumber) bandwidthSample { - sample := newBandwidthSample() - b.lastAckedPacket = packetNumber - sentPacketPointer := b.connectionStateMap.GetEntry(packetNumber) - if sentPacketPointer == nil { - return *sample - } - - // OnPacketAcknowledgedInner - b.totalBytesAcked += sentPacketPointer.size - b.totalBytesSentAtLastAckedPacket = sentPacketPointer.sendTimeState.totalBytesSent - b.lastAckedPacketSentTime = sentPacketPointer.sentTime - b.lastAckedPacketAckTime = ackTime - if b.overestimateAvoidance { - b.recentAckPoints.Update(ackTime, b.totalBytesAcked) - } - - if b.isAppLimited { - // Exit app-limited phase in two cases: - // (1) end_of_app_limited_phase_ is not initialized, i.e., so far all - // packets are sent while there are buffered packets or pending data. - // (2) The current acked packet is after the sent packet marked as the end - // of the app limit phase. - if b.endOfAppLimitedPhase == invalidPacketNumber || - packetNumber > b.endOfAppLimitedPhase { - b.isAppLimited = false - } - } - - // There might have been no packets acknowledged at the moment when the - // current packet was sent. In that case, there is no bandwidth sample to - // make. - if sentPacketPointer.lastAckedPacketSentTime.IsZero() { - return *sample - } - - // Infinite rate indicates that the sampler is supposed to discard the - // current send rate sample and use only the ack rate. - sendRate := infBandwidth - if sentPacketPointer.sentTime.After(sentPacketPointer.lastAckedPacketSentTime) { - sendRate = BandwidthFromDelta( - sentPacketPointer.sendTimeState.totalBytesSent-sentPacketPointer.totalBytesSentAtLastAckedPacket, - sentPacketPointer.sentTime.Sub(sentPacketPointer.lastAckedPacketSentTime)) - } - - var a0 ackPoint - if b.overestimateAvoidance && b.chooseA0Point(sentPacketPointer.sendTimeState.totalBytesAcked, &a0) { - } else { - a0.ackTime = sentPacketPointer.lastAckedPacketAckTime - a0.totalBytesAcked = sentPacketPointer.sendTimeState.totalBytesAcked - } - - // During the slope calculation, ensure that ack time of the current packet is - // always larger than the time of the previous packet, otherwise division by - // zero or integer underflow can occur. - if ackTime.Sub(a0.ackTime) <= 0 { - return *sample - } - - ackRate := BandwidthFromDelta(b.totalBytesAcked-a0.totalBytesAcked, ackTime.Sub(a0.ackTime)) - - sample.bandwidth = min(sendRate, ackRate) - // Note: this sample does not account for delayed acknowledgement time. This - // means that the RTT measurements here can be artificially high, especially - // on low bandwidth connections. - sample.rtt = ackTime.Sub(sentPacketPointer.sentTime) - sample.sendRate = sendRate - sentPacketToSendTimeState(sentPacketPointer, &sample.stateAtSend) - - return *sample -} - -func (b *bandwidthSampler) onAckEventEnd( - bandwidthEstimate Bandwidth, - isNewMaxBandwidth bool, - roundTripCount roundTripCount, -) congestion.ByteCount { - newlyAckedBytes := b.totalBytesAcked - b.totalBytesAckedAfterLastAckEvent - if newlyAckedBytes == 0 { - return 0 - } - b.totalBytesAckedAfterLastAckEvent = b.totalBytesAcked - extraAcked := b.maxAckHeightTracker.Update( - bandwidthEstimate, - isNewMaxBandwidth, - roundTripCount, - b.lastSentPacket, - b.lastAckedPacket, - b.lastAckedPacketAckTime, - newlyAckedBytes) - // If |extra_acked| is zero, i.e. this ack event marks the start of a new ack - // aggregation epoch, save LessRecentPoint, which is the last ack point of the - // previous epoch, as a A0 candidate. - if b.overestimateAvoidance && extraAcked == 0 { - b.a0Candidates.PushBack(*b.recentAckPoints.LessRecentPoint()) - } - return extraAcked -} - -func sentPacketToSendTimeState(sentPacket *connectionStateOnSentPacket, sendTimeState *sendTimeState) { - *sendTimeState = sentPacket.sendTimeState - sendTimeState.isValid = true -} - -// BytesFromBandwidthAndTimeDelta calculates the bytes -// from a bandwidth(bits per second) and a time delta -func bytesFromBandwidthAndTimeDelta(bandwidth Bandwidth, delta time.Duration) congestion.ByteCount { - return (congestion.ByteCount(bandwidth) * congestion.ByteCount(delta)) / - (congestion.ByteCount(time.Second) * 8) -} - -func timeDeltaFromBytesAndBandwidth(bytes congestion.ByteCount, bandwidth Bandwidth) time.Duration { - return time.Duration(bytes*8) * time.Second / time.Duration(bandwidth) -} diff --git a/protocol/hysteria2/internal/congestion/bbr/bbr_sender.go b/protocol/hysteria2/internal/congestion/bbr/bbr_sender.go deleted file mode 100644 index 63f5528..0000000 --- a/protocol/hysteria2/internal/congestion/bbr/bbr_sender.go +++ /dev/null @@ -1,984 +0,0 @@ -package bbr - -import ( - "fmt" - "math/rand" - "net" - "os" - "strconv" - "time" - - "github.com/daeuniverse/quic-go/congestion" - - "github.com/daeuniverse/outbound/protocol/hysteria2/internal/congestion/common" -) - -// BbrSender implements BBR congestion control algorithm. BBR aims to estimate -// the current available Bottleneck Bandwidth and RTT (hence the name), and -// regulates the pacing rate and the size of the congestion window based on -// those signals. -// -// BBR relies on pacing in order to function properly. Do not use BBR when -// pacing is disabled. -// - -const ( - minBps = 65536 // 64 kbps - - invalidPacketNumber = -1 - initialCongestionWindowPackets = 32 - - // Constants based on TCP defaults. - // The minimum CWND to ensure delayed acks don't reduce bandwidth measurements. - // Does not inflate the pacing rate. - defaultMinimumCongestionWindow = 4 * congestion.ByteCount(congestion.InitialPacketSizeIPv4) - - // The gain used for the STARTUP, equal to 2/ln(2). - defaultHighGain = 2.885 - // The newly derived gain for STARTUP, equal to 4 * ln(2) - derivedHighGain = 2.773 - // The newly derived CWND gain for STARTUP, 2. - derivedHighCWNDGain = 2.0 - - debugEnv = "HYSTERIA_BBR_DEBUG" -) - -// The cycle of gains used during the PROBE_BW stage. -var pacingGain = [...]float64{1.25, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0} - -const ( - // The length of the gain cycle. - gainCycleLength = len(pacingGain) - // The size of the bandwidth filter window, in round-trips. - bandwidthWindowSize = gainCycleLength + 2 - - // The time after which the current min_rtt value expires. - minRttExpiry = 10 * time.Second - // The minimum time the connection can spend in PROBE_RTT mode. - probeRttTime = 200 * time.Millisecond - // If the bandwidth does not increase by the factor of |kStartupGrowthTarget| - // within |kRoundTripsWithoutGrowthBeforeExitingStartup| rounds, the connection - // will exit the STARTUP mode. - startupGrowthTarget = 1.25 - roundTripsWithoutGrowthBeforeExitingStartup = int64(3) - - // Flag. - defaultStartupFullLossCount = 8 - quicBbr2DefaultLossThreshold = 0.02 - maxBbrBurstPackets = 10 -) - -type bbrMode int - -const ( - // Startup phase of the connection. - bbrModeStartup = iota - // After achieving the highest possible bandwidth during the startup, lower - // the pacing rate in order to drain the queue. - bbrModeDrain - // Cruising mode. - bbrModeProbeBw - // Temporarily slow down sending in order to empty the buffer and measure - // the real minimum RTT. - bbrModeProbeRtt -) - -// Indicates how the congestion control limits the amount of bytes in flight. -type bbrRecoveryState int - -const ( - // Do not limit. - bbrRecoveryStateNotInRecovery = iota - // Allow an extra outstanding byte for each byte acknowledged. - bbrRecoveryStateConservation - // Allow two extra outstanding bytes for each byte acknowledged (slow - // start). - bbrRecoveryStateGrowth -) - -type bbrSender struct { - rttStats congestion.RTTStatsProvider - clock Clock - pacer *common.Pacer - - mode bbrMode - - // Bandwidth sampler provides BBR with the bandwidth measurements at - // individual points. - sampler *bandwidthSampler - - // The number of the round trips that have occurred during the connection. - roundTripCount roundTripCount - - // The packet number of the most recently sent packet. - lastSentPacket congestion.PacketNumber - // Acknowledgement of any packet after |current_round_trip_end_| will cause - // the round trip counter to advance. - currentRoundTripEnd congestion.PacketNumber - - // Number of congestion events with some losses, in the current round. - numLossEventsInRound uint64 - - // Number of total bytes lost in the current round. - bytesLostInRound congestion.ByteCount - - // The filter that tracks the maximum bandwidth over the multiple recent - // round-trips. - maxBandwidth *WindowedFilter[Bandwidth, roundTripCount] - - // Minimum RTT estimate. Automatically expires within 10 seconds (and - // triggers PROBE_RTT mode) if no new value is sampled during that period. - minRtt time.Duration - // The time at which the current value of |min_rtt_| was assigned. - minRttTimestamp time.Time - - // The maximum allowed number of bytes in flight. - congestionWindow congestion.ByteCount - - // The initial value of the |congestion_window_|. - initialCongestionWindow congestion.ByteCount - - // The largest value the |congestion_window_| can achieve. - maxCongestionWindow congestion.ByteCount - - // The smallest value the |congestion_window_| can achieve. - minCongestionWindow congestion.ByteCount - - // The pacing gain applied during the STARTUP phase. - highGain float64 - - // The CWND gain applied during the STARTUP phase. - highCwndGain float64 - - // The pacing gain applied during the DRAIN phase. - drainGain float64 - - // The current pacing rate of the connection. - pacingRate Bandwidth - - // The gain currently applied to the pacing rate. - pacingGain float64 - // The gain currently applied to the congestion window. - congestionWindowGain float64 - - // The gain used for the congestion window during PROBE_BW. Latched from - // quic_bbr_cwnd_gain flag. - congestionWindowGainConstant float64 - // The number of RTTs to stay in STARTUP mode. Defaults to 3. - numStartupRtts int64 - - // Number of round-trips in PROBE_BW mode, used for determining the current - // pacing gain cycle. - cycleCurrentOffset int - // The time at which the last pacing gain cycle was started. - lastCycleStart time.Time - - // Indicates whether the connection has reached the full bandwidth mode. - isAtFullBandwidth bool - // Number of rounds during which there was no significant bandwidth increase. - roundsWithoutBandwidthGain int64 - // The bandwidth compared to which the increase is measured. - bandwidthAtLastRound Bandwidth - - // Set to true upon exiting quiescence. - exitingQuiescence bool - - // Time at which PROBE_RTT has to be exited. Setting it to zero indicates - // that the time is yet unknown as the number of packets in flight has not - // reached the required value. - exitProbeRttAt time.Time - // Indicates whether a round-trip has passed since PROBE_RTT became active. - probeRttRoundPassed bool - - // Indicates whether the most recent bandwidth sample was marked as - // app-limited. - lastSampleIsAppLimited bool - // Indicates whether any non app-limited samples have been recorded. - hasNoAppLimitedSample bool - - // Current state of recovery. - recoveryState bbrRecoveryState - // Receiving acknowledgement of a packet after |end_recovery_at_| will cause - // BBR to exit the recovery mode. A value above zero indicates at least one - // loss has been detected, so it must not be set back to zero. - endRecoveryAt congestion.PacketNumber - // A window used to limit the number of bytes in flight during loss recovery. - recoveryWindow congestion.ByteCount - // If true, consider all samples in recovery app-limited. - isAppLimitedRecovery bool // not used - - // When true, pace at 1.5x and disable packet conservation in STARTUP. - slowerStartup bool // not used - // When true, disables packet conservation in STARTUP. - rateBasedStartup bool // not used - - // When true, add the most recent ack aggregation measurement during STARTUP. - enableAckAggregationDuringStartup bool - // When true, expire the windowed ack aggregation values in STARTUP when - // bandwidth increases more than 25%. - expireAckAggregationInStartup bool - - // If true, will not exit low gain mode until bytes_in_flight drops below BDP - // or it's time for high gain mode. - drainToTarget bool - - // If true, slow down pacing rate in STARTUP when overshooting is detected. - detectOvershooting bool - // Bytes lost while detect_overshooting_ is true. - bytesLostWhileDetectingOvershooting congestion.ByteCount - // Slow down pacing rate if - // bytes_lost_while_detecting_overshooting_ * - // bytes_lost_multiplier_while_detecting_overshooting_ > IW. - bytesLostMultiplierWhileDetectingOvershooting uint8 - // When overshooting is detected, do not drop pacing_rate_ below this value / - // min_rtt. - cwndToCalculateMinPacingRate congestion.ByteCount - - // Max congestion window when adjusting network parameters. - maxCongestionWindowWithNetworkParametersAdjusted congestion.ByteCount // not used - - // Params. - maxDatagramSize congestion.ByteCount - // Recorded on packet sent. equivalent |unacked_packets_->bytes_in_flight()| - bytesInFlight congestion.ByteCount - - debug bool -} - -var _ congestion.CongestionControl = &bbrSender{} - -func NewBbrSender( - clock Clock, - initialMaxDatagramSize congestion.ByteCount, -) *bbrSender { - return newBbrSender( - clock, - initialMaxDatagramSize, - initialCongestionWindowPackets*initialMaxDatagramSize, - congestion.MaxCongestionWindowPackets*initialMaxDatagramSize, - ) -} - -func newBbrSender( - clock Clock, - initialMaxDatagramSize, - initialCongestionWindow, - initialMaxCongestionWindow congestion.ByteCount, -) *bbrSender { - debug, _ := strconv.ParseBool(os.Getenv(debugEnv)) - b := &bbrSender{ - clock: clock, - mode: bbrModeStartup, - sampler: newBandwidthSampler(roundTripCount(bandwidthWindowSize)), - lastSentPacket: invalidPacketNumber, - currentRoundTripEnd: invalidPacketNumber, - maxBandwidth: NewWindowedFilter(roundTripCount(bandwidthWindowSize), MaxFilter[Bandwidth]), - congestionWindow: initialCongestionWindow, - initialCongestionWindow: initialCongestionWindow, - maxCongestionWindow: initialMaxCongestionWindow, - minCongestionWindow: defaultMinimumCongestionWindow, - highGain: defaultHighGain, - highCwndGain: defaultHighGain, - drainGain: 1.0 / defaultHighGain, - pacingGain: 1.0, - congestionWindowGain: 1.0, - congestionWindowGainConstant: 2.0, - numStartupRtts: roundTripsWithoutGrowthBeforeExitingStartup, - recoveryState: bbrRecoveryStateNotInRecovery, - endRecoveryAt: invalidPacketNumber, - recoveryWindow: initialMaxCongestionWindow, - bytesLostMultiplierWhileDetectingOvershooting: 2, - cwndToCalculateMinPacingRate: initialCongestionWindow, - maxCongestionWindowWithNetworkParametersAdjusted: initialMaxCongestionWindow, - maxDatagramSize: initialMaxDatagramSize, - debug: debug, - } - b.pacer = common.NewPacer(b.bandwidthForPacer) - - /* - if b.tracer != nil { - b.lastState = logging.CongestionStateStartup - b.tracer.UpdatedCongestionState(logging.CongestionStateStartup) - } - */ - - b.enterStartupMode(b.clock.Now()) - b.setHighCwndGain(derivedHighCWNDGain) - - return b -} - -func (b *bbrSender) SetRTTStatsProvider(provider congestion.RTTStatsProvider) { - b.rttStats = provider -} - -// TimeUntilSend implements the SendAlgorithm interface. -func (b *bbrSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time { - return b.pacer.TimeUntilSend() -} - -// HasPacingBudget implements the SendAlgorithm interface. -func (b *bbrSender) HasPacingBudget(now time.Time) bool { - return b.pacer.Budget(now) >= b.maxDatagramSize -} - -// OnPacketSent implements the SendAlgorithm interface. -func (b *bbrSender) OnPacketSent( - sentTime time.Time, - bytesInFlight congestion.ByteCount, - packetNumber congestion.PacketNumber, - bytes congestion.ByteCount, - isRetransmittable bool, -) { - b.pacer.SentPacket(sentTime, bytes) - - b.lastSentPacket = packetNumber - b.bytesInFlight = bytesInFlight - - if bytesInFlight == 0 { - b.exitingQuiescence = true - } - - b.sampler.OnPacketSent(sentTime, packetNumber, bytes, bytesInFlight, isRetransmittable) -} - -// CanSend implements the SendAlgorithm interface. -func (b *bbrSender) CanSend(bytesInFlight congestion.ByteCount) bool { - return bytesInFlight < b.GetCongestionWindow() -} - -// MaybeExitSlowStart implements the SendAlgorithm interface. -func (b *bbrSender) MaybeExitSlowStart() { - // Do nothing -} - -// OnPacketAcked implements the SendAlgorithm interface. -func (b *bbrSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes, priorInFlight congestion.ByteCount, eventTime time.Time) { - // Do nothing. -} - -// OnPacketLost implements the SendAlgorithm interface. -func (b *bbrSender) OnPacketLost(number congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) { - // Do nothing. -} - -// OnRetransmissionTimeout implements the SendAlgorithm interface. -func (b *bbrSender) OnRetransmissionTimeout(packetsRetransmitted bool) { - // Do nothing. -} - -// SetMaxDatagramSize implements the SendAlgorithm interface. -func (b *bbrSender) SetMaxDatagramSize(s congestion.ByteCount) { - if s < b.maxDatagramSize { - panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", b.maxDatagramSize, s)) - } - cwndIsMinCwnd := b.congestionWindow == b.minCongestionWindow - b.maxDatagramSize = s - if cwndIsMinCwnd { - b.congestionWindow = b.minCongestionWindow - } - b.pacer.SetMaxDatagramSize(s) -} - -// InSlowStart implements the SendAlgorithmWithDebugInfos interface. -func (b *bbrSender) InSlowStart() bool { - return b.mode == bbrModeStartup -} - -// InRecovery implements the SendAlgorithmWithDebugInfos interface. -func (b *bbrSender) InRecovery() bool { - return b.recoveryState != bbrRecoveryStateNotInRecovery -} - -// GetCongestionWindow implements the SendAlgorithmWithDebugInfos interface. -func (b *bbrSender) GetCongestionWindow() congestion.ByteCount { - if b.mode == bbrModeProbeRtt { - return b.probeRttCongestionWindow() - } - - if b.InRecovery() { - return min(b.congestionWindow, b.recoveryWindow) - } - - return b.congestionWindow -} - -func (b *bbrSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes, priorInFlight congestion.ByteCount) { - // Do nothing. -} - -func (b *bbrSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) { - totalBytesAckedBefore := b.sampler.TotalBytesAcked() - totalBytesLostBefore := b.sampler.TotalBytesLost() - - var isRoundStart, minRttExpired bool - var excessAcked, bytesLost congestion.ByteCount - - // The send state of the largest packet in acked_packets, unless it is - // empty. If acked_packets is empty, it's the send state of the largest - // packet in lost_packets. - var lastPacketSendState sendTimeState - - b.maybeAppLimited(priorInFlight) - - // Update bytesInFlight - b.bytesInFlight = priorInFlight - for _, p := range ackedPackets { - b.bytesInFlight -= p.BytesAcked - } - for _, p := range lostPackets { - b.bytesInFlight -= p.BytesLost - } - - if len(ackedPackets) != 0 { - lastAckedPacket := ackedPackets[len(ackedPackets)-1].PacketNumber - isRoundStart = b.updateRoundTripCounter(lastAckedPacket) - b.updateRecoveryState(lastAckedPacket, len(lostPackets) != 0, isRoundStart) - } - - sample := b.sampler.OnCongestionEvent(eventTime, - ackedPackets, lostPackets, b.maxBandwidth.GetBest(), infBandwidth, b.roundTripCount) - if sample.lastPacketSendState.isValid { - b.lastSampleIsAppLimited = sample.lastPacketSendState.isAppLimited - b.hasNoAppLimitedSample = b.hasNoAppLimitedSample || !b.lastSampleIsAppLimited - } - // Avoid updating |max_bandwidth_| if a) this is a loss-only event, or b) all - // packets in |acked_packets| did not generate valid samples. (e.g. ack of - // ack-only packets). In both cases, sampler_.total_bytes_acked() will not - // change. - if totalBytesAckedBefore != b.sampler.TotalBytesAcked() { - if !sample.sampleIsAppLimited || sample.sampleMaxBandwidth > b.maxBandwidth.GetBest() { - b.maxBandwidth.Update(sample.sampleMaxBandwidth, b.roundTripCount) - } - } - - if sample.sampleRtt != infRTT { - minRttExpired = b.maybeUpdateMinRtt(eventTime, sample.sampleRtt) - } - bytesLost = b.sampler.TotalBytesLost() - totalBytesLostBefore - - excessAcked = sample.extraAcked - lastPacketSendState = sample.lastPacketSendState - - if len(lostPackets) != 0 { - b.numLossEventsInRound++ - b.bytesLostInRound += bytesLost - } - - // Handle logic specific to PROBE_BW mode. - if b.mode == bbrModeProbeBw { - b.updateGainCyclePhase(eventTime, priorInFlight, len(lostPackets) != 0) - } - - // Handle logic specific to STARTUP and DRAIN modes. - if isRoundStart && !b.isAtFullBandwidth { - b.checkIfFullBandwidthReached(&lastPacketSendState) - } - - b.maybeExitStartupOrDrain(eventTime) - - // Handle logic specific to PROBE_RTT. - b.maybeEnterOrExitProbeRtt(eventTime, isRoundStart, minRttExpired) - - // Calculate number of packets acked and lost. - bytesAcked := b.sampler.TotalBytesAcked() - totalBytesAckedBefore - - // After the model is updated, recalculate the pacing rate and congestion - // window. - b.calculatePacingRate(bytesLost) - b.calculateCongestionWindow(bytesAcked, excessAcked) - b.calculateRecoveryWindow(bytesAcked, bytesLost) - - // Cleanup internal state. - // This is where we clean up obsolete (acked or lost) packets from the bandwidth sampler. - // The "least unacked" should actually be FirstOutstanding, but since we are not passing - // that through OnCongestionEventEx, we will only do an estimate using acked/lost packets - // for now. Because of fast retransmission, they should differ by no more than 2 packets. - // (this is controlled by packetThreshold in quic-go's sentPacketHandler) - var leastUnacked congestion.PacketNumber - if len(ackedPackets) != 0 { - leastUnacked = ackedPackets[len(ackedPackets)-1].PacketNumber - 2 - } else { - leastUnacked = lostPackets[len(lostPackets)-1].PacketNumber + 1 - } - b.sampler.RemoveObsoletePackets(leastUnacked) - - if isRoundStart { - b.numLossEventsInRound = 0 - b.bytesLostInRound = 0 - } -} - -func (b *bbrSender) PacingRate() Bandwidth { - if b.pacingRate == 0 { - return Bandwidth(b.highGain * float64( - BandwidthFromDelta(b.initialCongestionWindow, b.getMinRtt()))) - } - - return b.pacingRate -} - -func (b *bbrSender) hasGoodBandwidthEstimateForResumption() bool { - return b.hasNonAppLimitedSample() -} - -func (b *bbrSender) hasNonAppLimitedSample() bool { - return b.hasNoAppLimitedSample -} - -// Sets the pacing gain used in STARTUP. Must be greater than 1. -func (b *bbrSender) setHighGain(highGain float64) { - b.highGain = highGain - if b.mode == bbrModeStartup { - b.pacingGain = highGain - } -} - -// Sets the CWND gain used in STARTUP. Must be greater than 1. -func (b *bbrSender) setHighCwndGain(highCwndGain float64) { - b.highCwndGain = highCwndGain - if b.mode == bbrModeStartup { - b.congestionWindowGain = highCwndGain - } -} - -// Sets the gain used in DRAIN. Must be less than 1. -func (b *bbrSender) setDrainGain(drainGain float64) { - b.drainGain = drainGain -} - -// Get the current bandwidth estimate. Note that Bandwidth is in bits per second. -func (b *bbrSender) bandwidthEstimate() Bandwidth { - return b.maxBandwidth.GetBest() -} - -func (b *bbrSender) bandwidthForPacer() congestion.ByteCount { - bps := congestion.ByteCount(float64(b.bandwidthEstimate()) * b.congestionWindowGain / float64(BytesPerSecond)) - if bps < minBps { - // We need to make sure that the bandwidth value for pacer is never zero, - // otherwise it will go into an edge case where HasPacingBudget = false - // but TimeUntilSend is before, causing the quic-go send loop to go crazy and get stuck. - return minBps - } - return bps -} - -// Returns the current estimate of the RTT of the connection. Outside of the -// edge cases, this is minimum RTT. -func (b *bbrSender) getMinRtt() time.Duration { - if b.minRtt != 0 { - return b.minRtt - } - // min_rtt could be available if the handshake packet gets neutered then - // gets acknowledged. This could only happen for QUIC crypto where we do not - // drop keys. - minRtt := b.rttStats.MinRTT() - if minRtt == 0 { - return 100 * time.Millisecond - } else { - return minRtt - } -} - -// Computes the target congestion window using the specified gain. -func (b *bbrSender) getTargetCongestionWindow(gain float64) congestion.ByteCount { - bdp := bdpFromRttAndBandwidth(b.getMinRtt(), b.bandwidthEstimate()) - congestionWindow := congestion.ByteCount(gain * float64(bdp)) - - // BDP estimate will be zero if no bandwidth samples are available yet. - if congestionWindow == 0 { - congestionWindow = congestion.ByteCount(gain * float64(b.initialCongestionWindow)) - } - - return max(congestionWindow, b.minCongestionWindow) -} - -// The target congestion window during PROBE_RTT. -func (b *bbrSender) probeRttCongestionWindow() congestion.ByteCount { - return b.minCongestionWindow -} - -func (b *bbrSender) maybeUpdateMinRtt(now time.Time, sampleMinRtt time.Duration) bool { - // Do not expire min_rtt if none was ever available. - minRttExpired := b.minRtt != 0 && now.After(b.minRttTimestamp.Add(minRttExpiry)) - if minRttExpired || sampleMinRtt < b.minRtt || b.minRtt == 0 { - b.minRtt = sampleMinRtt - b.minRttTimestamp = now - } - - return minRttExpired -} - -// Enters the STARTUP mode. -func (b *bbrSender) enterStartupMode(now time.Time) { - b.mode = bbrModeStartup - // b.maybeTraceStateChange(logging.CongestionStateStartup) - b.pacingGain = b.highGain - b.congestionWindowGain = b.highCwndGain - - if b.debug { - b.debugPrint("Phase: STARTUP") - } -} - -// Enters the PROBE_BW mode. -func (b *bbrSender) enterProbeBandwidthMode(now time.Time) { - b.mode = bbrModeProbeBw - // b.maybeTraceStateChange(logging.CongestionStateProbeBw) - b.congestionWindowGain = b.congestionWindowGainConstant - - // Pick a random offset for the gain cycle out of {0, 2..7} range. 1 is - // excluded because in that case increased gain and decreased gain would not - // follow each other. - b.cycleCurrentOffset = int(rand.Int31n(congestion.PacketsPerConnectionID)) % (gainCycleLength - 1) - if b.cycleCurrentOffset >= 1 { - b.cycleCurrentOffset += 1 - } - - b.lastCycleStart = now - b.pacingGain = pacingGain[b.cycleCurrentOffset] - - if b.debug { - b.debugPrint("Phase: PROBE_BW") - } -} - -// Updates the round-trip counter if a round-trip has passed. Returns true if -// the counter has been advanced. -func (b *bbrSender) updateRoundTripCounter(lastAckedPacket congestion.PacketNumber) bool { - if b.currentRoundTripEnd == invalidPacketNumber || lastAckedPacket > b.currentRoundTripEnd { - b.roundTripCount++ - b.currentRoundTripEnd = b.lastSentPacket - return true - } - return false -} - -// Updates the current gain used in PROBE_BW mode. -func (b *bbrSender) updateGainCyclePhase(now time.Time, priorInFlight congestion.ByteCount, hasLosses bool) { - // In most cases, the cycle is advanced after an RTT passes. - shouldAdvanceGainCycling := now.After(b.lastCycleStart.Add(b.getMinRtt())) - // If the pacing gain is above 1.0, the connection is trying to probe the - // bandwidth by increasing the number of bytes in flight to at least - // pacing_gain * BDP. Make sure that it actually reaches the target, as long - // as there are no losses suggesting that the buffers are not able to hold - // that much. - if b.pacingGain > 1.0 && !hasLosses && priorInFlight < b.getTargetCongestionWindow(b.pacingGain) { - shouldAdvanceGainCycling = false - } - - // If pacing gain is below 1.0, the connection is trying to drain the extra - // queue which could have been incurred by probing prior to it. If the number - // of bytes in flight falls down to the estimated BDP value earlier, conclude - // that the queue has been successfully drained and exit this cycle early. - if b.pacingGain < 1.0 && b.bytesInFlight <= b.getTargetCongestionWindow(1) { - shouldAdvanceGainCycling = true - } - - if shouldAdvanceGainCycling { - b.cycleCurrentOffset = (b.cycleCurrentOffset + 1) % gainCycleLength - b.lastCycleStart = now - // Stay in low gain mode until the target BDP is hit. - // Low gain mode will be exited immediately when the target BDP is achieved. - if b.drainToTarget && b.pacingGain < 1 && - pacingGain[b.cycleCurrentOffset] == 1 && - b.bytesInFlight > b.getTargetCongestionWindow(1) { - return - } - b.pacingGain = pacingGain[b.cycleCurrentOffset] - } -} - -// Tracks for how many round-trips the bandwidth has not increased -// significantly. -func (b *bbrSender) checkIfFullBandwidthReached(lastPacketSendState *sendTimeState) { - if b.lastSampleIsAppLimited { - return - } - - target := Bandwidth(float64(b.bandwidthAtLastRound) * startupGrowthTarget) - if b.bandwidthEstimate() >= target { - b.bandwidthAtLastRound = b.bandwidthEstimate() - b.roundsWithoutBandwidthGain = 0 - if b.expireAckAggregationInStartup { - // Expire old excess delivery measurements now that bandwidth increased. - b.sampler.ResetMaxAckHeightTracker(0, b.roundTripCount) - } - return - } - - b.roundsWithoutBandwidthGain++ - if b.roundsWithoutBandwidthGain >= b.numStartupRtts || - b.shouldExitStartupDueToLoss(lastPacketSendState) { - b.isAtFullBandwidth = true - } -} - -func (b *bbrSender) maybeAppLimited(bytesInFlight congestion.ByteCount) { - if bytesInFlight < b.getTargetCongestionWindow(1) { - b.sampler.OnAppLimited() - } -} - -// Transitions from STARTUP to DRAIN and from DRAIN to PROBE_BW if -// appropriate. -func (b *bbrSender) maybeExitStartupOrDrain(now time.Time) { - if b.mode == bbrModeStartup && b.isAtFullBandwidth { - b.mode = bbrModeDrain - // b.maybeTraceStateChange(logging.CongestionStateDrain) - b.pacingGain = b.drainGain - b.congestionWindowGain = b.highCwndGain - - if b.debug { - b.debugPrint("Phase: DRAIN") - } - } - if b.mode == bbrModeDrain && b.bytesInFlight <= b.getTargetCongestionWindow(1) { - b.enterProbeBandwidthMode(now) - } -} - -// Decides whether to enter or exit PROBE_RTT. -func (b *bbrSender) maybeEnterOrExitProbeRtt(now time.Time, isRoundStart, minRttExpired bool) { - if minRttExpired && !b.exitingQuiescence && b.mode != bbrModeProbeRtt { - b.mode = bbrModeProbeRtt - // b.maybeTraceStateChange(logging.CongestionStateProbRtt) - b.pacingGain = 1.0 - // Do not decide on the time to exit PROBE_RTT until the |bytes_in_flight| - // is at the target small value. - b.exitProbeRttAt = time.Time{} - - if b.debug { - b.debugPrint("BandwidthEstimate: %s, CongestionWindowGain: %.2f, PacingGain: %.2f, PacingRate: %s", - formatSpeed(b.bandwidthEstimate()), b.congestionWindowGain, b.pacingGain, formatSpeed(b.PacingRate())) - b.debugPrint("Phase: PROBE_RTT") - } - } - - if b.mode == bbrModeProbeRtt { - b.sampler.OnAppLimited() - // b.maybeTraceStateChange(logging.CongestionStateApplicationLimited) - - if b.exitProbeRttAt.IsZero() { - // If the window has reached the appropriate size, schedule exiting - // PROBE_RTT. The CWND during PROBE_RTT is kMinimumCongestionWindow, but - // we allow an extra packet since QUIC checks CWND before sending a - // packet. - if b.bytesInFlight < b.probeRttCongestionWindow()+congestion.MaxPacketBufferSize { - b.exitProbeRttAt = now.Add(probeRttTime) - b.probeRttRoundPassed = false - } - } else { - if isRoundStart { - b.probeRttRoundPassed = true - } - if now.Sub(b.exitProbeRttAt) >= 0 && b.probeRttRoundPassed { - b.minRttTimestamp = now - if b.debug { - b.debugPrint("MinRTT: %s", b.getMinRtt()) - } - if !b.isAtFullBandwidth { - b.enterStartupMode(now) - } else { - b.enterProbeBandwidthMode(now) - } - } - } - } - - b.exitingQuiescence = false -} - -// Determines whether BBR needs to enter, exit or advance state of the -// recovery. -func (b *bbrSender) updateRecoveryState(lastAckedPacket congestion.PacketNumber, hasLosses, isRoundStart bool) { - // Disable recovery in startup, if loss-based exit is enabled. - if !b.isAtFullBandwidth { - return - } - - // Exit recovery when there are no losses for a round. - if hasLosses { - b.endRecoveryAt = b.lastSentPacket - } - - switch b.recoveryState { - case bbrRecoveryStateNotInRecovery: - if hasLosses { - b.recoveryState = bbrRecoveryStateConservation - // This will cause the |recovery_window_| to be set to the correct - // value in CalculateRecoveryWindow(). - b.recoveryWindow = 0 - // Since the conservation phase is meant to be lasting for a whole - // round, extend the current round as if it were started right now. - b.currentRoundTripEnd = b.lastSentPacket - } - case bbrRecoveryStateConservation: - if isRoundStart { - b.recoveryState = bbrRecoveryStateGrowth - } - fallthrough - case bbrRecoveryStateGrowth: - // Exit recovery if appropriate. - if !hasLosses && lastAckedPacket > b.endRecoveryAt { - b.recoveryState = bbrRecoveryStateNotInRecovery - } - } -} - -// Determines the appropriate pacing rate for the connection. -func (b *bbrSender) calculatePacingRate(bytesLost congestion.ByteCount) { - if b.bandwidthEstimate() == 0 { - return - } - - targetRate := Bandwidth(b.pacingGain * float64(b.bandwidthEstimate())) - if b.isAtFullBandwidth { - b.pacingRate = targetRate - return - } - - // Pace at the rate of initial_window / RTT as soon as RTT measurements are - // available. - if b.pacingRate == 0 && b.rttStats.MinRTT() != 0 { - b.pacingRate = BandwidthFromDelta(b.initialCongestionWindow, b.rttStats.MinRTT()) - return - } - - if b.detectOvershooting { - b.bytesLostWhileDetectingOvershooting += bytesLost - // Check for overshooting with network parameters adjusted when pacing rate - // > target_rate and loss has been detected. - if b.pacingRate > targetRate && b.bytesLostWhileDetectingOvershooting > 0 { - if b.hasNoAppLimitedSample || - b.bytesLostWhileDetectingOvershooting*congestion.ByteCount(b.bytesLostMultiplierWhileDetectingOvershooting) > b.initialCongestionWindow { - // We are fairly sure overshoot happens if 1) there is at least one - // non app-limited bw sample or 2) half of IW gets lost. Slow pacing - // rate. - b.pacingRate = max(targetRate, BandwidthFromDelta(b.cwndToCalculateMinPacingRate, b.rttStats.MinRTT())) - b.bytesLostWhileDetectingOvershooting = 0 - b.detectOvershooting = false - } - } - } - - // Do not decrease the pacing rate during startup. - b.pacingRate = max(b.pacingRate, targetRate) -} - -// Determines the appropriate congestion window for the connection. -func (b *bbrSender) calculateCongestionWindow(bytesAcked, excessAcked congestion.ByteCount) { - if b.mode == bbrModeProbeRtt { - return - } - - targetWindow := b.getTargetCongestionWindow(b.congestionWindowGain) - if b.isAtFullBandwidth { - // Add the max recently measured ack aggregation to CWND. - targetWindow += b.sampler.MaxAckHeight() - } else if b.enableAckAggregationDuringStartup { - // Add the most recent excess acked. Because CWND never decreases in - // STARTUP, this will automatically create a very localized max filter. - targetWindow += excessAcked - } - - // Instead of immediately setting the target CWND as the new one, BBR grows - // the CWND towards |target_window| by only increasing it |bytes_acked| at a - // time. - if b.isAtFullBandwidth { - b.congestionWindow = min(targetWindow, b.congestionWindow+bytesAcked) - } else if b.congestionWindow < targetWindow || - b.sampler.TotalBytesAcked() < b.initialCongestionWindow { - // If the connection is not yet out of startup phase, do not decrease the - // window. - b.congestionWindow += bytesAcked - } - - // Enforce the limits on the congestion window. - b.congestionWindow = max(b.congestionWindow, b.minCongestionWindow) - b.congestionWindow = min(b.congestionWindow, b.maxCongestionWindow) -} - -// Determines the appropriate window that constrains the in-flight during recovery. -func (b *bbrSender) calculateRecoveryWindow(bytesAcked, bytesLost congestion.ByteCount) { - if b.recoveryState == bbrRecoveryStateNotInRecovery { - return - } - - // Set up the initial recovery window. - if b.recoveryWindow == 0 { - b.recoveryWindow = b.bytesInFlight + bytesAcked - b.recoveryWindow = max(b.minCongestionWindow, b.recoveryWindow) - return - } - - // Remove losses from the recovery window, while accounting for a potential - // integer underflow. - if b.recoveryWindow >= bytesLost { - b.recoveryWindow = b.recoveryWindow - bytesLost - } else { - b.recoveryWindow = b.maxDatagramSize - } - - // In CONSERVATION mode, just subtracting losses is sufficient. In GROWTH, - // release additional |bytes_acked| to achieve a slow-start-like behavior. - if b.recoveryState == bbrRecoveryStateGrowth { - b.recoveryWindow += bytesAcked - } - - // Always allow sending at least |bytes_acked| in response. - b.recoveryWindow = max(b.recoveryWindow, b.bytesInFlight+bytesAcked) - b.recoveryWindow = max(b.minCongestionWindow, b.recoveryWindow) -} - -// Return whether we should exit STARTUP due to excessive loss. -func (b *bbrSender) shouldExitStartupDueToLoss(lastPacketSendState *sendTimeState) bool { - if b.numLossEventsInRound < defaultStartupFullLossCount || !lastPacketSendState.isValid { - return false - } - - inflightAtSend := lastPacketSendState.bytesInFlight - - if inflightAtSend > 0 && b.bytesLostInRound > 0 { - if b.bytesLostInRound > congestion.ByteCount(float64(inflightAtSend)*quicBbr2DefaultLossThreshold) { - return true - } - return false - } - return false -} - -func (b *bbrSender) debugPrint(format string, a ...any) { - fmt.Printf("[BBRSender] [%s] %s\n", - time.Now().Format("15:04:05"), - fmt.Sprintf(format, a...)) -} - -func bdpFromRttAndBandwidth(rtt time.Duration, bandwidth Bandwidth) congestion.ByteCount { - return congestion.ByteCount(rtt) * congestion.ByteCount(bandwidth) / congestion.ByteCount(BytesPerSecond) / congestion.ByteCount(time.Second) -} - -func GetInitialPacketSize(addr net.Addr) congestion.ByteCount { - // If this is not a UDP address, we don't know anything about the MTU. - // Use the minimum size of an Initial packet as the max packet size. - if udpAddr, ok := addr.(*net.UDPAddr); ok { - if udpAddr.IP.To4() != nil { - return congestion.InitialPacketSizeIPv4 - } else { - return congestion.InitialPacketSizeIPv6 - } - } else { - return congestion.MinInitialPacketSize - } -} - -func formatSpeed(bw Bandwidth) string { - bwf := float64(bw) - units := []string{"bps", "Kbps", "Mbps", "Gbps"} - unitIndex := 0 - for bwf > 1024 && unitIndex < len(units)-1 { - bwf /= 1024 - unitIndex++ - } - return fmt.Sprintf("%.2f %s", bwf, units[unitIndex]) -} diff --git a/protocol/hysteria2/internal/congestion/bbr/clock.go b/protocol/hysteria2/internal/congestion/bbr/clock.go deleted file mode 100644 index a66344f..0000000 --- a/protocol/hysteria2/internal/congestion/bbr/clock.go +++ /dev/null @@ -1,18 +0,0 @@ -package bbr - -import "time" - -// A Clock returns the current time -type Clock interface { - Now() time.Time -} - -// DefaultClock implements the Clock interface using the Go stdlib clock. -type DefaultClock struct{} - -var _ Clock = DefaultClock{} - -// Now gets the current time -func (DefaultClock) Now() time.Time { - return time.Now() -} diff --git a/protocol/hysteria2/internal/congestion/bbr/packet_number_indexed_queue.go b/protocol/hysteria2/internal/congestion/bbr/packet_number_indexed_queue.go deleted file mode 100644 index e9fad5a..0000000 --- a/protocol/hysteria2/internal/congestion/bbr/packet_number_indexed_queue.go +++ /dev/null @@ -1,199 +0,0 @@ -package bbr - -import ( - "github.com/daeuniverse/quic-go/congestion" -) - -// packetNumberIndexedQueue is a queue of mostly continuous numbered entries -// which supports the following operations: -// - adding elements to the end of the queue, or at some point past the end -// - removing elements in any order -// - retrieving elements -// If all elements are inserted in order, all of the operations above are -// amortized O(1) time. -// -// Internally, the data structure is a deque where each element is marked as -// present or not. The deque starts at the lowest present index. Whenever an -// element is removed, it's marked as not present, and the front of the deque is -// cleared of elements that are not present. -// -// The tail of the queue is not cleared due to the assumption of entries being -// inserted in order, though removing all elements of the queue will return it -// to its initial state. -// -// Note that this data structure is inherently hazardous, since an addition of -// just two entries will cause it to consume all of the memory available. -// Because of that, it is not a general-purpose container and should not be used -// as one. - -type entryWrapper[T any] struct { - present bool - entry T -} - -type packetNumberIndexedQueue[T any] struct { - entries RingBuffer[entryWrapper[T]] - numberOfPresentEntries int - firstPacket congestion.PacketNumber -} - -func newPacketNumberIndexedQueue[T any](size int) *packetNumberIndexedQueue[T] { - q := &packetNumberIndexedQueue[T]{ - firstPacket: invalidPacketNumber, - } - - q.entries.Init(size) - - return q -} - -// Emplace inserts data associated |packet_number| into (or past) the end of the -// queue, filling up the missing intermediate entries as necessary. Returns -// true if the element has been inserted successfully, false if it was already -// in the queue or inserted out of order. -func (p *packetNumberIndexedQueue[T]) Emplace(packetNumber congestion.PacketNumber, entry *T) bool { - if packetNumber == invalidPacketNumber || entry == nil { - return false - } - - if p.IsEmpty() { - p.entries.PushBack(entryWrapper[T]{ - present: true, - entry: *entry, - }) - p.numberOfPresentEntries = 1 - p.firstPacket = packetNumber - return true - } - - // Do not allow insertion out-of-order. - if packetNumber <= p.LastPacket() { - return false - } - - // Handle potentially missing elements. - offset := int(packetNumber - p.FirstPacket()) - if gap := offset - p.entries.Len(); gap > 0 { - for i := 0; i < gap; i++ { - p.entries.PushBack(entryWrapper[T]{}) - } - } - - p.entries.PushBack(entryWrapper[T]{ - present: true, - entry: *entry, - }) - p.numberOfPresentEntries++ - return true -} - -// GetEntry Retrieve the entry associated with the packet number. Returns the pointer -// to the entry in case of success, or nullptr if the entry does not exist. -func (p *packetNumberIndexedQueue[T]) GetEntry(packetNumber congestion.PacketNumber) *T { - ew := p.getEntryWraper(packetNumber) - if ew == nil { - return nil - } - - return &ew.entry -} - -// Remove, Same as above, but if an entry is present in the queue, also call f(entry) -// before removing it. -func (p *packetNumberIndexedQueue[T]) Remove(packetNumber congestion.PacketNumber, f func(T)) bool { - ew := p.getEntryWraper(packetNumber) - if ew == nil { - return false - } - if f != nil { - f(ew.entry) - } - ew.present = false - p.numberOfPresentEntries-- - - if packetNumber == p.FirstPacket() { - p.clearup() - } - - return true -} - -// RemoveUpTo, but not including |packet_number|. -// Unused slots in the front are also removed, which means when the function -// returns, |first_packet()| can be larger than |packet_number|. -func (p *packetNumberIndexedQueue[T]) RemoveUpTo(packetNumber congestion.PacketNumber) { - for !p.entries.Empty() && - p.firstPacket != invalidPacketNumber && - p.firstPacket < packetNumber { - if p.entries.Front().present { - p.numberOfPresentEntries-- - } - p.entries.PopFront() - p.firstPacket++ - } - p.clearup() - - return -} - -// IsEmpty return if queue is empty. -func (p *packetNumberIndexedQueue[T]) IsEmpty() bool { - return p.numberOfPresentEntries == 0 -} - -// NumberOfPresentEntries returns the number of entries in the queue. -func (p *packetNumberIndexedQueue[T]) NumberOfPresentEntries() int { - return p.numberOfPresentEntries -} - -// EntrySlotsUsed returns the number of entries allocated in the underlying deque. This is -// proportional to the memory usage of the queue. -func (p *packetNumberIndexedQueue[T]) EntrySlotsUsed() int { - return p.entries.Len() -} - -// LastPacket returns packet number of the first entry in the queue. -func (p *packetNumberIndexedQueue[T]) FirstPacket() (packetNumber congestion.PacketNumber) { - return p.firstPacket -} - -// LastPacket returns packet number of the last entry ever inserted in the queue. Note that the -// entry in question may have already been removed. Zero if the queue is -// empty. -func (p *packetNumberIndexedQueue[T]) LastPacket() (packetNumber congestion.PacketNumber) { - if p.IsEmpty() { - return invalidPacketNumber - } - - return p.firstPacket + congestion.PacketNumber(p.entries.Len()-1) -} - -func (p *packetNumberIndexedQueue[T]) clearup() { - for !p.entries.Empty() && !p.entries.Front().present { - p.entries.PopFront() - p.firstPacket++ - } - if p.entries.Empty() { - p.firstPacket = invalidPacketNumber - } -} - -func (p *packetNumberIndexedQueue[T]) getEntryWraper(packetNumber congestion.PacketNumber) *entryWrapper[T] { - if packetNumber == invalidPacketNumber || - p.IsEmpty() || - packetNumber < p.firstPacket { - return nil - } - - offset := int(packetNumber - p.firstPacket) - if offset >= p.entries.Len() { - return nil - } - - ew := p.entries.Offset(offset) - if ew == nil || !ew.present { - return nil - } - - return ew -} diff --git a/protocol/hysteria2/internal/congestion/bbr/ringbuffer.go b/protocol/hysteria2/internal/congestion/bbr/ringbuffer.go deleted file mode 100644 index ed92d4c..0000000 --- a/protocol/hysteria2/internal/congestion/bbr/ringbuffer.go +++ /dev/null @@ -1,118 +0,0 @@ -package bbr - -// A RingBuffer is a ring buffer. -// It acts as a heap that doesn't cause any allocations. -type RingBuffer[T any] struct { - ring []T - headPos, tailPos int - full bool -} - -// Init preallocs a buffer with a certain size. -func (r *RingBuffer[T]) Init(size int) { - r.ring = make([]T, size) -} - -// Len returns the number of elements in the ring buffer. -func (r *RingBuffer[T]) Len() int { - if r.full { - return len(r.ring) - } - if r.tailPos >= r.headPos { - return r.tailPos - r.headPos - } - return r.tailPos - r.headPos + len(r.ring) -} - -// Empty says if the ring buffer is empty. -func (r *RingBuffer[T]) Empty() bool { - return !r.full && r.headPos == r.tailPos -} - -// PushBack adds a new element. -// If the ring buffer is full, its capacity is increased first. -func (r *RingBuffer[T]) PushBack(t T) { - if r.full || len(r.ring) == 0 { - r.grow() - } - r.ring[r.tailPos] = t - r.tailPos++ - if r.tailPos == len(r.ring) { - r.tailPos = 0 - } - if r.tailPos == r.headPos { - r.full = true - } -} - -// PopFront returns the next element. -// It must not be called when the buffer is empty, that means that -// callers might need to check if there are elements in the buffer first. -func (r *RingBuffer[T]) PopFront() T { - if r.Empty() { - panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: pop from an empty queue") - } - r.full = false - t := r.ring[r.headPos] - r.ring[r.headPos] = *new(T) - r.headPos++ - if r.headPos == len(r.ring) { - r.headPos = 0 - } - return t -} - -// Offset returns the offset element. -// It must not be called when the buffer is empty, that means that -// callers might need to check if there are elements in the buffer first -// and check if the index larger than buffer length. -func (r *RingBuffer[T]) Offset(index int) *T { - if r.Empty() || index >= r.Len() { - panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: offset from invalid index") - } - offset := (r.headPos + index) % len(r.ring) - return &r.ring[offset] -} - -// Front returns the front element. -// It must not be called when the buffer is empty, that means that -// callers might need to check if there are elements in the buffer first. -func (r *RingBuffer[T]) Front() *T { - if r.Empty() { - panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: front from an empty queue") - } - return &r.ring[r.headPos] -} - -// Back returns the back element. -// It must not be called when the buffer is empty, that means that -// callers might need to check if there are elements in the buffer first. -func (r *RingBuffer[T]) Back() *T { - if r.Empty() { - panic("github.com/quic-go/quic-go/internal/utils/ringbuffer: back from an empty queue") - } - return r.Offset(r.Len() - 1) -} - -// Grow the maximum size of the queue. -// This method assume the queue is full. -func (r *RingBuffer[T]) grow() { - oldRing := r.ring - newSize := len(oldRing) * 2 - if newSize == 0 { - newSize = 1 - } - r.ring = make([]T, newSize) - headLen := copy(r.ring, oldRing[r.headPos:]) - copy(r.ring[headLen:], oldRing[:r.headPos]) - r.headPos, r.tailPos, r.full = 0, len(oldRing), false -} - -// Clear removes all elements. -func (r *RingBuffer[T]) Clear() { - var zeroValue T - for i := range r.ring { - r.ring[i] = zeroValue - } - r.headPos, r.tailPos, r.full = 0, 0, false -} diff --git a/protocol/hysteria2/internal/congestion/bbr/windowed_filter.go b/protocol/hysteria2/internal/congestion/bbr/windowed_filter.go deleted file mode 100644 index 4773bce..0000000 --- a/protocol/hysteria2/internal/congestion/bbr/windowed_filter.go +++ /dev/null @@ -1,162 +0,0 @@ -package bbr - -import ( - "golang.org/x/exp/constraints" -) - -// Implements Kathleen Nichols' algorithm for tracking the minimum (or maximum) -// estimate of a stream of samples over some fixed time interval. (E.g., -// the minimum RTT over the past five minutes.) The algorithm keeps track of -// the best, second best, and third best min (or max) estimates, maintaining an -// invariant that the measurement time of the n'th best >= n-1'th best. - -// The algorithm works as follows. On a reset, all three estimates are set to -// the same sample. The second best estimate is then recorded in the second -// quarter of the window, and a third best estimate is recorded in the second -// half of the window, bounding the worst case error when the true min is -// monotonically increasing (or true max is monotonically decreasing) over the -// window. -// -// A new best sample replaces all three estimates, since the new best is lower -// (or higher) than everything else in the window and it is the most recent. -// The window thus effectively gets reset on every new min. The same property -// holds true for second best and third best estimates. Specifically, when a -// sample arrives that is better than the second best but not better than the -// best, it replaces the second and third best estimates but not the best -// estimate. Similarly, a sample that is better than the third best estimate -// but not the other estimates replaces only the third best estimate. -// -// Finally, when the best expires, it is replaced by the second best, which in -// turn is replaced by the third best. The newest sample replaces the third -// best. - -type WindowedFilterValue interface { - any -} - -type WindowedFilterTime interface { - constraints.Integer | constraints.Float -} - -type WindowedFilter[V WindowedFilterValue, T WindowedFilterTime] struct { - // Time length of window. - windowLength T - estimates []entry[V, T] - comparator func(V, V) int -} - -type entry[V WindowedFilterValue, T WindowedFilterTime] struct { - sample V - time T -} - -// Compares two values and returns true if the first is greater than or equal -// to the second. -func MaxFilter[O constraints.Ordered](a, b O) int { - if a > b { - return 1 - } else if a < b { - return -1 - } - return 0 -} - -// Compares two values and returns true if the first is less than or equal -// to the second. -func MinFilter[O constraints.Ordered](a, b O) int { - if a < b { - return 1 - } else if a > b { - return -1 - } - return 0 -} - -func NewWindowedFilter[V WindowedFilterValue, T WindowedFilterTime](windowLength T, comparator func(V, V) int) *WindowedFilter[V, T] { - return &WindowedFilter[V, T]{ - windowLength: windowLength, - estimates: make([]entry[V, T], 3, 3), - comparator: comparator, - } -} - -// Changes the window length. Does not update any current samples. -func (f *WindowedFilter[V, T]) SetWindowLength(windowLength T) { - f.windowLength = windowLength -} - -func (f *WindowedFilter[V, T]) GetBest() V { - return f.estimates[0].sample -} - -func (f *WindowedFilter[V, T]) GetSecondBest() V { - return f.estimates[1].sample -} - -func (f *WindowedFilter[V, T]) GetThirdBest() V { - return f.estimates[2].sample -} - -// Updates best estimates with |sample|, and expires and updates best -// estimates as necessary. -func (f *WindowedFilter[V, T]) Update(newSample V, newTime T) { - // Reset all estimates if they have not yet been initialized, if new sample - // is a new best, or if the newest recorded estimate is too old. - if f.comparator(f.estimates[0].sample, *new(V)) == 0 || - f.comparator(newSample, f.estimates[0].sample) >= 0 || - newTime-f.estimates[2].time > f.windowLength { - f.Reset(newSample, newTime) - return - } - - if f.comparator(newSample, f.estimates[1].sample) >= 0 { - f.estimates[1] = entry[V, T]{newSample, newTime} - f.estimates[2] = f.estimates[1] - } else if f.comparator(newSample, f.estimates[2].sample) >= 0 { - f.estimates[2] = entry[V, T]{newSample, newTime} - } - - // Expire and update estimates as necessary. - if newTime-f.estimates[0].time > f.windowLength { - // The best estimate hasn't been updated for an entire window, so promote - // second and third best estimates. - f.estimates[0] = f.estimates[1] - f.estimates[1] = f.estimates[2] - f.estimates[2] = entry[V, T]{newSample, newTime} - // Need to iterate one more time. Check if the new best estimate is - // outside the window as well, since it may also have been recorded a - // long time ago. Don't need to iterate once more since we cover that - // case at the beginning of the method. - if newTime-f.estimates[0].time > f.windowLength { - f.estimates[0] = f.estimates[1] - f.estimates[1] = f.estimates[2] - } - return - } - if f.comparator(f.estimates[1].sample, f.estimates[0].sample) == 0 && - newTime-f.estimates[1].time > f.windowLength/4 { - // A quarter of the window has passed without a better sample, so the - // second-best estimate is taken from the second quarter of the window. - f.estimates[1] = entry[V, T]{newSample, newTime} - f.estimates[2] = f.estimates[1] - return - } - - if f.comparator(f.estimates[2].sample, f.estimates[1].sample) == 0 && - newTime-f.estimates[2].time > f.windowLength/2 { - // We've passed a half of the window without a better estimate, so take - // a third-best estimate from the second half of the window. - f.estimates[2] = entry[V, T]{newSample, newTime} - } -} - -// Resets all estimates to new sample. -func (f *WindowedFilter[V, T]) Reset(newSample V, newTime T) { - f.estimates[2] = entry[V, T]{newSample, newTime} - f.estimates[1] = f.estimates[2] - f.estimates[0] = f.estimates[1] -} - -func (f *WindowedFilter[V, T]) Clear() { - f.estimates = make([]entry[V, T], 3, 3) -} diff --git a/protocol/hysteria2/internal/congestion/brutal/brutal.go b/protocol/hysteria2/internal/congestion/brutal/brutal.go deleted file mode 100644 index b353090..0000000 --- a/protocol/hysteria2/internal/congestion/brutal/brutal.go +++ /dev/null @@ -1,185 +0,0 @@ -package brutal - -import ( - "fmt" - "os" - "strconv" - "time" - - "github.com/daeuniverse/outbound/protocol/hysteria2/internal/congestion/common" - - "github.com/daeuniverse/quic-go/congestion" -) - -const ( - pktInfoSlotCount = 5 // slot index is based on seconds, so this is basically how many seconds we sample - minSampleCount = 50 - minAckRate = 0.8 - congestionWindowMultiplier = 2 - - debugEnv = "HYSTERIA_BRUTAL_DEBUG" - debugPrintInterval = 2 -) - -var _ congestion.CongestionControl = &BrutalSender{} - -type BrutalSender struct { - rttStats congestion.RTTStatsProvider - bps congestion.ByteCount - maxDatagramSize congestion.ByteCount - pacer *common.Pacer - - pktInfoSlots [pktInfoSlotCount]pktInfo - ackRate float64 - - debug bool - lastAckPrintTimestamp int64 -} - -type pktInfo struct { - Timestamp int64 - AckCount uint64 - LossCount uint64 -} - -func NewBrutalSender(bps uint64) *BrutalSender { - debug, _ := strconv.ParseBool(os.Getenv(debugEnv)) - bs := &BrutalSender{ - bps: congestion.ByteCount(bps), - maxDatagramSize: congestion.InitialPacketSizeIPv4, - ackRate: 1, - debug: debug, - } - bs.pacer = common.NewPacer(func() congestion.ByteCount { - return congestion.ByteCount(float64(bs.bps) / bs.ackRate) - }) - return bs -} - -func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) { - b.rttStats = rttStats -} - -func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time { - return b.pacer.TimeUntilSend() -} - -func (b *BrutalSender) HasPacingBudget(now time.Time) bool { - return b.pacer.Budget(now) >= b.maxDatagramSize -} - -func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool { - return bytesInFlight <= b.GetCongestionWindow() -} - -func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount { - rtt := b.rttStats.SmoothedRTT() - if rtt <= 0 { - return 10240 - } - cwnd := congestion.ByteCount(float64(b.bps) * rtt.Seconds() * congestionWindowMultiplier / b.ackRate) - if cwnd < b.maxDatagramSize { - cwnd = b.maxDatagramSize - } - return cwnd -} - -func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount, - packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool, -) { - b.pacer.SentPacket(sentTime, bytes) -} - -func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount, - priorInFlight congestion.ByteCount, eventTime time.Time, -) { - // Stub -} - -func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes congestion.ByteCount, - priorInFlight congestion.ByteCount, -) { - // Stub -} - -func (b *BrutalSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) { - currentTimestamp := eventTime.Unix() - slot := currentTimestamp % pktInfoSlotCount - if b.pktInfoSlots[slot].Timestamp == currentTimestamp { - b.pktInfoSlots[slot].LossCount += uint64(len(lostPackets)) - b.pktInfoSlots[slot].AckCount += uint64(len(ackedPackets)) - } else { - // uninitialized slot or too old, reset - b.pktInfoSlots[slot].Timestamp = currentTimestamp - b.pktInfoSlots[slot].AckCount = uint64(len(ackedPackets)) - b.pktInfoSlots[slot].LossCount = uint64(len(lostPackets)) - } - b.updateAckRate(currentTimestamp) -} - -func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) { - b.maxDatagramSize = size - b.pacer.SetMaxDatagramSize(size) - if b.debug { - b.debugPrint("SetMaxDatagramSize: %d", size) - } -} - -func (b *BrutalSender) updateAckRate(currentTimestamp int64) { - minTimestamp := currentTimestamp - pktInfoSlotCount - var ackCount, lossCount uint64 - for _, info := range b.pktInfoSlots { - if info.Timestamp < minTimestamp { - continue - } - ackCount += info.AckCount - lossCount += info.LossCount - } - if ackCount+lossCount < minSampleCount { - b.ackRate = 1 - if b.canPrintAckRate(currentTimestamp) { - b.lastAckPrintTimestamp = currentTimestamp - b.debugPrint("Not enough samples (total=%d, ack=%d, loss=%d, rtt=%d)", - ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) - } - return - } - rate := float64(ackCount) / float64(ackCount+lossCount) - if rate < minAckRate { - b.ackRate = minAckRate - if b.canPrintAckRate(currentTimestamp) { - b.lastAckPrintTimestamp = currentTimestamp - b.debugPrint("ACK rate too low: %.2f, clamped to %.2f (total=%d, ack=%d, loss=%d, rtt=%d)", - rate, minAckRate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) - } - return - } - b.ackRate = rate - if b.canPrintAckRate(currentTimestamp) { - b.lastAckPrintTimestamp = currentTimestamp - b.debugPrint("ACK rate: %.2f (total=%d, ack=%d, loss=%d, rtt=%d)", - rate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds()) - } -} - -func (b *BrutalSender) InSlowStart() bool { - return false -} - -func (b *BrutalSender) InRecovery() bool { - return false -} - -func (b *BrutalSender) MaybeExitSlowStart() {} - -func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {} - -func (b *BrutalSender) canPrintAckRate(currentTimestamp int64) bool { - return b.debug && currentTimestamp-b.lastAckPrintTimestamp >= debugPrintInterval -} - -func (b *BrutalSender) debugPrint(format string, a ...any) { - fmt.Printf("[BrutalSender] [%s] %s\n", - time.Now().Format("15:04:05"), - fmt.Sprintf(format, a...)) -} diff --git a/protocol/hysteria2/internal/congestion/common/pacer.go b/protocol/hysteria2/internal/congestion/common/pacer.go deleted file mode 100644 index 9d55876..0000000 --- a/protocol/hysteria2/internal/congestion/common/pacer.go +++ /dev/null @@ -1,79 +0,0 @@ -package common - -import ( - "time" - - "github.com/daeuniverse/quic-go/congestion" -) - -const ( - maxBurstPackets = 10 - maxBurstPacingDelayMultiplier = 4 -) - -// Pacer implements a token bucket pacing algorithm. -type Pacer struct { - budgetAtLastSent congestion.ByteCount - maxDatagramSize congestion.ByteCount - lastSentTime time.Time - getBandwidth func() congestion.ByteCount // in bytes/s -} - -func NewPacer(getBandwidth func() congestion.ByteCount) *Pacer { - p := &Pacer{ - budgetAtLastSent: maxBurstPackets * congestion.InitialPacketSizeIPv4, - maxDatagramSize: congestion.InitialPacketSizeIPv4, - getBandwidth: getBandwidth, - } - return p -} - -func (p *Pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) { - budget := p.Budget(sendTime) - if size > budget { - p.budgetAtLastSent = 0 - } else { - p.budgetAtLastSent = budget - size - } - p.lastSentTime = sendTime -} - -func (p *Pacer) Budget(now time.Time) congestion.ByteCount { - if p.lastSentTime.IsZero() { - return p.maxBurstSize() - } - budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 - if budget < 0 { // protect against overflows - budget = congestion.ByteCount(1<<62 - 1) - } - return min(p.maxBurstSize(), budget) -} - -func (p *Pacer) maxBurstSize() congestion.ByteCount { - return max( - congestion.ByteCount((maxBurstPacingDelayMultiplier*congestion.MinPacingDelay).Nanoseconds())*p.getBandwidth()/1e9, - maxBurstPackets*p.maxDatagramSize, - ) -} - -// TimeUntilSend returns when the next packet should be sent. -// It returns the zero value of time.Time if a packet can be sent immediately. -func (p *Pacer) TimeUntilSend() time.Time { - if p.budgetAtLastSent >= p.maxDatagramSize { - return time.Time{} - } - diff := 1e9 * uint64(p.maxDatagramSize-p.budgetAtLastSent) - bw := uint64(p.getBandwidth()) - // We might need to round up this value. - // Otherwise, we might have a budget (slightly) smaller than the datagram size when the timer expires. - d := diff / bw - // this is effectively a math.Ceil, but using only integer math - if diff%bw > 0 { - d++ - } - return p.lastSentTime.Add(max(congestion.MinPacingDelay, time.Duration(d)*time.Nanosecond)) -} - -func (p *Pacer) SetMaxDatagramSize(s congestion.ByteCount) { - p.maxDatagramSize = s -} diff --git a/protocol/hysteria2/internal/congestion/utils.go b/protocol/hysteria2/internal/congestion/utils.go deleted file mode 100644 index 99a562a..0000000 --- a/protocol/hysteria2/internal/congestion/utils.go +++ /dev/null @@ -1,18 +0,0 @@ -package congestion - -import ( - "github.com/daeuniverse/outbound/protocol/hysteria2/internal/congestion/bbr" - "github.com/daeuniverse/outbound/protocol/hysteria2/internal/congestion/brutal" - "github.com/daeuniverse/quic-go" -) - -func UseBBR(conn quic.Connection) { - conn.SetCongestionControl(bbr.NewBbrSender( - bbr.DefaultClock{}, - bbr.GetInitialPacketSize(conn.RemoteAddr()), - )) -} - -func UseBrutal(conn quic.Connection, tx uint64) { - conn.SetCongestionControl(brutal.NewBrutalSender(tx)) -} diff --git a/protocol/tuic/congestion/brutal/brutal.go b/protocol/tuic/congestion/brutal/brutal.go index ae203d0..15276b4 100644 --- a/protocol/tuic/congestion/brutal/brutal.go +++ b/protocol/tuic/congestion/brutal/brutal.go @@ -69,7 +69,7 @@ func (b *BrutalSender) HasPacingBudget(now time.Time) bool { } func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool { - return bytesInFlight < b.GetCongestionWindow() + return bytesInFlight <= b.GetCongestionWindow() } func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount { @@ -77,7 +77,11 @@ func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount { if rtt <= 0 { return 10240 } - return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * congestionWindowMultiplier / b.ackRate) + cwnd := congestion.ByteCount(float64(b.bps) * rtt.Seconds() * congestionWindowMultiplier / b.ackRate) + if cwnd < b.maxDatagramSize { + cwnd = b.maxDatagramSize + } + return cwnd } func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount, From c8bd546e2a2156950bacdacbb614f49cd78f63a2 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Fri, 7 Jun 2024 14:28:52 +0800 Subject: [PATCH 13/13] feat: Add Hysteria2 license claim document --- protocol/hysteria2/README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 protocol/hysteria2/README.md diff --git a/protocol/hysteria2/README.md b/protocol/hysteria2/README.md new file mode 100644 index 0000000..16c97e7 --- /dev/null +++ b/protocol/hysteria2/README.md @@ -0,0 +1,17 @@ +# Hysteria2 + +This part of the code is modified from [`apernet/hysteria`](https://github.com/apernet/hysteria/) with many thanks. + +## License + +See the [LICENSE](/LICENSE) file for license rights and limitations. + +Portions of this code are derived from the Hysteria project under the [MIT License](https://github.com/apernet/hysteria/blob/52c8f82c2ba3172660152b3d6918797dd49eff13/LICENSE.md): + + Copyright 2023 Toby + + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.