Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query complexity calculation and limits #315

Merged
merged 11 commits into from
Aug 31, 2018
5 changes: 3 additions & 2 deletions codegen/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ type TypeMapEntry struct {
}

type TypeMapField struct {
Resolver bool `yaml:"resolver"`
FieldName string `yaml:"fieldName"`
Resolver bool `yaml:"resolver"`
Complexity bool `yaml:"complexity"`
FieldName string `yaml:"fieldName"`
}

func (c *PackageConfig) normalize() error {
Expand Down
48 changes: 38 additions & 10 deletions codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,16 @@ type Object struct {
type Field struct {
*Type

GQLName string // The name of the field in graphql
GoFieldType GoFieldType // The field type in go, if any
GoReceiverName string // The name of method & var receiver in go, if any
GoFieldName string // The name of the method or var in go, if any
Args []FieldArgument // A list of arguments to be passed to this field
ForceResolver bool // Should be emit Resolver method
NoErr bool // If this is bound to a go method, does that method have an error as the second argument
Object *Object // A link back to the parent object
Default interface{} // The default value
GQLName string // The name of the field in graphql
GoFieldType GoFieldType // The field type in go, if any
GoReceiverName string // The name of method & var receiver in go, if any
GoFieldName string // The name of the method or var in go, if any
Args []FieldArgument // A list of arguments to be passed to this field
ForceResolver bool // Should be emit Resolver method
CustomComplexity bool // Uses a custom complexity calculation
NoErr bool // If this is bound to a go method, does that method have an error as the second argument
Object *Object // A link back to the parent object
Default interface{} // The default value
}

type FieldArgument struct {
Expand Down Expand Up @@ -81,6 +82,15 @@ func (o *Object) IsConcurrent() bool {
return false
}

func (o *Object) HasComplexity() bool {
for _, f := range o.Fields {
if f.CustomComplexity {
return true
}
}
return false
}

func (f *Field) IsResolver() bool {
return f.GoFieldName == ""
}
Expand Down Expand Up @@ -165,6 +175,24 @@ func (f *Field) ResolverDeclaration() string {
return res
}

func (f *Field) ComplexitySignature() string {
res := fmt.Sprintf("func(childComplexity int")
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
}
res += ") int"
return res
}

func (f *Field) ComplexityArgs() string {
var args []string
for _, arg := range f.Args {
args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
}

return strings.Join(args, ", ")
}

func (f *Field) CallArgs() string {
var args []string

Expand Down Expand Up @@ -227,7 +255,7 @@ func (f *Field) doWriteJson(val string, remainingMods []string, astType *ast.Typ
ctx := graphql.WithResolverContext(ctx, rctx)
{{- end}}
{{.arr}} = append({{.arr}}, func() graphql.Marshaler {
{{ .next }}
{{ .next }}
}())
}
return {{.arr}}`, map[string]interface{}{
Expand Down
15 changes: 9 additions & 6 deletions codegen/object_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,13 @@ func (cfg *Config) buildObject(types NamedTypes, typ *ast.Definition, imports *I
}

var forceResolver bool
var customComplexity bool
var goName string
if entryExists {
if typeField, ok := typeEntry.Fields[field.Name]; ok {
goName = typeField.FieldName
forceResolver = typeField.Resolver
customComplexity = typeField.Complexity
}
}

Expand Down Expand Up @@ -168,12 +170,13 @@ func (cfg *Config) buildObject(types NamedTypes, typ *ast.Definition, imports *I
}

obj.Fields = append(obj.Fields, Field{
GQLName: field.Name,
Type: types.getType(field.Type),
Args: args,
Object: obj,
GoFieldName: goName,
ForceResolver: forceResolver,
GQLName: field.Name,
Type: types.getType(field.Type),
Args: args,
Object: obj,
GoFieldName: goName,
ForceResolver: forceResolver,
CustomComplexity: customComplexity,
})
}

Expand Down
2 changes: 1 addition & 1 deletion codegen/templates/data.go

Large diffs are not rendered by default.

48 changes: 47 additions & 1 deletion codegen/templates/generated.gotpl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ func NewExecutableSchema(cfg Config) graphql.ExecutableSchema {
return &executableSchema{
resolvers: cfg.Resolvers,
directives: cfg.Directives,
complexity: cfg.Complexity,
}
}

type Config struct {
Resolvers ResolverRoot
Directives DirectiveRoot
Complexity ComplexityRoot
}

type ResolverRoot interface {
Expand All @@ -35,7 +37,21 @@ type DirectiveRoot struct {
{{ end }}
}

{{- range $object := .Objects -}}
type ComplexityRoot struct {
{{ range $object := .Objects }}
{{ if $object.HasComplexity }}
{{ $object.GoType }} struct {
{{ range $field := $object.Fields }}
{{ if $field.CustomComplexity }}
{{ $field.GoFieldName }} {{ $field.ComplexitySignature }}
{{ end }}
{{ end }}
}
{{ end }}
{{ end }}
}

{{ range $object := .Objects -}}
{{ if $object.HasResolvers }}
type {{$object.GQLType}}Resolver interface {
{{ range $field := $object.Fields -}}
Expand All @@ -48,12 +64,42 @@ type DirectiveRoot struct {
type executableSchema struct {
resolvers ResolverRoot
directives DirectiveRoot
complexity ComplexityRoot
}

func (e *executableSchema) Schema() *ast.Schema {
return parsedSchema
}

func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) {
switch typeName + "." + field {
{{ range $object := .Objects }}
{{ range $field := $object.Fields }}
{{ if $field.CustomComplexity }}
case "{{$object.GQLType}}.{{$field.GQLName}}":
if e.complexity.{{$object.GoType}}.{{$field.GoFieldName}} == nil {
break
}
{{ if . }}args := map[string]interface{}{} {{end}}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is basically copied from args.gotpl. There's probably a better way to do this with less duplication.

In general, it seems like it might be too expensive to be doing all the argument conversion and unmarshaling twice, as well. (Now it will happen once for the complexity calculation, and then again for the actual query execution.)

Copy link
Collaborator

@vektah vektah Aug 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be pretty cheap as long as it only happens if there is a complexity function defined? Might be a bunch of extra generated code, but that's OK.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What prevents you from calling args.gotpl directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wants to return graphql.Null in error cases, which is the wrong type for this function.

{{ range $i, $arg := $field.Args }}
var arg{{$i}} {{$arg.Signature }}
if tmp, ok := rawArgs[{{$arg.GQLName|quote}}]; ok {
var err error
{{$arg.Unmarshal (print "arg" $i) "tmp" }}
if err != nil {
return 0, false
}
}
args[{{$arg.GQLName|quote}}] = arg{{$i}}
{{ end }}
return e.complexity.{{$object.GoType}}.{{$field.GoFieldName}}(childComplexity{{if $field.Args}}, {{$field.ComplexityArgs}} {{end}}), true
{{ end }}
{{ end }}
{{ end }}
}
return 0, false
}

func (e *executableSchema) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {
{{- if .QueryRoot }}
ec := executionContext{graphql.GetRequestContext(ctx), e}
Expand Down
85 changes: 85 additions & 0 deletions complexity/complexity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package complexity

import (
"github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/ast"
)

func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars map[string]interface{}) int {
walker := complexityWalker{
es: es,
vars: vars,
}
typeName := ""
switch op.Operation {
case ast.Query:
typeName = es.Schema().Query.Name
case ast.Mutation:
typeName = es.Schema().Mutation.Name
case ast.Subscription:
typeName = es.Schema().Subscription.Name
}
return walker.selectionSetComplexity(typeName, op.SelectionSet)
}

type complexityWalker struct {
es graphql.ExecutableSchema
vars map[string]interface{}
}

func (cw complexityWalker) selectionSetComplexity(typeName string, selectionSet ast.SelectionSet) int {
var complexity int
for _, selection := range selectionSet {
switch s := selection.(type) {
case *ast.Field:
var childComplexity int
switch s.ObjectDefinition.Kind {
case ast.Object, ast.Interface, ast.Union:
childComplexity = cw.selectionSetComplexity(s.ObjectDefinition.Name, s.SelectionSet)
}

args := s.ArgumentMap(cw.vars)
if customComplexity, ok := cw.es.Complexity(typeName, s.Name, childComplexity, args); ok {
complexity = safeAdd(complexity, customComplexity)
} else {
// default complexity calculation
complexity = safeAdd(complexity, safeAdd(1, childComplexity))
}

case *ast.FragmentSpread:
complexity = safeAdd(complexity, cw.selectionSetComplexity(typeName, s.Definition.SelectionSet))

case *ast.InlineFragment:
complexity = safeAdd(complexity, cw.selectionSetComplexity(typeName, s.SelectionSet))
}
}
return complexity
}

// safeAdd is a saturating add of a and b that ignores negative operands.
// If a + b would overflow through normal Go addition,
// it returns the maximum integer value instead.
//
// Adding complexities with this function prevents attackers from intentionally
// overflowing the complexity calculation to allow overly-complex queries.
//
// It also helps mitigate the impact of custom complexities that accidentally
// return negative values.
func safeAdd(a, b int) int {
// Ignore negative operands.
if a <= 0 {
if b < 0 {
return 0
}
return b
} else if b <= 0 {
return a
}

c := a + b
if c < a {
// Set c to maximum integer instead of overflowing.
c = int(^uint(0) >> 1)
}
return c
}
1 change: 1 addition & 0 deletions graphql/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
type ExecutableSchema interface {
Schema() *ast.Schema

Complexity(typeName, fieldName string, childComplexity int, args map[string]interface{}) (int, bool)
Query(ctx context.Context, op *ast.OperationDefinition) *Response
Mutation(ctx context.Context, op *ast.OperationDefinition) *Response
Subscription(ctx context.Context, op *ast.OperationDefinition) func() *Response
Expand Down
28 changes: 23 additions & 5 deletions handler/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"strings"

"github.com/99designs/gqlgen/complexity"
"github.com/99designs/gqlgen/graphql"
"github.com/gorilla/websocket"
"github.com/vektah/gqlparser"
Expand All @@ -23,11 +24,12 @@ type params struct {
}

type Config struct {
upgrader websocket.Upgrader
recover graphql.RecoverFunc
errorPresenter graphql.ErrorPresenterFunc
resolverHook graphql.FieldMiddleware
requestHook graphql.RequestMiddleware
upgrader websocket.Upgrader
recover graphql.RecoverFunc
errorPresenter graphql.ErrorPresenterFunc
resolverHook graphql.FieldMiddleware
requestHook graphql.RequestMiddleware
complexityLimit int
}

func (c *Config) newRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *graphql.RequestContext {
Expand Down Expand Up @@ -74,6 +76,14 @@ func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
}
}

// ComplexityLimit sets a maximum query complexity that is allowed to be executed.
// If a query is submitted that exceeds the limit, a 422 status code will be returned.
func ComplexityLimit(limit int) Option {
return func(cfg *Config) {
cfg.complexityLimit = limit
}
}

// ResolverMiddleware allows you to define a function that will be called around every resolver,
// useful for tracing and logging.
func ResolverMiddleware(middleware graphql.FieldMiddleware) Option {
Expand Down Expand Up @@ -184,6 +194,14 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc
}
}()

if cfg.complexityLimit > 0 {
queryComplexity := complexity.Calculate(exec, op, vars)
if queryComplexity > cfg.complexityLimit {
sendErrorf(w, http.StatusUnprocessableEntity, "query has complexity %d, which exceeds the limit of %d", queryComplexity, cfg.complexityLimit)
return
}
}

switch op.Operation {
case ast.Query:
b, err := json.Marshal(exec.Query(ctx, op))
Expand Down
4 changes: 4 additions & 0 deletions handler/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ func (e *executableSchemaStub) Schema() *ast.Schema {
`})
}

func (e *executableSchemaStub) Complexity(typeName, field string, childComplexity int, args map[string]interface{}) (int, bool) {
return 0, false
}

func (e *executableSchemaStub) Query(ctx context.Context, op *ast.OperationDefinition) *graphql.Response {
return &graphql.Response{Data: []byte(`{"name":"test"}`)}
}
Expand Down