Skip to content

Commit

Permalink
blob: pass through reader/writer to WriteTo/ReadFrom if available (
Browse files Browse the repository at this point in the history
  • Loading branch information
HippoBaro authored Jun 27, 2023
1 parent adb7ff5 commit 8385fc3
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 5 deletions.
12 changes: 12 additions & 0 deletions blob/blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,12 @@ func (r *Reader) As(i interface{}) bool {
//
// It implements the io.WriterTo interface.
func (r *Reader) WriteTo(w io.Writer) (int64, error) {
// If the writer has a ReaderFrom method, use it to do the copy.
// Avoids an allocation and a copy.
if rt, ok := w.(io.ReaderFrom); ok {
return rt.ReadFrom(r)
}

_, nw, err := readFromWriteTo(r, w)
return nw, err
}
Expand Down Expand Up @@ -476,6 +482,12 @@ func (w *Writer) write(p []byte) (int, error) {
//
// It implements the io.ReaderFrom interface.
func (w *Writer) ReadFrom(r io.Reader) (int64, error) {
// If the reader has a WriteTo method, use it to do the copy.
// Avoids an allocation and a copy.
if wt, ok := r.(io.WriterTo); ok {
return wt.WriteTo(w)
}

nr, _, err := readFromWriteTo(r, w)
return nr, err
}
Expand Down
25 changes: 22 additions & 3 deletions blob/blob_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
package blob_test

import (
"bytes"
"context"
"io"
"testing"
"testing/iotest"

Expand All @@ -41,12 +43,29 @@ func TestReader(t *testing.T) {
bucket.WriteAll(ctx, myKey, data, nil)

// Create a blob.Reader.
r, err := bucket.NewReader(ctx, myKey, nil)
r1, err := bucket.NewReader(ctx, myKey, nil)
if err != nil {
t.Fatal(err)
}
defer r.Close()
if err := iotest.TestReader(r, data); err != nil {
r1.Close()
if err := iotest.TestReader(r1, data); err != nil {
t.Error(err)
}

// Create another blob.Reader to exercise the ReadFrom code path
r2, err := bucket.NewReader(ctx, myKey, nil)
if err != nil {
t.Fatal(err)
}
defer r2.Close()

var buffer bytes.Buffer
n, err := io.Copy(&buffer, r2)
if err != nil {
t.Fatal(err)
} else if n != int64(len(data)) {
t.Fatal("wrote fewer bytes than expected")
} else if !bytes.Equal(buffer.Bytes(), data) {
t.Fatal("wrote invalid bytes")
}
}
13 changes: 11 additions & 2 deletions blob/drivertest/drivertest.go
Original file line number Diff line number Diff line change
Expand Up @@ -2624,12 +2624,21 @@ func benchmarkRead(b *testing.B, bkt *blob.Bucket) {

b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
var buffer bytes.Buffer
buffer.Grow(len(content))

for pb.Next() {
buf, err := bkt.ReadAll(ctx, key)
buffer.Reset()
r, err := bkt.NewReader(ctx, key, nil)
if err != nil {
b.Error(err)
}
if !bytes.Equal(buf, content) {

if _, err = io.Copy(&buffer, r); err != nil {
b.Error(err)
}
r.Close()
if !bytes.Equal(buffer.Bytes(), content) {
b.Error("read didn't match write")
}
}
Expand Down

0 comments on commit 8385fc3

Please sign in to comment.