From c1b787b66651943be1cc411948b43d09e4e306c2 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Sun, 5 Feb 2023 13:56:10 +0100 Subject: [PATCH] s2: Support ReadAt in ReadSeeker Also simplifies seeking. --- s2/decode.go | 121 +++++++++++++++++++++++++++++++++++----------- s2/encode_test.go | 31 ++++++++++++ s2/index_test.go | 60 +++++++++++++++++++++++ 3 files changed, 183 insertions(+), 29 deletions(-) diff --git a/s2/decode.go b/s2/decode.go index 00c5cc72c2..a89db8f96a 100644 --- a/s2/decode.go +++ b/s2/decode.go @@ -880,15 +880,20 @@ func (r *Reader) Skip(n int64) error { // See Reader.ReadSeeker type ReadSeeker struct { *Reader + readAtMu sync.Mutex } -// ReadSeeker will return an io.ReadSeeker compatible version of the reader. +// ReadSeeker will return an io.ReadSeeker and io.ReaderAt +// compatible version of the reader. // If 'random' is specified the returned io.Seeker can be used for // random seeking, otherwise only forward seeking is supported. // Enabling random seeking requires the original input to support // the io.Seeker interface. // A custom index can be specified which will be used if supplied. // When using a custom index, it will not be read from the input stream. +// The ReadAt position will affect regular reads and the current position of Seek. +// So using Read after ReadAt will continue from where the ReadAt stopped. +// No functions should be used concurrently. // The returned ReadSeeker contains a shallow reference to the existing Reader, // meaning changes performed to one is reflected in the other. func (r *Reader) ReadSeeker(random bool, index []byte) (*ReadSeeker, error) { @@ -958,42 +963,55 @@ func (r *ReadSeeker) Seek(offset int64, whence int) (int64, error) { // Reset on EOF r.err = nil } - if offset == 0 && whence == io.SeekCurrent { - return r.blockStart + int64(r.i), nil + + // Calculate absolute offset. + absOffset := offset + + switch whence { + case io.SeekStart: + case io.SeekCurrent: + absOffset = r.blockStart + int64(r.i) + offset + case io.SeekEnd: + if r.index == nil { + return 0, ErrUnsupported + } + absOffset = r.index.TotalUncompressed + offset + default: + r.err = ErrUnsupported + return 0, r.err + } + + if absOffset < 0 { + return 0, errors.New("seek before start of file") } + if !r.readHeader { // Make sure we read the header. _, r.err = r.Read([]byte{}) + if r.err != nil { + return 0, r.err + } } + + // If we are inside current block no need to seek. + // This includes no offset changes. + if absOffset >= r.blockStart && absOffset < r.blockStart+int64(r.j) { + r.i = int(absOffset - r.blockStart) + return r.blockStart + int64(r.i), nil + } + rs, ok := r.r.(io.ReadSeeker) if r.index == nil || !ok { - if whence == io.SeekCurrent && offset >= 0 { - err := r.Skip(offset) - return r.blockStart + int64(r.i), err - } - if whence == io.SeekStart && offset >= r.blockStart+int64(r.i) { - err := r.Skip(offset - r.blockStart - int64(r.i)) + currOffset := r.blockStart + int64(r.i) + if absOffset >= currOffset { + err := r.Skip(absOffset - currOffset) return r.blockStart + int64(r.i), err } return 0, ErrUnsupported - } - switch whence { - case io.SeekCurrent: - offset += r.blockStart + int64(r.i) - case io.SeekEnd: - if offset > 0 { - return 0, errors.New("seek after end of file") - } - offset = r.index.TotalUncompressed + offset - } - - if offset < 0 { - return 0, errors.New("seek before start of file") - } - - c, u, err := r.index.Find(offset) + // We can seek and we have an index. + c, u, err := r.index.Find(absOffset) if err != nil { return r.blockStart + int64(r.i), err } @@ -1004,12 +1022,57 @@ func (r *ReadSeeker) Seek(offset int64, whence int) (int64, error) { return 0, err } - r.i = r.j // Remove rest of current block. - if u < offset { + r.i = r.j // Remove rest of current block. + r.blockStart = u - int64(r.j) // Adjust current block start for accounting. + if u < absOffset { // Forward inside block - return offset, r.Skip(offset - u) + return absOffset, r.Skip(absOffset - u) + } + if u > absOffset { + return 0, fmt.Errorf("s2 seek: (internal error) u (%d) > absOffset (%d)", u, absOffset) + } + return absOffset, nil +} + +// ReadAt reads len(p) bytes into p starting at offset off in the +// underlying input source. It returns the number of bytes +// read (0 <= n <= len(p)) and any error encountered. +// +// When ReadAt returns n < len(p), it returns a non-nil error +// explaining why more bytes were not returned. In this respect, +// ReadAt is stricter than Read. +// +// Even if ReadAt returns n < len(p), it may use all of p as scratch +// space during the call. If some data is available but not len(p) bytes, +// ReadAt blocks until either all the data is available or an error occurs. +// In this respect ReadAt is different from Read. +// +// If the n = len(p) bytes returned by ReadAt are at the end of the +// input source, ReadAt may return either err == EOF or err == nil. +// +// If ReadAt is reading from an input source with a seek offset, +// ReadAt should not affect nor be affected by the underlying +// seek offset. +// +// Clients of ReadAt can execute parallel ReadAt calls on the +// same input source. This is however not recommended. +func (r *ReadSeeker) ReadAt(p []byte, offset int64) (int, error) { + r.readAtMu.Lock() + defer r.readAtMu.Unlock() + _, err := r.Seek(offset, io.SeekStart) + if err != nil { + return 0, err + } + n := 0 + for n < len(p) { + n2, err := r.Read(p[n:]) + if err != nil { + // This will include io.EOF + return n + n2, err + } + n += n2 } - return offset, nil + return n, nil } // ReadByte satisfies the io.ByteReader interface. diff --git a/s2/encode_test.go b/s2/encode_test.go index 29acebce70..2787563700 100644 --- a/s2/encode_test.go +++ b/s2/encode_test.go @@ -366,6 +366,37 @@ func TestIndex(t *testing.T) { } }) } + t.Run(fmt.Sprintf("ReadAt"), func(t *testing.T) { + // Read it from a seekable stream + dec = NewReader(bytes.NewReader(compressed)) + + rs, err := dec.ReadSeeker(true, nil) + fatalErr(t, err) + + // Read a little... + var tmp = make([]byte, len(input)/2) + _, err = io.ReadFull(rs, tmp[:]) + fatalErr(t, err) + wantLen := len(tmp) + if wantLen+int(wantOffset) > len(input) { + wantLen = len(input) - int(wantOffset) + } + // Read from wantOffset + n, err := rs.ReadAt(tmp, wantOffset) + if n != wantLen { + t.Errorf("got length %d, want %d", n, wantLen) + } + if err != io.EOF { + fatalErr(t, err) + } + want := want[:n] + got := tmp[:n] + + // Read the rest of the stream... + if !bytes.Equal(got, want) { + t.Error("Result mismatch", wantOffset) + } + }) }) } } diff --git a/s2/index_test.go b/s2/index_test.go index 4fc2197fe3..9d53f53ba5 100644 --- a/s2/index_test.go +++ b/s2/index_test.go @@ -234,6 +234,66 @@ func TestSeeking(t *testing.T) { } }) } + // Test seek current + t.Run(fmt.Sprintf("seekCurrent"), func(t *testing.T) { + dec := s2.NewReader(io.ReadSeeker(bytes.NewReader(compressed.Bytes()))) + + seeker, err := dec.ReadSeeker(true, index) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 25) + rng := rand.New(rand.NewSource(0)) + var currentOff int64 + for i := 0; i < nElems/10; i++ { + rec := rng.Intn(nElems) + offset := int64(rec * 25) + //t.Logf("Reading record %d", rec) + absOff, err := seeker.Seek(offset-currentOff, io.SeekCurrent) + if err != nil { + t.Fatalf("Failed to seek: %v", err) + } + if absOff != offset { + t.Fatalf("Unexpected seek offset: want %v, got %v", offset, absOff) + } + _, err = io.ReadFull(dec, buf) + if err != nil { + t.Fatalf("Failed to seek: %v", err) + } + expected := fmt.Sprintf("Item %019d\n", rec) + if string(buf) != expected { + t.Fatalf("Expected %q, got %q", expected, buf) + } + // Adjust offset + currentOff = offset + int64(len(buf)) + } + }) + // Test ReadAt + t.Run(fmt.Sprintf("ReadAt"), func(t *testing.T) { + dec := s2.NewReader(io.ReadSeeker(bytes.NewReader(compressed.Bytes()))) + + seeker, err := dec.ReadSeeker(true, index) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 25) + rng := rand.New(rand.NewSource(0)) + for i := 0; i < nElems/10; i++ { + rec := rng.Intn(nElems) + offset := int64(rec * 25) + n, err := seeker.ReadAt(buf, offset) + if err != nil { + t.Fatalf("Failed to seek: %v", err) + } + if n != len(buf) { + t.Fatalf("Unexpected read length: want %v, got %v", len(buf), n) + } + expected := fmt.Sprintf("Item %019d\n", rec) + if string(buf) != expected { + t.Fatalf("Expected %q, got %q", expected, buf) + } + } + }) } // ExampleIndexStream shows an example of indexing a stream