Skip to content

Commit

Permalink
Add query document caching
Browse files Browse the repository at this point in the history
  • Loading branch information
vektah committed Oct 31, 2019
1 parent aede7d1 commit 0965420
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 59 deletions.
11 changes: 7 additions & 4 deletions graphql/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ type Cache interface {
Get(key string) (value interface{}, ok bool)

// Add adds a value to the cache.
Add(key, value string)
Add(key string, value interface{})
}

// MapCache is the simplest implementation of a cache, because it can not evict it should only be used in tests
Expand All @@ -19,6 +19,9 @@ func (m MapCache) Get(key string) (value interface{}, ok bool) {
}

// Add adds a value to the cache.
func (m MapCache) Add(key, value string) {
m[key] = value
}
func (m MapCache) Add(key string, value interface{}) { m[key] = value }

type NoCache struct{}

func (n NoCache) Get(key string) (value interface{}, ok bool) { return nil, false }
func (n NoCache) Add(key string, value interface{}) {}
58 changes: 31 additions & 27 deletions graphql/handler/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ func (e executor) CreateRequestContext(ctx context.Context, params *graphql.RawP
}
}

var gerr *gqlerror.Error

rc := &graphql.RequestContext{
DisableIntrospection: true,
Recover: graphql.DefaultRecover,
Expand All @@ -100,21 +98,22 @@ func (e executor) CreateRequestContext(ctx context.Context, params *graphql.RawP
}
rc.Stats.OperationStart = graphql.GetStartTime(ctx)

rc.Doc, gerr = e.parseOperation(ctx, rc)
if gerr != nil {
return nil, []*gqlerror.Error{gerr}
}

ctx, op, listErr := e.validateOperation(ctx, rc)
var listErr gqlerror.List
rc.Doc, listErr = e.parseQuery(ctx, rc)
if len(listErr) != 0 {
return nil, listErr
}

op := rc.Doc.Operations.ForName(rc.OperationName)
if op == nil {
return nil, gqlerror.List{gqlerror.Errorf("operation %s not found", rc.OperationName)}
}

vars, err := validator.VariableValues(e.server.es.Schema(), op, rc.Variables)
if err != nil {
return nil, gqlerror.List{err}
}

rc.Stats.Validation.End = graphql.Now()
rc.Variables = vars

for _, p := range e.requestContextMutators {
Expand Down Expand Up @@ -180,29 +179,34 @@ func (e *executor) executableSchemaHandler(ctx context.Context, write graphql.Wr
}
}

func (e executor) parseOperation(ctx context.Context, rc *graphql.RequestContext) (*ast.QueryDocument, *gqlerror.Error) {
// parseQuery decodes the incoming query and validates it, pulling from cache if present.
//
// NOTE: This should NOT look at variables, they will change per request. It should only parse and validate
// the raw query string.
func (e executor) parseQuery(ctx context.Context, rc *graphql.RequestContext) (*ast.QueryDocument, gqlerror.List) {
rc.Stats.Parsing.Start = graphql.Now()
defer func() {
rc.Stats.Parsing.End = graphql.Now()
}()
return parser.ParseQuery(&ast.Source{Input: rc.RawQuery})
}

func (e executor) validateOperation(ctx context.Context, rc *graphql.RequestContext) (context.Context, *ast.OperationDefinition, gqlerror.List) {
rc.Stats.Validation.Start = graphql.Now()
defer func() {
rc.Stats.Validation.End = graphql.Now()
}()
if doc, ok := e.server.queryCache.Get(rc.RawQuery); ok {
now := graphql.Now()

listErr := validator.Validate(e.server.es.Schema(), rc.Doc)
if len(listErr) != 0 {
return ctx, nil, listErr
rc.Stats.Parsing.End = now
rc.Stats.Validation.Start = now
return doc.(*ast.QueryDocument), nil
}

op := rc.Doc.Operations.ForName(rc.OperationName)
if op == nil {
return ctx, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", rc.OperationName)}
doc, err := parser.ParseQuery(&ast.Source{Input: rc.RawQuery})
if err != nil {
return nil, gqlerror.List{err}
}
rc.Stats.Parsing.End = graphql.Now()

rc.Stats.Validation.Start = graphql.Now()
listErr := validator.Validate(e.server.es.Schema(), doc)
if len(listErr) != 0 {
return nil, listErr
}

return ctx, op, nil
e.server.queryCache.Add(rc.RawQuery, doc)

return doc, nil
}
6 changes: 6 additions & 0 deletions graphql/handler/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type (

errorPresenter graphql.ErrorPresenterFunc
recoverFunc graphql.RecoverFunc
queryCache graphql.Cache
}
)

Expand All @@ -26,6 +27,7 @@ func New(es graphql.ExecutableSchema) *Server {
es: es,
errorPresenter: graphql.DefaultErrorPresenter,
recoverFunc: graphql.DefaultRecover,
queryCache: graphql.NoCache{},
}
s.exec = newExecutor(s)
return s
Expand All @@ -43,6 +45,10 @@ func (s *Server) SetRecoverFunc(f graphql.RecoverFunc) {
s.recoverFunc = f
}

func (s *Server) SetQueryCache(cache graphql.Cache) {
s.queryCache = cache
}

func (s *Server) Use(plugin graphql.HandlerPlugin) {
switch plugin.(type) {
case graphql.RequestParameterMutator,
Expand Down
40 changes: 37 additions & 3 deletions graphql/handler/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@ package handler_test

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/vektah/gqlparser/parser"

"github.com/stretchr/testify/require"
"github.com/vektah/gqlparser/ast"

"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/handler/testserver"
"github.com/99designs/gqlgen/graphql/handler/transport"
Expand Down Expand Up @@ -60,12 +65,10 @@ func TestServer(t *testing.T) {
t.Run("invokes field middleware in order", func(t *testing.T) {
var calls []string
srv.Use(fieldFunc(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
fmt.Println("first")
calls = append(calls, "first")
return next(ctx)
}))
srv.Use(fieldFunc(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
fmt.Println("second")
calls = append(calls, "second")
return next(ctx)
}))
Expand All @@ -74,6 +77,37 @@ func TestServer(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
assert.Equal(t, []string{"first", "second"}, calls)
})

t.Run("query caching", func(t *testing.T) {
cache := &graphql.MapCache{}
srv.SetQueryCache(cache)
qry := `query Foo {name}`

t.Run("cache miss populates cache", func(t *testing.T) {
resp := get(srv, "/foo?query="+url.QueryEscape(qry))
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())

cacheDoc, ok := cache.Get(qry)
require.True(t, ok)
require.Equal(t, "Foo", cacheDoc.(*ast.QueryDocument).Operations[0].Name)
})

t.Run("cache hits use document from cache", func(t *testing.T) {
doc, err := parser.ParseQuery(&ast.Source{Input: `query Bar {name}`})
require.Nil(t, err)
cache.Add(qry, doc)

resp := get(srv, "/foo?query="+url.QueryEscape(qry))
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())

cacheDoc, ok := cache.Get(qry)
require.True(t, ok)
require.Equal(t, "Bar", cacheDoc.(*ast.QueryDocument).Operations[0].Name)
})
})

}

type opFunc func(ctx context.Context, next graphql.OperationHandler, writer graphql.Writer)
Expand Down
25 changes: 0 additions & 25 deletions graphql/handler/transport/http_post_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,6 @@ func TestPOST(t *testing.T) {
assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
})

// Todo: Extract
//t.Run("query caching", func(t *testing.T) {
// // Run enough unique queries to evict a bunch of them
// for i := 0; i < 2000; i++ {
// query := `{"query":"` + strings.Repeat(" ", i) + "{ me { name } }" + `"}`
// resp := doRequest(h, "POST", "/graphql", query)
// assert.Equal(t, http.StatusOK, resp.Code)
// assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
// }
//
// t.Run("evicted queries run", func(t *testing.T) {
// query := `{"query":"` + strings.Repeat(" ", 0) + "{ me { name } }" + `"}`
// resp := doRequest(h, "POST", "/graphql", query)
// assert.Equal(t, http.StatusOK, resp.Code)
// assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
// })
//
// t.Run("non-evicted queries run", func(t *testing.T) {
// query := `{"query":"` + strings.Repeat(" ", 1999) + "{ me { name } }" + `"}`
// resp := doRequest(h, "POST", "/graphql", query)
// assert.Equal(t, http.StatusOK, resp.Code)
// assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String())
// })
//})

t.Run("decode failure", func(t *testing.T) {
resp := doRequest(h, "POST", "/graphql", "notjson")
assert.Equal(t, http.StatusBadRequest, resp.Code, resp.Body.String())
Expand Down

0 comments on commit 0965420

Please sign in to comment.