Skip to content

Commit

Permalink
compress/flate: return error on closed stream write
Browse files Browse the repository at this point in the history
Previously flate.Writer allowed writes after Close, and this behavior
could lead to stream corruption.

Fixes #27741

Change-Id: Iee1ac69f8199232f693dba77b275f7078257b582
Reviewed-on: https://go-review.googlesource.com/c/go/+/136475
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
Auto-Submit: Ian Lance Taylor <iant@google.com>
Reviewed-by: Joseph Tsai <joetsai@digital-static.net>
Reviewed-by: Carlos Amedee <carlos@golang.org>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@google.com>
  • Loading branch information
gregory-m authored and gopherbot committed May 2, 2022
1 parent fdf1d76 commit 2719b1a
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 4 deletions.
36 changes: 33 additions & 3 deletions src/compress/flate/deflate.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package flate

import (
"errors"
"fmt"
"io"
"math"
Expand Down Expand Up @@ -699,17 +700,27 @@ func (w *dictWriter) Write(b []byte) (n int, err error) {
return w.w.Write(b)
}

var errWriteAfterClose = errors.New("compress/flate: write after close")

// A Writer takes data written to it and writes the compressed
// form of that data to an underlying writer (see NewWriter).
type Writer struct {
d compressor
dict []byte
err error
}

// Write writes data to w, which will eventually write the
// compressed form of data to its underlying writer.
func (w *Writer) Write(data []byte) (n int, err error) {
return w.d.write(data)
if w.err != nil {
return 0, w.err
}
n, err = w.d.write(data)
if err != nil {
w.err = err
}
return n, err
}

// Flush flushes any pending data to the underlying writer.
Expand All @@ -724,18 +735,37 @@ func (w *Writer) Write(data []byte) (n int, err error) {
func (w *Writer) Flush() error {
// For more about flushing:
// https://www.bolet.org/~pornin/deflate-flush.html
return w.d.syncFlush()
if w.err != nil {
return w.err
}
if err := w.d.syncFlush(); err != nil {
w.err = err
return err
}
return nil
}

// Close flushes and closes the writer.
func (w *Writer) Close() error {
return w.d.close()
if w.err == errWriteAfterClose {
return nil
}
if w.err != nil {
return w.err
}
if err := w.d.close(); err != nil {
w.err = err
return err
}
w.err = errWriteAfterClose
return nil
}

// Reset discards the writer's state and makes it equivalent to
// the result of NewWriter or NewWriterDict called with dst
// and w's level and dictionary.
func (w *Writer) Reset(dst io.Writer) {
w.err = nil
if dw, ok := w.d.w.writer.(*dictWriter); ok {
// w was created with NewWriterDict
dw.w = dst
Expand Down
88 changes: 87 additions & 1 deletion src/compress/flate/deflate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,40 @@ func TestDeflate(t *testing.T) {
}
}

func TestWriterClose(t *testing.T) {
b := new(bytes.Buffer)
zw, err := NewWriter(b, 6)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}

if c, err := zw.Write([]byte("Test")); err != nil || c != 4 {
t.Fatalf("Write to not closed writer: %s, %d", err, c)
}

if err := zw.Close(); err != nil {
t.Fatalf("Close: %v", err)
}

afterClose := b.Len()

if c, err := zw.Write([]byte("Test")); err == nil || c != 0 {
t.Fatalf("Write to closed writer: %s, %d", err, c)
}

if err := zw.Flush(); err == nil {
t.Fatalf("Flush to closed writer: %s", err)
}

if err := zw.Close(); err != nil {
t.Fatalf("Close: %v", err)
}

if afterClose != b.Len() {
t.Fatalf("Writer wrote data after close. After close: %d. After writes on closed stream: %d", afterClose, b.Len())
}
}

// A sparseReader returns a stream consisting of 0s followed by 1<<16 1s.
// This tests missing hash references in a very large input.
type sparseReader struct {
Expand Down Expand Up @@ -683,7 +717,7 @@ func (w *failWriter) Write(b []byte) (int, error) {
return len(b), nil
}

func TestWriterPersistentError(t *testing.T) {
func TestWriterPersistentWriteError(t *testing.T) {
t.Parallel()
d, err := os.ReadFile("../../testdata/Isaac.Newton-Opticks.txt")
if err != nil {
Expand All @@ -706,19 +740,71 @@ func TestWriterPersistentError(t *testing.T) {

_, werr := zw.Write(d)
cerr := zw.Close()
ferr := zw.Flush()
if werr != errIO && werr != nil {
t.Errorf("test %d, mismatching Write error: got %v, want %v", i, werr, errIO)
}
if cerr != errIO && fw.n < 0 {
t.Errorf("test %d, mismatching Close error: got %v, want %v", i, cerr, errIO)
}
if ferr != errIO && fw.n < 0 {
t.Errorf("test %d, mismatching Flush error: got %v, want %v", i, ferr, errIO)
}
if fw.n >= 0 {
// At this point, the failure threshold was sufficiently high enough
// that we wrote the whole stream without any errors.
return
}
}
}
func TestWriterPersistentFlushError(t *testing.T) {
zw, err := NewWriter(&failWriter{0}, DefaultCompression)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
flushErr := zw.Flush()
closeErr := zw.Close()
_, writeErr := zw.Write([]byte("Test"))
checkErrors([]error{closeErr, flushErr, writeErr}, errIO, t)
}

func TestWriterPersistentCloseError(t *testing.T) {
// If underlying writer return error on closing stream we should persistent this error across all writer calls.
zw, err := NewWriter(&failWriter{0}, DefaultCompression)
if err != nil {
t.Fatalf("NewWriter: %v", err)
}
closeErr := zw.Close()
flushErr := zw.Flush()
_, writeErr := zw.Write([]byte("Test"))
checkErrors([]error{closeErr, flushErr, writeErr}, errIO, t)

// After closing writer we should persistent "write after close" error across Flush and Write calls, but return nil
// on next Close calls.
var b bytes.Buffer
zw.Reset(&b)
err = zw.Close()
if err != nil {
t.Fatalf("First call to close returned error: %s", err)
}
err = zw.Close()
if err != nil {
t.Fatalf("Second call to close returned error: %s", err)
}

flushErr = zw.Flush()
_, writeErr = zw.Write([]byte("Test"))
checkErrors([]error{flushErr, writeErr}, errWriteAfterClose, t)
}

func checkErrors(got []error, want error, t *testing.T) {
t.Helper()
for _, err := range got {
if err != want {
t.Errorf("Errors dosn't match\nWant: %s\nGot: %s", want, got)
}
}
}

func TestBestSpeedMatch(t *testing.T) {
t.Parallel()
Expand Down

0 comments on commit 2719b1a

Please sign in to comment.