diff --git a/gzip/gunzip.go b/gzip/gunzip.go index 66fe5ddf72..6d630c390d 100644 --- a/gzip/gunzip.go +++ b/gzip/gunzip.go @@ -288,10 +288,35 @@ func (z *Reader) Read(p []byte) (n int, err error) { return n, nil } -// Support the io.WriteTo interface for io.Copy and friends. +type crcer interface { + io.Writer + Sum32() uint32 + Reset() +} +type crcUpdater struct { + z *Reader +} + +func (c *crcUpdater) Write(p []byte) (int, error) { + c.z.digest = crc32.Update(c.z.digest, crc32.IEEETable, p) + return len(p), nil +} + +func (c *crcUpdater) Sum32() uint32 { + return c.z.digest +} + +func (c *crcUpdater) Reset() { + c.z.digest = 0 +} + +// WriteTo support the io.WriteTo interface for io.Copy and friends. func (z *Reader) WriteTo(w io.Writer) (int64, error) { total := int64(0) - crcWriter := crc32.NewIEEE() + crcWriter := crcer(crc32.NewIEEE()) + if z.digest != 0 { + crcWriter = &crcUpdater{z: z} + } for { if z.err != nil { if z.err == io.EOF { diff --git a/gzip/gunzip_test.go b/gzip/gunzip_test.go index d11325c285..399149a4b4 100644 --- a/gzip/gunzip_test.go +++ b/gzip/gunzip_test.go @@ -5,6 +5,7 @@ package gzip import ( + "bufio" "bytes" oldgz "compress/gzip" "crypto/rand" @@ -537,6 +538,7 @@ func TestWriteTo(t *testing.T) { t.Fatal(err) } wtbuf := &bytes.Buffer{} + written, err := dec.WriteTo(wtbuf) if err != nil { t.Fatal(err) @@ -708,3 +710,43 @@ func TestTruncatedGunzip(t *testing.T) { } } } + +func TestBufferedPartialCopyGzip(t *testing.T) { + var ( + in = []byte("hello\nworld") + compressedIn []byte + ) + + var buf bytes.Buffer + gzw := NewWriter(&buf) + if _, err := gzw.Write(in); err != nil { + panic(err) + } + if err := gzw.Flush(); err != nil { + panic(err) + } + if err := gzw.Close(); err != nil { + panic(err) + } + + compressedIn = buf.Bytes() + + gz, err := NewReader(bytes.NewReader(compressedIn)) + if err != nil { + t.Errorf("constructing a reader: %v", err) + } + + br := bufio.NewReader(gz) + if _, err := br.ReadBytes('\n'); err != nil { + t.Errorf("reading to the first newline: %v", err) + } + + var out bytes.Buffer + _, err = io.Copy(&out, br) + if !bytes.Equal(out.Bytes(), []byte("world")) { + t.Errorf("unexpected output when reading the remainder") + } + if err != nil { + t.Errorf("reading the remainder: %v", err) + } +}