Skip to content

Commit

Permalink
fix: link header in keyset pagination (#729)
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr authored Oct 10, 2023
1 parent 45f6c90 commit 67f2a27
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
11 changes: 6 additions & 5 deletions pagination/keysetpagination/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"net/url"
"strconv"
"strings"

"github.com/pkg/errors"
)
Expand Down Expand Up @@ -83,11 +84,11 @@ func header(u *url.URL, rel, token string, size int) string {
// It contains links to the first and next page, if one exists.
func Header(w http.ResponseWriter, u *url.URL, p *Paginator) {
size := p.Size()
w.Header().Set("Link", header(u, "first", p.defaultToken.Encode(), size))

if !p.IsLast() {
w.Header().Add("Link", header(u, "next", p.Token().Encode(), size))
link := []string{header(u, "first", p.defaultToken.Encode(), size)}
if !p.isLast {
link = append(link, header(u, "next", p.Token().Encode(), size))
}
w.Header().Set("Link", strings.Join(link, ","))
}

// Parse returns the pagination options from the URL query.
Expand All @@ -104,7 +105,7 @@ func Parse(q url.Values, p PageTokenConstructor) ([]Option, error) {
}
opts = append(opts, WithToken(parsed))
}
if q.Has("page_size") {
if q.Get("page_size") != "" {
size, err := strconv.Atoi(q.Get("page_size"))
if err != nil {
return nil, errors.WithStack(err)
Expand Down
26 changes: 15 additions & 11 deletions pagination/keysetpagination/header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/url"
"testing"

"github.com/peterhellberg/link"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand All @@ -26,19 +27,22 @@ func TestHeader(t *testing.T) {

Header(r, u, p)

links := r.HeaderMap["Link"]
require.Len(t, links, 2)
assert.Contains(t, links[0], "page_token=default")
assert.Contains(t, links[1], "page_token=next")
assert.Len(t, r.Result().Header.Values("link"), 1, "make sure we send one header with multiple comma-separated values rather than multiple headers")

t.Run("with isLast", func(t *testing.T) {
p.isLast = true
links := link.ParseResponse(r.Result())
assert.Contains(t, links, "first")
assert.Contains(t, links["first"].URI, "page_token=default")

Header(r, u, p)
assert.Contains(t, links, "next")
assert.Contains(t, links["next"].URI, "page_token=next")

links := r.HeaderMap["Link"]
require.Len(t, links, 1)
assert.Contains(t, links[0], "page_token=default")
})
p.isLast = true
r = httptest.NewRecorder()
Header(r, u, p)
links = link.ParseResponse(r.Result())

assert.Contains(t, links, "first")
assert.Contains(t, links["first"].URI, "page_token=default")

assert.NotContains(t, links, "next")
}

0 comments on commit 67f2a27

Please sign in to comment.