Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: merge go1.17.4 and add factory for uTLS proxy #16

Merged
merged 5 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -966,7 +966,6 @@ func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
if err == nil {
return n, nil
}
b.stop()
if err == io.EOF {
return n, err
}
Expand Down
27 changes: 27 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1354,6 +1354,33 @@ func TestClientTimeoutCancel(t *testing.T) {
}
}

func TestClientTimeoutDoesNotExpire_h1(t *testing.T) { testClientTimeoutDoesNotExpire(t, h1Mode) }
func TestClientTimeoutDoesNotExpire_h2(t *testing.T) { testClientTimeoutDoesNotExpire(t, h2Mode) }

// Issue 49366: if Client.Timeout is set but not hit, no error should be returned.
func testClientTimeoutDoesNotExpire(t *testing.T, h2 bool) {
setParallel(t)
defer afterTest(t)

cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
w.Write([]byte("body"))
}))
defer cst.close()

cst.c.Timeout = 1 * time.Hour
req, _ := NewRequest("GET", cst.ts.URL, nil)
res, err := cst.c.Do(req)
if err != nil {
t.Fatal(err)
}
if _, err = io.Copy(io.Discard, res.Body); err != nil {
t.Fatalf("io.Copy(io.Discard, res.Body) = %v, want nil", err)
}
if err = res.Body.Close(); err != nil {
t.Fatalf("res.Body.Close() = %v, want nil", err)
}
}

func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) }
func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) }
func testClientRedirectEatsBody(t *testing.T, h2 bool) {
Expand Down
34 changes: 34 additions & 0 deletions clientserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1583,3 +1583,37 @@ func TestH12_WebSocketUpgrade(t *testing.T) {
},
}.run(t)
}

func TestIdentityTransferEncoding_h1(t *testing.T) { testIdentityTransferEncoding(t, h1Mode) }
func TestIdentityTransferEncoding_h2(t *testing.T) { testIdentityTransferEncoding(t, h2Mode) }

func testIdentityTransferEncoding(t *testing.T, h2 bool) {
setParallel(t)
defer afterTest(t)

const body = "body"
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
gotBody, _ := io.ReadAll(r.Body)
if got, want := string(gotBody), body; got != want {
t.Errorf("got request body = %q; want %q", got, want)
}
w.Header().Set("Transfer-Encoding", "identity")
w.WriteHeader(StatusOK)
w.(Flusher).Flush()
io.WriteString(w, body)
}))
defer cst.close()
req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body))
res, err := cst.c.Do(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
gotBody, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if got, want := string(gotBody), body; got != want {
t.Errorf("got response body = %q; want %q", got, want)
}
}
81 changes: 49 additions & 32 deletions h2_bundle.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1404,11 +1404,11 @@ func (cw *chunkWriter) writeHeader(p []byte) {
hasCL = false
}

if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) {
// do nothing
} else if code == StatusNoContent {
if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) || code == StatusNoContent {
// Response has no body.
delHeader("Transfer-Encoding")
} else if hasCL {
// Content-Length has been provided, so no chunking is to be done.
delHeader("Transfer-Encoding")
} else if w.req.ProtoAtLeast(1, 1) {
// HTTP/1.1 or greater: Transfer-Encoding has been set to identity, and no
Expand All @@ -1419,6 +1419,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
if hasTE && te == "identity" {
cw.chunking = false
w.closeAfterReply = true
delHeader("Transfer-Encoding")
} else {
// HTTP/1.1 or greater: use chunked transfer encoding
// to avoid closing the connection at EOF.
Expand Down
6 changes: 6 additions & 0 deletions tlsconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ type TLSConn interface {
// in time by the given context.
HandshakeContext(ctx context.Context) error
}

// TLSClientFactory is the factory used when creating connections
// using a proxy inside of the HTTP library. By default, this is
// the tls.Client function. You'll need to override this factory if
// you want to use refraction-networking/utls for proxied conns.
var TLSClientFactory = tls.Client
2 changes: 1 addition & 1 deletion transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -1519,7 +1519,7 @@ func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptr
cfg.NextProtos = nil
}
plainConn := pconn.conn
tlsConn := tls.Client(plainConn, cfg)
tlsConn := TLSClientFactory(plainConn, cfg)
errc := make(chan error, 2)
var timer *time.Timer // for canceling TLS handshake
if d := pconn.t.TLSHandshakeTimeout; d != 0 {
Expand Down