From 0965420a4246492bbac6922da742b157ea968c29 Mon Sep 17 00:00:00 2001 From: Adam Date: Thu, 31 Oct 2019 14:44:54 +1100 Subject: [PATCH] Add query document caching --- graphql/cache.go | 11 ++-- graphql/handler/executor.go | 58 +++++++++++---------- graphql/handler/server.go | 6 +++ graphql/handler/server_test.go | 40 ++++++++++++-- graphql/handler/transport/http_post_test.go | 25 --------- 5 files changed, 81 insertions(+), 59 deletions(-) diff --git a/graphql/cache.go b/graphql/cache.go index bd55f0b8e14..24294f72d7f 100644 --- a/graphql/cache.go +++ b/graphql/cache.go @@ -6,7 +6,7 @@ type Cache interface { Get(key string) (value interface{}, ok bool) // Add adds a value to the cache. - Add(key, value string) + Add(key string, value interface{}) } // MapCache is the simplest implementation of a cache, because it can not evict it should only be used in tests @@ -19,6 +19,9 @@ func (m MapCache) Get(key string) (value interface{}, ok bool) { } // Add adds a value to the cache. -func (m MapCache) Add(key, value string) { - m[key] = value -} +func (m MapCache) Add(key string, value interface{}) { m[key] = value } + +type NoCache struct{} + +func (n NoCache) Get(key string) (value interface{}, ok bool) { return nil, false } +func (n NoCache) Add(key string, value interface{}) {} diff --git a/graphql/handler/executor.go b/graphql/handler/executor.go index eba9714da2a..6e7fd18c116 100644 --- a/graphql/handler/executor.go +++ b/graphql/handler/executor.go @@ -87,8 +87,6 @@ func (e executor) CreateRequestContext(ctx context.Context, params *graphql.RawP } } - var gerr *gqlerror.Error - rc := &graphql.RequestContext{ DisableIntrospection: true, Recover: graphql.DefaultRecover, @@ -100,21 +98,22 @@ func (e executor) CreateRequestContext(ctx context.Context, params *graphql.RawP } rc.Stats.OperationStart = graphql.GetStartTime(ctx) - rc.Doc, gerr = e.parseOperation(ctx, rc) - if gerr != nil { - return nil, []*gqlerror.Error{gerr} - } - - ctx, op, listErr := e.validateOperation(ctx, rc) + var listErr gqlerror.List + rc.Doc, listErr = e.parseQuery(ctx, rc) if len(listErr) != 0 { return nil, listErr } + op := rc.Doc.Operations.ForName(rc.OperationName) + if op == nil { + return nil, gqlerror.List{gqlerror.Errorf("operation %s not found", rc.OperationName)} + } + vars, err := validator.VariableValues(e.server.es.Schema(), op, rc.Variables) if err != nil { return nil, gqlerror.List{err} } - + rc.Stats.Validation.End = graphql.Now() rc.Variables = vars for _, p := range e.requestContextMutators { @@ -180,29 +179,34 @@ func (e *executor) executableSchemaHandler(ctx context.Context, write graphql.Wr } } -func (e executor) parseOperation(ctx context.Context, rc *graphql.RequestContext) (*ast.QueryDocument, *gqlerror.Error) { +// parseQuery decodes the incoming query and validates it, pulling from cache if present. +// +// NOTE: This should NOT look at variables, they will change per request. It should only parse and validate +// the raw query string. +func (e executor) parseQuery(ctx context.Context, rc *graphql.RequestContext) (*ast.QueryDocument, gqlerror.List) { rc.Stats.Parsing.Start = graphql.Now() - defer func() { - rc.Stats.Parsing.End = graphql.Now() - }() - return parser.ParseQuery(&ast.Source{Input: rc.RawQuery}) -} -func (e executor) validateOperation(ctx context.Context, rc *graphql.RequestContext) (context.Context, *ast.OperationDefinition, gqlerror.List) { - rc.Stats.Validation.Start = graphql.Now() - defer func() { - rc.Stats.Validation.End = graphql.Now() - }() + if doc, ok := e.server.queryCache.Get(rc.RawQuery); ok { + now := graphql.Now() - listErr := validator.Validate(e.server.es.Schema(), rc.Doc) - if len(listErr) != 0 { - return ctx, nil, listErr + rc.Stats.Parsing.End = now + rc.Stats.Validation.Start = now + return doc.(*ast.QueryDocument), nil } - op := rc.Doc.Operations.ForName(rc.OperationName) - if op == nil { - return ctx, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", rc.OperationName)} + doc, err := parser.ParseQuery(&ast.Source{Input: rc.RawQuery}) + if err != nil { + return nil, gqlerror.List{err} + } + rc.Stats.Parsing.End = graphql.Now() + + rc.Stats.Validation.Start = graphql.Now() + listErr := validator.Validate(e.server.es.Schema(), doc) + if len(listErr) != 0 { + return nil, listErr } - return ctx, op, nil + e.server.queryCache.Add(rc.RawQuery, doc) + + return doc, nil } diff --git a/graphql/handler/server.go b/graphql/handler/server.go index 03ccae9dfbd..50b9379033d 100644 --- a/graphql/handler/server.go +++ b/graphql/handler/server.go @@ -18,6 +18,7 @@ type ( errorPresenter graphql.ErrorPresenterFunc recoverFunc graphql.RecoverFunc + queryCache graphql.Cache } ) @@ -26,6 +27,7 @@ func New(es graphql.ExecutableSchema) *Server { es: es, errorPresenter: graphql.DefaultErrorPresenter, recoverFunc: graphql.DefaultRecover, + queryCache: graphql.NoCache{}, } s.exec = newExecutor(s) return s @@ -43,6 +45,10 @@ func (s *Server) SetRecoverFunc(f graphql.RecoverFunc) { s.recoverFunc = f } +func (s *Server) SetQueryCache(cache graphql.Cache) { + s.queryCache = cache +} + func (s *Server) Use(plugin graphql.HandlerPlugin) { switch plugin.(type) { case graphql.RequestParameterMutator, diff --git a/graphql/handler/server_test.go b/graphql/handler/server_test.go index fe8ab220871..a966cd8eb20 100644 --- a/graphql/handler/server_test.go +++ b/graphql/handler/server_test.go @@ -2,11 +2,16 @@ package handler_test import ( "context" - "fmt" "net/http" "net/http/httptest" + "net/url" "testing" + "github.com/vektah/gqlparser/parser" + + "github.com/stretchr/testify/require" + "github.com/vektah/gqlparser/ast" + "github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql/handler/testserver" "github.com/99designs/gqlgen/graphql/handler/transport" @@ -60,12 +65,10 @@ func TestServer(t *testing.T) { t.Run("invokes field middleware in order", func(t *testing.T) { var calls []string srv.Use(fieldFunc(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { - fmt.Println("first") calls = append(calls, "first") return next(ctx) })) srv.Use(fieldFunc(func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { - fmt.Println("second") calls = append(calls, "second") return next(ctx) })) @@ -74,6 +77,37 @@ func TestServer(t *testing.T) { assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String()) assert.Equal(t, []string{"first", "second"}, calls) }) + + t.Run("query caching", func(t *testing.T) { + cache := &graphql.MapCache{} + srv.SetQueryCache(cache) + qry := `query Foo {name}` + + t.Run("cache miss populates cache", func(t *testing.T) { + resp := get(srv, "/foo?query="+url.QueryEscape(qry)) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) + + cacheDoc, ok := cache.Get(qry) + require.True(t, ok) + require.Equal(t, "Foo", cacheDoc.(*ast.QueryDocument).Operations[0].Name) + }) + + t.Run("cache hits use document from cache", func(t *testing.T) { + doc, err := parser.ParseQuery(&ast.Source{Input: `query Bar {name}`}) + require.Nil(t, err) + cache.Add(qry, doc) + + resp := get(srv, "/foo?query="+url.QueryEscape(qry)) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) + + cacheDoc, ok := cache.Get(qry) + require.True(t, ok) + require.Equal(t, "Bar", cacheDoc.(*ast.QueryDocument).Operations[0].Name) + }) + }) + } type opFunc func(ctx context.Context, next graphql.OperationHandler, writer graphql.Writer) diff --git a/graphql/handler/transport/http_post_test.go b/graphql/handler/transport/http_post_test.go index 86411767ca7..8e389b023be 100644 --- a/graphql/handler/transport/http_post_test.go +++ b/graphql/handler/transport/http_post_test.go @@ -22,31 +22,6 @@ func TestPOST(t *testing.T) { assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) }) - // Todo: Extract - //t.Run("query caching", func(t *testing.T) { - // // Run enough unique queries to evict a bunch of them - // for i := 0; i < 2000; i++ { - // query := `{"query":"` + strings.Repeat(" ", i) + "{ me { name } }" + `"}` - // resp := doRequest(h, "POST", "/graphql", query) - // assert.Equal(t, http.StatusOK, resp.Code) - // assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) - // } - // - // t.Run("evicted queries run", func(t *testing.T) { - // query := `{"query":"` + strings.Repeat(" ", 0) + "{ me { name } }" + `"}` - // resp := doRequest(h, "POST", "/graphql", query) - // assert.Equal(t, http.StatusOK, resp.Code) - // assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) - // }) - // - // t.Run("non-evicted queries run", func(t *testing.T) { - // query := `{"query":"` + strings.Repeat(" ", 1999) + "{ me { name } }" + `"}` - // resp := doRequest(h, "POST", "/graphql", query) - // assert.Equal(t, http.StatusOK, resp.Code) - // assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) - // }) - //}) - t.Run("decode failure", func(t *testing.T) { resp := doRequest(h, "POST", "/graphql", "notjson") assert.Equal(t, http.StatusBadRequest, resp.Code, resp.Body.String())