diff --git a/README.md b/README.md index 8d68d2c7..01038bfa 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,17 @@ $ curl -XPOST -d '{"query": "{ hello }"}' localhost:8080/query ### Resolvers -A resolver must have one method for each field of the GraphQL type it resolves. The method name has to be [exported](https://golang.org/ref/spec#Exported_identifiers) and match the field's name in a non-case-sensitive way. +A resolver must have one method or field for each field of the GraphQL type it resolves. The method or field name has to be [exported](https://golang.org/ref/spec#Exported_identifiers) and match the schema's field's name in a non-case-sensitive way. +You can use struct fields as resolvers by using `SchemaOpt: UseFieldResolvers()`. For example, +``` +opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()} +schema := graphql.MustParseSchema(s, &query{}, opts...) +``` + +When using `UseFieldResolvers` schema option, a struct field will be used *only* when: +- there is no method for a struct field +- a struct field does not implement an interface method +- a struct field does not have arguments The method has up to two arguments: diff --git a/example/social/README.md b/example/social/README.md new file mode 100644 index 00000000..5ab316fd --- /dev/null +++ b/example/social/README.md @@ -0,0 +1,9 @@ +### Social App + +A simple example of how to use struct fields as resolvers instead of methods. + +To run this server + +`go run ./example/field-resolvers/server/server.go` + +and go to localhost:9011 to interact \ No newline at end of file diff --git a/example/social/server/server.go b/example/social/server/server.go new file mode 100644 index 00000000..6bfde72b --- /dev/null +++ b/example/social/server/server.go @@ -0,0 +1,62 @@ +package main + +import ( + "log" + "net/http" + + "github.com/graph-gophers/graphql-go" + "github.com/graph-gophers/graphql-go/example/social" + "github.com/graph-gophers/graphql-go/relay" +) + +func main() { + opts := []graphql.SchemaOpt{graphql.UseFieldResolvers(), graphql.MaxParallelism(20)} + schema := graphql.MustParseSchema(social.Schema, &social.Resolver{}, opts...) + + http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write(page) + })) + + http.Handle("/query", &relay.Handler{Schema: schema}) + + log.Fatal(http.ListenAndServe(":9011", nil)) +} + +var page = []byte(` + + + + + + + + + + + +
Loading...
+ + + +`) diff --git a/example/social/social.go b/example/social/social.go new file mode 100644 index 00000000..67774207 --- /dev/null +++ b/example/social/social.go @@ -0,0 +1,206 @@ +package social + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/graph-gophers/graphql-go" +) + +const Schema = ` + schema { + query: Query + } + + type Query { + admin(id: ID!, role: Role = ADMIN): Admin! + user(id: ID!): User! + search(text: String!): [SearchResult]! + } + + interface Admin { + id: ID! + name: String! + role: Role! + } + + scalar Time + + type User implements Admin { + id: ID! + name: String! + email: String! + role: Role! + phone: String! + address: [String!] + friends(page: Pagination): [User] + createdAt: Time! + } + + input Pagination { + first: Int + last: Int + } + + enum Role { + ADMIN + USER + } + + union SearchResult = User +` + +type page struct { + First *float64 + Last *float64 +} + +type admin interface { + ID() graphql.ID + Name() string + Role() string +} + +type searchResult struct { + result interface{} +} + +func (r *searchResult) ToUser() (*user, bool) { + res, ok := r.result.(*user) + return res, ok +} + +type user struct { + IDField string + NameField string + RoleField string + Email string + Phone string + Address *[]string + Friends *[]*user + CreatedAt graphql.Time +} + +func (u user) ID() graphql.ID { + return graphql.ID(u.IDField) +} + +func (u user) Name() string { + return u.NameField +} + +func (u user) Role() string { + return u.RoleField +} + +func (u user) FriendsResolver(args struct{ Page *page }) (*[]*user, error) { + var from int + numFriends := len(*u.Friends) + to := numFriends + + if args.Page != nil { + if args.Page.First != nil { + from = int(*args.Page.First) + if from > numFriends { + return nil, errors.New("not enough users") + } + } + if args.Page.Last != nil { + to = int(*args.Page.Last) + if to == 0 || to > numFriends { + to = numFriends + } + } + } + + friends := (*u.Friends)[from:to] + + return &friends, nil +} + +var users = []*user{ + { + IDField: "0x01", + NameField: "Albus Dumbledore", + RoleField: "ADMIN", + Email: "Albus@hogwarts.com", + Phone: "000-000-0000", + Address: &[]string{"Office @ Hogwarts", "where Horcruxes are"}, + CreatedAt: graphql.Time{Time: time.Now()}, + }, + { + IDField: "0x02", + NameField: "Harry Potter", + RoleField: "USER", + Email: "harry@hogwarts.com", + Phone: "000-000-0001", + Address: &[]string{"123 dorm room @ Hogwarts", "456 random place"}, + CreatedAt: graphql.Time{Time: time.Now()}, + }, + { + IDField: "0x03", + NameField: "Hermione Granger", + RoleField: "USER", + Email: "hermione@hogwarts.com", + Phone: "000-000-0011", + Address: &[]string{"233 dorm room @ Hogwarts", "786 @ random place"}, + CreatedAt: graphql.Time{Time: time.Now()}, + }, + { + IDField: "0x04", + NameField: "Ronald Weasley", + RoleField: "USER", + Email: "ronald@hogwarts.com", + Phone: "000-000-0111", + Address: &[]string{"411 dorm room @ Hogwarts", "981 @ random place"}, + CreatedAt: graphql.Time{Time: time.Now()}, + }, +} + +var usersMap = make(map[string]*user) + +func init() { + users[0].Friends = &[]*user{users[1]} + users[1].Friends = &[]*user{users[0], users[2], users[3]} + users[2].Friends = &[]*user{users[1], users[3]} + users[3].Friends = &[]*user{users[1], users[2]} + for _, usr := range users { + usersMap[usr.IDField] = usr + } +} + +type Resolver struct{} + +func (r *Resolver) Admin(ctx context.Context, args struct { + ID string + Role string +}) (admin, error) { + if usr, ok := usersMap[args.ID]; ok { + if usr.RoleField == args.Role { + return *usr, nil + } + } + err := fmt.Errorf("user with id=%s and role=%s does not exist", args.ID, args.Role) + return user{}, err +} + +func (r *Resolver) User(ctx context.Context, args struct{ Id string }) (user, error) { + if usr, ok := usersMap[args.Id]; ok { + return *usr, nil + } + err := fmt.Errorf("user with id=%s does not exist", args.Id) + return user{}, err +} + +func (r *Resolver) Search(ctx context.Context, args struct{ Text string }) ([]*searchResult, error) { + var result []*searchResult + for _, usr := range users { + if strings.Contains(usr.NameField, args.Text) { + result = append(result, &searchResult{usr}) + } + } + return result, nil +} diff --git a/graphql.go b/graphql.go index 35768a47..f3fe32eb 100644 --- a/graphql.go +++ b/graphql.go @@ -2,9 +2,8 @@ package graphql import ( "context" - "fmt" - "encoding/json" + "fmt" "github.com/graph-gophers/graphql-go/errors" "github.com/graph-gophers/graphql-go/internal/common" @@ -84,6 +83,13 @@ func UseStringDescriptions() SchemaOpt { } } +// UseFieldResolvers specifies whether to use struct field resolvers +func UseFieldResolvers() SchemaOpt { + return func(s *Schema) { + s.schema.UseFieldResolvers = true + } +} + // MaxDepth specifies the maximum field nesting depth in a query. The default is 0 which disables max depth checking. func MaxDepth(n int) SchemaOpt { return func(s *Schema) { diff --git a/internal/exec/exec.go b/internal/exec/exec.go index c326fc95..e878888f 100644 --- a/internal/exec/exec.go +++ b/internal/exec/exec.go @@ -178,24 +178,33 @@ func execFieldSelection(ctx context.Context, r *Request, f *fieldToExec, path *p return errors.Errorf("%s", err) // don't execute any more resolvers if context got cancelled } - var in []reflect.Value - if f.field.HasContext { - in = append(in, reflect.ValueOf(traceCtx)) - } - if f.field.ArgsPacker != nil { - in = append(in, f.field.PackedArgs) - } - callOut := f.resolver.Method(f.field.MethodIndex).Call(in) - result = callOut[0] - if f.field.HasError && !callOut[1].IsNil() { - resolverErr := callOut[1].Interface().(error) - err := errors.Errorf("%s", resolverErr) - err.Path = path.toSlice() - err.ResolverError = resolverErr - if ex, ok := callOut[1].Interface().(extensionser); ok { - err.Extensions = ex.Extensions() + res := f.resolver + if f.field.UseMethodResolver() { + var in []reflect.Value + if f.field.HasContext { + in = append(in, reflect.ValueOf(traceCtx)) + } + if f.field.ArgsPacker != nil { + in = append(in, f.field.PackedArgs) + } + callOut := res.Method(f.field.MethodIndex).Call(in) + result = callOut[0] + if f.field.HasError && !callOut[1].IsNil() { + resolverErr := callOut[1].Interface().(error) + err := errors.Errorf("%s", resolverErr) + err.Path = path.toSlice() + err.ResolverError = resolverErr + if ex, ok := callOut[1].Interface().(extensionser); ok { + err.Extensions = ex.Extensions() + } + return err + } + } else { + // TODO extract out unwrapping ptr logic to a common place + if res.Kind() == reflect.Ptr { + res = res.Elem() } - return err + result = res.Field(f.field.FieldIndex) } return nil }() diff --git a/internal/exec/resolvable/resolvable.go b/internal/exec/resolvable/resolvable.go index c4802520..27809230 100644 --- a/internal/exec/resolvable/resolvable.go +++ b/internal/exec/resolvable/resolvable.go @@ -33,6 +33,7 @@ type Field struct { schema.Field TypeName string MethodIndex int + FieldIndex int HasContext bool HasError bool ArgsPacker *packer.StructPacker @@ -40,6 +41,10 @@ type Field struct { TraceLabel string } +func (f *Field) UseMethodResolver() bool { + return f.FieldIndex == -1 +} + type TypeAssertion struct { MethodIndex int TypeExec Resolvable @@ -189,13 +194,13 @@ func makeScalarExec(t *schema.Scalar, resolverType reflect.Type) (Resolvable, er implementsType := false switch r := reflect.New(resolverType).Interface().(type) { case *int32: - implementsType = (t.Name == "Int") + implementsType = t.Name == "Int" case *float64: - implementsType = (t.Name == "Float") + implementsType = t.Name == "Float" case *string: - implementsType = (t.Name == "String") + implementsType = t.Name == "String" case *bool: - implementsType = (t.Name == "Boolean") + implementsType = t.Name == "Boolean" case packer.Unmarshaler: implementsType = r.ImplementsGraphQLType(t.Name) } @@ -205,7 +210,8 @@ func makeScalarExec(t *schema.Scalar, resolverType reflect.Type) (Resolvable, er return &Scalar{}, nil } -func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, possibleTypes []*schema.Object, nonNull bool, resolverType reflect.Type) (*Object, error) { +func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, possibleTypes []*schema.Object, + nonNull bool, resolverType reflect.Type) (*Object, error) { if !nonNull { if resolverType.Kind() != reflect.Ptr && resolverType.Kind() != reflect.Interface { return nil, fmt.Errorf("%s is not a pointer or interface", resolverType) @@ -215,9 +221,14 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p methodHasReceiver := resolverType.Kind() != reflect.Interface Fields := make(map[string]*Field) + rt := unwrapPtr(resolverType) for _, f := range fields { + fieldIndex := -1 methodIndex := findMethod(resolverType, f.Name) - if methodIndex == -1 { + if b.schema.UseFieldResolvers && methodIndex == -1 { + fieldIndex = findField(rt, f.Name) + } + if methodIndex == -1 && fieldIndex == -1 { hint := "" if findMethod(reflect.PtrTo(resolverType), f.Name) != -1 { hint = " (hint: the method exists on the pointer type)" @@ -225,30 +236,41 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p return nil, fmt.Errorf("%s does not resolve %q: missing method for field %q%s", resolverType, typeName, f.Name, hint) } - m := resolverType.Method(methodIndex) - fe, err := b.makeFieldExec(typeName, f, m, methodIndex, methodHasReceiver) + var m reflect.Method + var sf reflect.StructField + if methodIndex != -1 { + m = resolverType.Method(methodIndex) + } else { + sf = rt.Field(fieldIndex) + } + fe, err := b.makeFieldExec(typeName, f, m, sf, methodIndex, fieldIndex, methodHasReceiver) if err != nil { return nil, fmt.Errorf("%s\n\treturned by (%s).%s", err, resolverType, m.Name) } Fields[f.Name] = fe } + // Check type assertions when + // 1) using method resolvers + // 2) Or resolver is not an interface type typeAssertions := make(map[string]*TypeAssertion) - for _, impl := range possibleTypes { - methodIndex := findMethod(resolverType, "To"+impl.Name) - if methodIndex == -1 { - return nil, fmt.Errorf("%s does not resolve %q: missing method %q to convert to %q", resolverType, typeName, "To"+impl.Name, impl.Name) - } - if resolverType.Method(methodIndex).Type.NumOut() != 2 { - return nil, fmt.Errorf("%s does not resolve %q: method %q should return a value and a bool indicating success", resolverType, typeName, "To"+impl.Name) - } - a := &TypeAssertion{ - MethodIndex: methodIndex, - } - if err := b.assignExec(&a.TypeExec, impl, resolverType.Method(methodIndex).Type.Out(0)); err != nil { - return nil, err + if !b.schema.UseFieldResolvers || resolverType.Kind() != reflect.Interface { + for _, impl := range possibleTypes { + methodIndex := findMethod(resolverType, "To"+impl.Name) + if methodIndex == -1 { + return nil, fmt.Errorf("%s does not resolve %q: missing method %q to convert to %q", resolverType, typeName, "To"+impl.Name, impl.Name) + } + if resolverType.Method(methodIndex).Type.NumOut() != 2 { + return nil, fmt.Errorf("%s does not resolve %q: method %q should return a value and a bool indicating success", resolverType, typeName, "To"+impl.Name) + } + a := &TypeAssertion{ + MethodIndex: methodIndex, + } + if err := b.assignExec(&a.TypeExec, impl, resolverType.Method(methodIndex).Type.Out(0)); err != nil { + return nil, err + } + typeAssertions[impl.Name] = a } - typeAssertions[impl.Name] = a } return &Object{ @@ -261,50 +283,58 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p var contextType = reflect.TypeOf((*context.Context)(nil)).Elem() var errorType = reflect.TypeOf((*error)(nil)).Elem() -func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect.Method, methodIndex int, methodHasReceiver bool) (*Field, error) { - in := make([]reflect.Type, m.Type.NumIn()) - for i := range in { - in[i] = m.Type.In(i) - } - if methodHasReceiver { - in = in[1:] // first parameter is receiver - } - - hasContext := len(in) > 0 && in[0] == contextType - if hasContext { - in = in[1:] - } +func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect.Method, sf reflect.StructField, + methodIndex, fieldIndex int, methodHasReceiver bool) (*Field, error) { var argsPacker *packer.StructPacker - if len(f.Args) > 0 { - if len(in) == 0 { - return nil, fmt.Errorf("must have parameter for field arguments") + var hasError bool + var hasContext bool + + // Validate resolver method only when there is one + if methodIndex != -1 { + in := make([]reflect.Type, m.Type.NumIn()) + for i := range in { + in[i] = m.Type.In(i) } - var err error - argsPacker, err = b.packerBuilder.MakeStructPacker(f.Args, in[0]) - if err != nil { - return nil, err + if methodHasReceiver { + in = in[1:] // first parameter is receiver } - in = in[1:] - } - if len(in) > 0 { - return nil, fmt.Errorf("too many parameters") - } + hasContext = len(in) > 0 && in[0] == contextType + if hasContext { + in = in[1:] + } - maxNumOfReturns := 2 - if m.Type.NumOut() < maxNumOfReturns-1 { - return nil, fmt.Errorf("too few return values") - } + if len(f.Args) > 0 { + if len(in) == 0 { + return nil, fmt.Errorf("must have parameter for field arguments") + } + var err error + argsPacker, err = b.packerBuilder.MakeStructPacker(f.Args, in[0]) + if err != nil { + return nil, err + } + in = in[1:] + } - if m.Type.NumOut() > maxNumOfReturns { - return nil, fmt.Errorf("too many return values") - } + if len(in) > 0 { + return nil, fmt.Errorf("too many parameters") + } - hasError := m.Type.NumOut() == maxNumOfReturns - if hasError { - if m.Type.Out(maxNumOfReturns-1) != errorType { - return nil, fmt.Errorf(`must have "error" as its last return value`) + maxNumOfReturns := 2 + if m.Type.NumOut() < maxNumOfReturns-1 { + return nil, fmt.Errorf("too few return values") + } + + if m.Type.NumOut() > maxNumOfReturns { + return nil, fmt.Errorf("too many return values") + } + + hasError = m.Type.NumOut() == maxNumOfReturns + if hasError { + if m.Type.Out(maxNumOfReturns-1) != errorType { + return nil, fmt.Errorf(`must have "error" as its last return value`) + } } } @@ -312,19 +342,26 @@ func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect. Field: *f, TypeName: typeName, MethodIndex: methodIndex, + FieldIndex: fieldIndex, HasContext: hasContext, ArgsPacker: argsPacker, HasError: hasError, TraceLabel: fmt.Sprintf("GraphQL field: %s.%s", typeName, f.Name), } - out := m.Type.Out(0) - if typeName == "Subscription" && out.Kind() == reflect.Chan { - out = m.Type.Out(0).Elem() + var out reflect.Type + if methodIndex != -1 { + out = m.Type.Out(0) + if typeName == "Subscription" && out.Kind() == reflect.Chan { + out = m.Type.Out(0).Elem() + } + } else { + out = sf.Type } if err := b.assignExec(&fe.ValueExec, f.Type, out); err != nil { return nil, err } + return fe, nil } @@ -337,6 +374,15 @@ func findMethod(t reflect.Type, name string) int { return -1 } +func findField(t reflect.Type, name string) int { + for i := 0; i < t.NumField(); i++ { + if strings.EqualFold(stripUnderscore(name), stripUnderscore(t.Field(i).Name)) { + return i + } + } + return -1 +} + func unwrapNonNull(t common.Type) (common.Type, bool) { if nn, ok := t.(*common.NonNull); ok { return nn.OfType, true @@ -347,3 +393,10 @@ func unwrapNonNull(t common.Type) (common.Type, bool) { func stripUnderscore(s string) string { return strings.Replace(s, "_", "", -1) } + +func unwrapPtr(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + return t.Elem() + } + return t +} diff --git a/internal/schema/schema.go b/internal/schema/schema.go index 569b26b2..08cc47e3 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -41,6 +41,8 @@ type Schema struct { // http://facebook.github.io/graphql/draft/#sec-Type-System.Directives Directives map[string]*DirectiveDecl + UseFieldResolvers bool + entryPointNames map[string]string objects []*Object unions []*Union