From a027ac21c773ed1bf71ec6017e5cafbd140305a2 Mon Sep 17 00:00:00 2001 From: vvakame Date: Mon, 29 Oct 2018 18:54:22 +0900 Subject: [PATCH 1/2] copy complexity to RequestContext --- graphql/context.go | 4 ++++ handler/graphql.go | 19 +++++++++++-------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/graphql/context.go b/graphql/context.go index b6586e96046..a46844cb40a 100644 --- a/graphql/context.go +++ b/graphql/context.go @@ -17,6 +17,10 @@ type RequestContext struct { RawQuery string Variables map[string]interface{} Doc *ast.QueryDocument + + ComplexityLimit int + OperationComplexity int + // ErrorPresenter will be used to generate the error // message from errors given to Error(). ErrorPresenter ErrorPresenterFunc diff --git a/handler/graphql.go b/handler/graphql.go index 13d013613a3..b09c344f540 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -35,7 +35,7 @@ type Config struct { complexityLimit int } -func (c *Config) newRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *graphql.RequestContext { +func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext { reqCtx := graphql.NewRequestContext(doc, query, variables) if hook := c.recover; hook != nil { reqCtx.Recover = hook @@ -59,6 +59,12 @@ func (c *Config) newRequestContext(doc *ast.QueryDocument, query string, variabl reqCtx.Tracer = &graphql.NopTracer{} } + if c.complexityLimit > 0 { + reqCtx.ComplexityLimit = c.complexityLimit + operationComplexity := complexity.Calculate(es, op, variables) + reqCtx.OperationComplexity = operationComplexity + } + return reqCtx } @@ -298,7 +304,7 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc sendError(w, http.StatusUnprocessableEntity, err) return } - reqCtx := cfg.newRequestContext(doc, reqParams.Query, vars) + reqCtx := cfg.newRequestContext(exec, doc, op, reqParams.Query, vars) ctx := graphql.WithRequestContext(r.Context(), reqCtx) defer func() { @@ -308,12 +314,9 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc } }() - if cfg.complexityLimit > 0 { - queryComplexity := complexity.Calculate(exec, op, vars) - if queryComplexity > cfg.complexityLimit { - sendErrorf(w, http.StatusUnprocessableEntity, "query has complexity %d, which exceeds the limit of %d", queryComplexity, cfg.complexityLimit) - return - } + if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > cfg.complexityLimit { + sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", operationComplexity, cfg.complexityLimit) + return } switch op.Operation { From 784dc01fdb4f0759e59e32bb48814b94760ca00b Mon Sep 17 00:00:00 2001 From: vvakame Date: Mon, 29 Oct 2018 19:17:04 +0900 Subject: [PATCH 2/2] oops... --- handler/graphql.go | 2 +- handler/websocket.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/handler/graphql.go b/handler/graphql.go index b09c344f540..1e085ecd169 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -315,7 +315,7 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc }() if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > cfg.complexityLimit { - sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", operationComplexity, cfg.complexityLimit) + sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit) return } diff --git a/handler/websocket.go b/handler/websocket.go index af7d7a5a4e4..dae262bdf3a 100644 --- a/handler/websocket.go +++ b/handler/websocket.go @@ -165,7 +165,7 @@ func (c *wsConnection) subscribe(message *operationMessage) bool { c.sendError(message.ID, err) return true } - reqCtx := c.cfg.newRequestContext(doc, reqParams.Query, vars) + reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars) ctx := graphql.WithRequestContext(c.ctx, reqCtx) if c.initPayload != nil {