Skip to content

Commit

Permalink
Ensuring that compression readers are ensuring the underlying buffer are
Browse files Browse the repository at this point in the history
closed
  • Loading branch information
MovieStoreGuy committed Sep 11, 2024
1 parent 970dcfa commit 5cf5312
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 31 deletions.
34 changes: 5 additions & 29 deletions config/confighttp/compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"bytes"
"compress/gzip"
"compress/zlib"
"errors"
"fmt"
"io"
"net/http"
Expand All @@ -33,11 +32,7 @@ var availableDecoders = map[string]func(body io.ReadCloser) (io.ReadCloser, erro
return nil, nil
},
"gzip": func(body io.ReadCloser) (io.ReadCloser, error) {
gr, err := gzip.NewReader(body)
if err != nil {
return nil, err
}
return gr, nil
return gzip.NewReader(body)
},
"zstd": func(body io.ReadCloser) (io.ReadCloser, error) {
zr, err := zstd.NewReader(
Expand All @@ -54,32 +49,14 @@ var availableDecoders = map[string]func(body io.ReadCloser) (io.ReadCloser, erro
return zr.IOReadCloser(), nil
},
"zlib": func(body io.ReadCloser) (io.ReadCloser, error) {
zr, err := zlib.NewReader(body)
if err != nil {
return nil, err
}
return zr, nil
return zlib.NewReader(body)
},
"snappy": func(body io.ReadCloser) (io.ReadCloser, error) {
sr := snappy.NewReader(body)
sb := new(bytes.Buffer)
_, err := io.Copy(sb, sr)
if err != nil {
return nil, err
}
if err = body.Close(); err != nil {
return nil, err
}
return io.NopCloser(sb), nil
return newCompressionReader(snappy.NewReader, body), nil
},
"lz4": func(body io.ReadCloser) (io.ReadCloser, error) {
lz := lz4.NewReader(body)
buf := new(bytes.Buffer)
_, err := io.Copy(buf, lz)
if err = errors.Join(err, body.Close()); err != nil {
return nil, err
}
return io.NopCloser(buf), nil
return newCompressionReader(lz4.NewReader, body), nil

},
}

Expand Down Expand Up @@ -134,7 +111,6 @@ type decompressor struct {
// httpContentDecompressor offloads the task of handling compressed HTTP requests
// by identifying the compression format in the "Content-Encoding" header and re-writing
// request body so that the handlers further in the chain can work on decompressed data.
// It supports gzip and deflate/zlib compression.
func httpContentDecompressor(h http.Handler, maxRequestBodySize int64, eh func(w http.ResponseWriter, r *http.Request, errorMsg string, statusCode int), enableDecoders []string, decoders map[string]func(body io.ReadCloser) (io.ReadCloser, error)) http.Handler {
errHandler := defaultErrorHandler
if eh != nil {
Expand Down
36 changes: 36 additions & 0 deletions config/confighttp/compression_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package confighttp // import "go.opentelemetry.io/collector/config/confighttp"

import (
"errors"
"io"
)

type compressionReader[R io.Reader] struct {
io.Reader
orig io.ReadCloser
}

var (
_ io.Reader = (*compressionReader[io.Reader])(nil)
_ io.Closer = (*compressionReader[io.Reader])(nil)
)

// newCompressionReader is used to couple the original underlying buffer and
// the compression reader to allow for close operations to correctly
// free up the underlying buffer that was provided by the original reader
func newCompressionReader[R io.Reader](method func(io.Reader) R, orig io.ReadCloser) io.ReadCloser {
return &compressionReader[R]{
Reader: method(orig),
orig: orig,
}
}

func (cr *compressionReader[R]) Close() error {
// taking the original compressed buffer and discarding it
// to ensure the underlying buffers are released
_, err := io.Copy(io.Discard, cr.orig)
return errors.Join(err, cr.orig.Close())
}
69 changes: 69 additions & 0 deletions config/confighttp/compression_reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package confighttp

import (
"bytes"
"io"
"testing"
"testing/iotest"

"github.com/golang/snappy"
"github.com/pierrec/lz4/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestCompressionReaders(t *testing.T) {
t.Parallel()

for _, tc := range []struct {
name string
compress func(testing.TB, []byte) *bytes.Buffer
createFn func(orig io.Reader) (io.ReadCloser, error)
}{
{
name: "no compression",
compress: func(_ testing.TB, b []byte) *bytes.Buffer {
return bytes.NewBuffer(b)
},
createFn: func(orig io.Reader) (io.ReadCloser, error) {
return newCompressionReader(
func(r io.Reader) io.Reader {
return r
},
io.NopCloser(orig),
), nil
},
},
{
name: "snappy",
compress: compressSnappy,
createFn: func(orig io.Reader) (io.ReadCloser, error) {
return newCompressionReader(snappy.NewReader, io.NopCloser(orig)), nil
},
},
{
name: "lz4",
compress: compressLz4,
createFn: func(orig io.Reader) (io.ReadCloser, error) {
return newCompressionReader(lz4.NewReader, io.NopCloser(orig)), nil
},
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
data := "hello world"
payload := tc.compress(t, []byte(data))

rc, err := tc.createFn(payload)
require.NoError(t, err, "Must not error creating compression reader")

assert.NoError(t, iotest.TestReader(rc, []byte(data)), "Must pass the expected reader test")
assert.NoError(t, rc.Close(), "Must not error on close")
assert.Zero(t, payload.Len(), "Must have consumed entire payload buffer")
})
}
}
74 changes: 72 additions & 2 deletions config/confighttp/compression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,14 @@ func TestHTTPContentDecompressionHandler(t *testing.T) {
encoding: "snappy",
reqBody: bytes.NewBuffer(testBody),
respCode: http.StatusBadRequest,
respBody: "snappy: corrupt input\n",
respBody: "snappy: corrupt input",
},
{
name: "InvalidLz4",
encoding: "lz4",
reqBody: bytes.NewBuffer(testBody),
respCode: http.StatusBadRequest,
respBody: "lz4: bad magic number\n",
respBody: "lz4: bad magic number",
},
{
name: "UnsupportedCompression",
Expand Down Expand Up @@ -395,6 +395,76 @@ func TestOverrideCompressionList(t *testing.T) {
require.NoError(t, res.Body.Close(), "failed to close request body: %v", err)
}

func TestDecompressorAvoidDecompressionBomb(t *testing.T) {
t.Parallel()

for _, tc := range []struct {
name string
encoding string
compress func(tb testing.TB, payload []byte) *bytes.Buffer
}{
// None encoding is ignored since it does not
// enforce the max body size if content encoding header is not set
{
name: "gzip",
encoding: "gzip",
compress: compressGzip,
},
{
name: "zstd",
encoding: "zstd",
compress: compressZstd,
},
{
name: "zlib",
encoding: "zlib",
compress: compressZlib,
},
{
name: "snappy",
encoding: "snappy",
compress: compressSnappy,
},
{
name: "lz4",
encoding: "lz4",
compress: compressLz4,
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

h := httpContentDecompressor(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
n, err := io.Copy(io.Discard, r.Body)
assert.Equal(t, int64(1024), n, "Must have only read the limited value of bytes")
assert.EqualError(t, err, "http: request body too large")
w.WriteHeader(http.StatusBadRequest)
}),
1024,
defaultErrorHandler,
defaultCompressionAlgorithms,
availableDecoders,
)

payload := tc.compress(t, make([]byte, 2*1024)) // 2KB uncompressed payload
assert.NotEmpty(t, payload.Bytes(), "Must have data available")

req := httptest.NewRequest(http.MethodPost, "/", payload)
req.Header.Set("Content-Encoding", tc.encoding)

resp := httptest.NewRecorder()

h.ServeHTTP(resp, req)

assert.Equal(t, http.StatusBadRequest, resp.Code, "Must match the expected code")
assert.Empty(t, resp.Body.String(), "Must match the returned string")
assert.Empty(t, payload.Bytes(), "Must have consumed original payload")
})
}
}

func compressGzip(t testing.TB, body []byte) *bytes.Buffer {
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
Expand Down

0 comments on commit 5cf5312

Please sign in to comment.