Skip to content

Commit

Permalink
SplitHTTP: Fix connection leaks and crashes (XTLS#3710)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmmray authored and zxspirit committed Aug 30, 2024
1 parent c83ce43 commit cfb9474
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 65 deletions.
41 changes: 32 additions & 9 deletions transport/internet/splithttp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
var downResponse io.ReadCloser
gotDownResponse := done.New()

ctx, ctxCancel := context.WithCancel(ctx)

go func() {
trace := &httptrace.ClientTrace{
GotConn: func(connInfo httptrace.GotConnInfo) {
Expand All @@ -61,8 +63,10 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
// in case we hit an error, we want to unblock this part
defer gotConn.Close()

ctx = httptrace.WithClientTrace(ctx, trace)

req, err := http.NewRequestWithContext(
httptrace.WithClientTrace(ctx, trace),
ctx,
"GET",
baseURL,
nil,
Expand Down Expand Up @@ -94,16 +98,17 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
gotDownResponse.Close()
}()

if c.isH3 {
gotConn.Close()
if !c.isH3 {
// in quic-go, sometimes gotConn is never closed for the lifetime of
// the entire connection, and the download locks up
// https://github.com/quic-go/quic-go/issues/3342
// for other HTTP versions, we want to block Dial until we know the
// remote address of the server, for logging purposes
<-gotConn.Wait()
}

// we want to block Dial until we know the remote address of the server,
// for logging purposes
<-gotConn.Wait()

lazyDownload := &LazyReader{
CreateReader: func() (io.ReadCloser, error) {
CreateReader: func() (io.Reader, error) {
<-gotDownResponse.Wait()
if downResponse == nil {
return nil, errors.New("downResponse failed")
Expand All @@ -112,7 +117,15 @@ func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string)
},
}

return lazyDownload, remoteAddr, localAddr, nil
// workaround for https://github.com/quic-go/quic-go/issues/2143 --
// always cancel request context so that Close cancels any Read.
// Should then match the behavior of http2 and http1.
reader := downloadBody{
lazyDownload,
ctxCancel,
}

return reader, remoteAddr, localAddr, nil
}

func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string, payload io.ReadWriteCloser, contentLength int64) error {
Expand Down Expand Up @@ -172,3 +185,13 @@ func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string,

return nil
}

type downloadBody struct {
io.Reader
cancel context.CancelFunc
}

func (c downloadBody) Close() error {
c.cancel()
return nil
}
34 changes: 2 additions & 32 deletions transport/internet/splithttp/dialer.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package splithttp

import (
"bytes"
"context"
gotls "crypto/tls"
"io"
"net/http"
"net/url"
"strconv"
Expand Down Expand Up @@ -292,35 +290,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
return nil, err
}

lazyDownload := &LazyReader{
CreateReader: func() (io.ReadCloser, error) {
// skip "ok" response
trashHeader := []byte{0, 0}
_, err := io.ReadFull(lazyRawDownload, trashHeader)
if err != nil {
return nil, errors.New("failed to read initial response").Base(err)
}

if bytes.Equal(trashHeader, []byte("ok")) {
return lazyRawDownload, nil
}

// we read some garbage byte that may not have been "ok" at
// all. return a reader that replays what we have read so far
reader := io.MultiReader(
bytes.NewReader(trashHeader),
lazyRawDownload,
)
readCloser := struct {
io.Reader
io.Closer
}{
Reader: reader,
Closer: lazyRawDownload,
}
return readCloser, nil
},
}
reader := &stripOkReader{ReadCloser: lazyRawDownload}

writer := uploadWriter{
uploadPipeWriter,
Expand All @@ -329,7 +299,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me

conn := splitConn{
writer: writer,
reader: lazyDownload,
reader: reader,
remoteAddr: remoteAddr,
localAddr: localAddr,
}
Expand Down
6 changes: 5 additions & 1 deletion transport/internet/splithttp/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,12 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
h.ln.addConn(stat.Connection(&conn))

// "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned."
<-downloadDone.Wait()
select {
case <-request.Context().Done():
case <-downloadDone.Wait():
}

conn.Close()
} else {
writer.WriteHeader(http.StatusMethodNotAllowed)
}
Expand Down
26 changes: 7 additions & 19 deletions transport/internet/splithttp/lazy_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@ package splithttp
import (
"io"
"sync"

"github.com/xtls/xray-core/common/errors"
)

// Close is intentionally not supported by LazyReader because it's not clear
// how CreateReader should be aborted in case of Close. It's best to wrap
// LazyReader in another struct that handles Close correctly, or better, stop
// using LazyReader entirely.
type LazyReader struct {
readerSync sync.Mutex
CreateReader func() (io.ReadCloser, error)
reader io.ReadCloser
CreateReader func() (io.Reader, error)
reader io.Reader
readerError error
}

func (r *LazyReader) getReader() (io.ReadCloser, error) {
func (r *LazyReader) getReader() (io.Reader, error) {
r.readerSync.Lock()
defer r.readerSync.Unlock()
if r.reader != nil {
Expand Down Expand Up @@ -43,17 +45,3 @@ func (r *LazyReader) Read(b []byte) (int, error) {
n, err := reader.Read(b)
return n, err
}

func (r *LazyReader) Close() error {
r.readerSync.Lock()
defer r.readerSync.Unlock()

var err error
if r.reader != nil {
err = r.reader.Close()
r.reader = nil
r.readerError = errors.New("closed reader")
}

return err
}
13 changes: 11 additions & 2 deletions transport/internet/splithttp/splithttp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
NextProtocol: []string{"h3"},
},
}

serverClosed := false
listen, err := ListenSH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
go func() {
defer conn.Close()
Expand All @@ -258,10 +260,12 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
for {
b.Clear()
if _, err := b.ReadFrom(conn); err != nil {
return
break
}
common.Must2(conn.Write(b.Bytes()))
}

serverClosed = true
}()
})
common.Must(err)
Expand All @@ -271,7 +275,6 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {

conn, err := Dial(context.Background(), net.UDPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
common.Must(err)
defer conn.Close()

const N = 1024
b1 := make([]byte, N)
Expand All @@ -294,6 +297,12 @@ func Test_listenSHAndDial_QUIC(t *testing.T) {
t.Error(r)
}

conn.Close()
time.Sleep(100 * time.Millisecond)
if !serverClosed {
t.Error("server did not get closed")
}

end := time.Now()
if !end.Before(start.Add(time.Second * 5)) {
t.Error("end: ", end, " start: ", start)
Expand Down
48 changes: 48 additions & 0 deletions transport/internet/splithttp/strip_ok_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package splithttp

import (
"bytes"
"io"

"github.com/xtls/xray-core/common/errors"
)

// in older versions of splithttp, the server would respond with `ok` to flush
// out HTTP response headers early. Response headers and a 200 OK were required
// to initiate the connection. Later versions of splithttp dropped this
// requirement, and in xray 1.8.24 the server stopped sending "ok" if it sees
// x_padding. For compatibility, we need to remove "ok" from the underlying
// reader if it exists, and otherwise forward the stream as-is.
type stripOkReader struct {
io.ReadCloser
firstDone bool
prefixRead []byte
}

func (r *stripOkReader) Read(b []byte) (int, error) {
if !r.firstDone {
r.firstDone = true

// skip "ok" response
prefixRead := []byte{0, 0}
_, err := io.ReadFull(r.ReadCloser, prefixRead)
if err != nil {
return 0, errors.New("failed to read initial response").Base(err)
}

if !bytes.Equal(prefixRead, []byte("ok")) {
// we read some garbage byte that may not have been "ok" at
// all. return a reader that replays what we have read so far
r.prefixRead = prefixRead
}
}

if len(r.prefixRead) > 0 {
n := copy(b, r.prefixRead)
r.prefixRead = r.prefixRead[n:]
return n, nil
}

n, err := r.ReadCloser.Read(b)
return n, err
}
6 changes: 4 additions & 2 deletions transport/internet/splithttp/upload_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ func (h *uploadQueue) Close() error {
h.writeCloseMutex.Lock()
defer h.writeCloseMutex.Unlock()

h.closed = true
close(h.pushedPackets)
if !h.closed {
h.closed = true
close(h.pushedPackets)
}
return nil
}

Expand Down

0 comments on commit cfb9474

Please sign in to comment.