diff --git a/http2/transport.go b/http2/transport.go index 91f4370ccf..30f706e6cb 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -118,6 +118,15 @@ type Transport struct { // to mean no limit. MaxHeaderListSize uint32 + // MaxReadFrameSize is the http2 SETTINGS_MAX_FRAME_SIZE to send in the + // initial settings frame. It is the size in bytes of the largest frame + // payload that the sender is willing to receive. If 0, no setting is + // sent, and the value is provided by the peer, which should be 16384 + // according to the spec: + // https://datatracker.ietf.org/doc/html/rfc7540#section-6.5.2. + // Values are bounded in the range 16k to 16M. + MaxReadFrameSize uint32 + // MaxDecoderHeaderTableSize optionally specifies the http2 // SETTINGS_HEADER_TABLE_SIZE to send in the initial settings frame. It // informs the remote endpoint of the maximum size of the header compression @@ -184,6 +193,19 @@ func (t *Transport) maxHeaderListSize() uint32 { return t.MaxHeaderListSize } +func (t *Transport) maxFrameReadSize() uint32 { + if t.MaxReadFrameSize == 0 { + return 0 // use the default provided by the peer + } + if t.MaxReadFrameSize < minMaxFrameSize { + return minMaxFrameSize + } + if t.MaxReadFrameSize > maxFrameSize { + return maxFrameSize + } + return t.MaxReadFrameSize +} + func (t *Transport) disableCompression() bool { return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression) } @@ -749,6 +771,9 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro }) cc.br = bufio.NewReader(c) cc.fr = NewFramer(cc.bw, cc.br) + if t.maxFrameReadSize() != 0 { + cc.fr.SetMaxReadFrameSize(t.maxFrameReadSize()) + } if t.CountError != nil { cc.fr.countError = t.CountError } @@ -773,6 +798,9 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro {ID: SettingEnablePush, Val: 0}, {ID: SettingInitialWindowSize, Val: transportDefaultStreamFlow}, } + if max := t.maxFrameReadSize(); max != 0 { + initialSettings = append(initialSettings, Setting{ID: SettingMaxFrameSize, Val: max}) + } if max := t.maxHeaderListSize(); max != 0 { initialSettings = append(initialSettings, Setting{ID: SettingMaxHeaderListSize, Val: max}) } diff --git a/http2/transport_test.go b/http2/transport_test.go index ee852b6198..42d7dbc5e7 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -3998,6 +3998,64 @@ func TestTransportResponseDataBeforeHeaders(t *testing.T) { ct.run() } +// Test that the server received values are in range. +func TestTransportMaxFrameReadSize(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + }, func(s *Server) { + s.MaxConcurrentStreams = 1 + }) + defer st.Close() + tr := &Transport{ + TLSClientConfig: tlsConfigInsecure, + MaxReadFrameSize: 64000, + } + defer tr.CloseIdleConnections() + + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + if got, want := res.StatusCode, 200; got != want { + t.Errorf("StatusCode = %v; want %v", got, want) + } + if res != nil && res.Body != nil { + res.Body.Close() + } + // Test that maxFrameReadSize() matches the requested size. + if got, want := tr.maxFrameReadSize(), uint32(64000); got != want { + t.Errorf("maxFrameReadSize = %v; want %v", got, want) + } + // Test that server receives the requested size. + if got, want := st.sc.maxFrameSize, int32(64000); got != want { + t.Errorf("maxFrameReadSize = %v; want %v", got, want) + } + + // test for minimum frame read size + tr = &Transport{ + TLSClientConfig: tlsConfigInsecure, + MaxReadFrameSize: 1024, + } + + if got, want := tr.maxFrameReadSize(), uint32(minMaxFrameSize); got != want { + t.Errorf("maxFrameReadSize = %v; want %v", got, want) + } + + // test for maximum frame size + tr = &Transport{ + TLSClientConfig: tlsConfigInsecure, + MaxReadFrameSize: 1024 * 1024 * 16, + } + + if got, want := tr.maxFrameReadSize(), uint32(maxFrameSize); got != want { + t.Errorf("maxFrameReadSize = %v; want %v", got, want) + } + +} + func TestTransportRequestsLowServerLimit(t *testing.T) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { }, optOnlyServer, func(s *Server) { @@ -4608,6 +4666,61 @@ func BenchmarkClientResponseHeaders(b *testing.B) { b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 1000) }) } +func BenchmarkDownloadFrameSize(b *testing.B) { + b.Run(" 16k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 16*1024) }) + b.Run(" 64k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 64*1024) }) + b.Run("128k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 128*1024) }) + b.Run("256k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 256*1024) }) + b.Run("512k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 512*1024) }) +} +func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) { + defer disableGoroutineTracking()() + const transferSize = 1024 * 1024 * 1024 // must be multiple of 1M + b.ReportAllocs() + st := newServerTester(b, + func(w http.ResponseWriter, r *http.Request) { + // test 1GB transfer + w.Header().Set("Content-Length", strconv.Itoa(transferSize)) + w.Header().Set("Content-Transfer-Encoding", "binary") + var data [1024 * 1024]byte + for i := 0; i < transferSize/(1024*1024); i++ { + w.Write(data[:]) + } + }, optQuiet, + ) + defer st.Close() + + tr := &Transport{TLSClientConfig: tlsConfigInsecure, MaxReadFrameSize: frameSize} + defer tr.CloseIdleConnections() + + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + b.Fatal(err) + } + + b.N = 3 + b.SetBytes(transferSize) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + res, err := tr.RoundTrip(req) + if err != nil { + if res != nil { + res.Body.Close() + } + b.Fatalf("RoundTrip err = %v; want nil", err) + } + data, _ := io.ReadAll(res.Body) + if len(data) != transferSize { + b.Fatalf("Response length invalid") + } + res.Body.Close() + if res.StatusCode != http.StatusOK { + b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK) + } + } +} + func activeStreams(cc *ClientConn) int { count := 0 cc.mu.Lock()