diff --git a/README.md b/README.md
index 8d68d2c7..01038bfa 100644
--- a/README.md
+++ b/README.md
@@ -65,7 +65,17 @@ $ curl -XPOST -d '{"query": "{ hello }"}' localhost:8080/query
### Resolvers
-A resolver must have one method for each field of the GraphQL type it resolves. The method name has to be [exported](https://golang.org/ref/spec#Exported_identifiers) and match the field's name in a non-case-sensitive way.
+A resolver must have one method or field for each field of the GraphQL type it resolves. The method or field name has to be [exported](https://golang.org/ref/spec#Exported_identifiers) and match the schema's field's name in a non-case-sensitive way.
+You can use struct fields as resolvers by using `SchemaOpt: UseFieldResolvers()`. For example,
+```
+opts := []graphql.SchemaOpt{graphql.UseFieldResolvers()}
+schema := graphql.MustParseSchema(s, &query{}, opts...)
+```
+
+When using `UseFieldResolvers` schema option, a struct field will be used *only* when:
+- there is no method for a struct field
+- a struct field does not implement an interface method
+- a struct field does not have arguments
The method has up to two arguments:
diff --git a/example/social/README.md b/example/social/README.md
new file mode 100644
index 00000000..5ab316fd
--- /dev/null
+++ b/example/social/README.md
@@ -0,0 +1,9 @@
+### Social App
+
+A simple example of how to use struct fields as resolvers instead of methods.
+
+To run this server
+
+`go run ./example/field-resolvers/server/server.go`
+
+and go to localhost:9011 to interact
\ No newline at end of file
diff --git a/example/social/server/server.go b/example/social/server/server.go
new file mode 100644
index 00000000..6bfde72b
--- /dev/null
+++ b/example/social/server/server.go
@@ -0,0 +1,62 @@
+package main
+
+import (
+ "log"
+ "net/http"
+
+ "github.com/graph-gophers/graphql-go"
+ "github.com/graph-gophers/graphql-go/example/social"
+ "github.com/graph-gophers/graphql-go/relay"
+)
+
+func main() {
+ opts := []graphql.SchemaOpt{graphql.UseFieldResolvers(), graphql.MaxParallelism(20)}
+ schema := graphql.MustParseSchema(social.Schema, &social.Resolver{}, opts...)
+
+ http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write(page)
+ }))
+
+ http.Handle("/query", &relay.Handler{Schema: schema})
+
+ log.Fatal(http.ListenAndServe(":9011", nil))
+}
+
+var page = []byte(`
+
+
+
+
+
+
+
+
+
+
+
+
Loading...
+
+
+
+`)
diff --git a/example/social/social.go b/example/social/social.go
new file mode 100644
index 00000000..67774207
--- /dev/null
+++ b/example/social/social.go
@@ -0,0 +1,206 @@
+package social
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/graph-gophers/graphql-go"
+)
+
+const Schema = `
+ schema {
+ query: Query
+ }
+
+ type Query {
+ admin(id: ID!, role: Role = ADMIN): Admin!
+ user(id: ID!): User!
+ search(text: String!): [SearchResult]!
+ }
+
+ interface Admin {
+ id: ID!
+ name: String!
+ role: Role!
+ }
+
+ scalar Time
+
+ type User implements Admin {
+ id: ID!
+ name: String!
+ email: String!
+ role: Role!
+ phone: String!
+ address: [String!]
+ friends(page: Pagination): [User]
+ createdAt: Time!
+ }
+
+ input Pagination {
+ first: Int
+ last: Int
+ }
+
+ enum Role {
+ ADMIN
+ USER
+ }
+
+ union SearchResult = User
+`
+
+type page struct {
+ First *float64
+ Last *float64
+}
+
+type admin interface {
+ ID() graphql.ID
+ Name() string
+ Role() string
+}
+
+type searchResult struct {
+ result interface{}
+}
+
+func (r *searchResult) ToUser() (*user, bool) {
+ res, ok := r.result.(*user)
+ return res, ok
+}
+
+type user struct {
+ IDField string
+ NameField string
+ RoleField string
+ Email string
+ Phone string
+ Address *[]string
+ Friends *[]*user
+ CreatedAt graphql.Time
+}
+
+func (u user) ID() graphql.ID {
+ return graphql.ID(u.IDField)
+}
+
+func (u user) Name() string {
+ return u.NameField
+}
+
+func (u user) Role() string {
+ return u.RoleField
+}
+
+func (u user) FriendsResolver(args struct{ Page *page }) (*[]*user, error) {
+ var from int
+ numFriends := len(*u.Friends)
+ to := numFriends
+
+ if args.Page != nil {
+ if args.Page.First != nil {
+ from = int(*args.Page.First)
+ if from > numFriends {
+ return nil, errors.New("not enough users")
+ }
+ }
+ if args.Page.Last != nil {
+ to = int(*args.Page.Last)
+ if to == 0 || to > numFriends {
+ to = numFriends
+ }
+ }
+ }
+
+ friends := (*u.Friends)[from:to]
+
+ return &friends, nil
+}
+
+var users = []*user{
+ {
+ IDField: "0x01",
+ NameField: "Albus Dumbledore",
+ RoleField: "ADMIN",
+ Email: "Albus@hogwarts.com",
+ Phone: "000-000-0000",
+ Address: &[]string{"Office @ Hogwarts", "where Horcruxes are"},
+ CreatedAt: graphql.Time{Time: time.Now()},
+ },
+ {
+ IDField: "0x02",
+ NameField: "Harry Potter",
+ RoleField: "USER",
+ Email: "harry@hogwarts.com",
+ Phone: "000-000-0001",
+ Address: &[]string{"123 dorm room @ Hogwarts", "456 random place"},
+ CreatedAt: graphql.Time{Time: time.Now()},
+ },
+ {
+ IDField: "0x03",
+ NameField: "Hermione Granger",
+ RoleField: "USER",
+ Email: "hermione@hogwarts.com",
+ Phone: "000-000-0011",
+ Address: &[]string{"233 dorm room @ Hogwarts", "786 @ random place"},
+ CreatedAt: graphql.Time{Time: time.Now()},
+ },
+ {
+ IDField: "0x04",
+ NameField: "Ronald Weasley",
+ RoleField: "USER",
+ Email: "ronald@hogwarts.com",
+ Phone: "000-000-0111",
+ Address: &[]string{"411 dorm room @ Hogwarts", "981 @ random place"},
+ CreatedAt: graphql.Time{Time: time.Now()},
+ },
+}
+
+var usersMap = make(map[string]*user)
+
+func init() {
+ users[0].Friends = &[]*user{users[1]}
+ users[1].Friends = &[]*user{users[0], users[2], users[3]}
+ users[2].Friends = &[]*user{users[1], users[3]}
+ users[3].Friends = &[]*user{users[1], users[2]}
+ for _, usr := range users {
+ usersMap[usr.IDField] = usr
+ }
+}
+
+type Resolver struct{}
+
+func (r *Resolver) Admin(ctx context.Context, args struct {
+ ID string
+ Role string
+}) (admin, error) {
+ if usr, ok := usersMap[args.ID]; ok {
+ if usr.RoleField == args.Role {
+ return *usr, nil
+ }
+ }
+ err := fmt.Errorf("user with id=%s and role=%s does not exist", args.ID, args.Role)
+ return user{}, err
+}
+
+func (r *Resolver) User(ctx context.Context, args struct{ Id string }) (user, error) {
+ if usr, ok := usersMap[args.Id]; ok {
+ return *usr, nil
+ }
+ err := fmt.Errorf("user with id=%s does not exist", args.Id)
+ return user{}, err
+}
+
+func (r *Resolver) Search(ctx context.Context, args struct{ Text string }) ([]*searchResult, error) {
+ var result []*searchResult
+ for _, usr := range users {
+ if strings.Contains(usr.NameField, args.Text) {
+ result = append(result, &searchResult{usr})
+ }
+ }
+ return result, nil
+}
diff --git a/graphql.go b/graphql.go
index 35768a47..f3fe32eb 100644
--- a/graphql.go
+++ b/graphql.go
@@ -2,9 +2,8 @@ package graphql
import (
"context"
- "fmt"
-
"encoding/json"
+ "fmt"
"github.com/graph-gophers/graphql-go/errors"
"github.com/graph-gophers/graphql-go/internal/common"
@@ -84,6 +83,13 @@ func UseStringDescriptions() SchemaOpt {
}
}
+// UseFieldResolvers specifies whether to use struct field resolvers
+func UseFieldResolvers() SchemaOpt {
+ return func(s *Schema) {
+ s.schema.UseFieldResolvers = true
+ }
+}
+
// MaxDepth specifies the maximum field nesting depth in a query. The default is 0 which disables max depth checking.
func MaxDepth(n int) SchemaOpt {
return func(s *Schema) {
diff --git a/internal/exec/exec.go b/internal/exec/exec.go
index c326fc95..e878888f 100644
--- a/internal/exec/exec.go
+++ b/internal/exec/exec.go
@@ -178,24 +178,33 @@ func execFieldSelection(ctx context.Context, r *Request, f *fieldToExec, path *p
return errors.Errorf("%s", err) // don't execute any more resolvers if context got cancelled
}
- var in []reflect.Value
- if f.field.HasContext {
- in = append(in, reflect.ValueOf(traceCtx))
- }
- if f.field.ArgsPacker != nil {
- in = append(in, f.field.PackedArgs)
- }
- callOut := f.resolver.Method(f.field.MethodIndex).Call(in)
- result = callOut[0]
- if f.field.HasError && !callOut[1].IsNil() {
- resolverErr := callOut[1].Interface().(error)
- err := errors.Errorf("%s", resolverErr)
- err.Path = path.toSlice()
- err.ResolverError = resolverErr
- if ex, ok := callOut[1].Interface().(extensionser); ok {
- err.Extensions = ex.Extensions()
+ res := f.resolver
+ if f.field.UseMethodResolver() {
+ var in []reflect.Value
+ if f.field.HasContext {
+ in = append(in, reflect.ValueOf(traceCtx))
+ }
+ if f.field.ArgsPacker != nil {
+ in = append(in, f.field.PackedArgs)
+ }
+ callOut := res.Method(f.field.MethodIndex).Call(in)
+ result = callOut[0]
+ if f.field.HasError && !callOut[1].IsNil() {
+ resolverErr := callOut[1].Interface().(error)
+ err := errors.Errorf("%s", resolverErr)
+ err.Path = path.toSlice()
+ err.ResolverError = resolverErr
+ if ex, ok := callOut[1].Interface().(extensionser); ok {
+ err.Extensions = ex.Extensions()
+ }
+ return err
+ }
+ } else {
+ // TODO extract out unwrapping ptr logic to a common place
+ if res.Kind() == reflect.Ptr {
+ res = res.Elem()
}
- return err
+ result = res.Field(f.field.FieldIndex)
}
return nil
}()
diff --git a/internal/exec/resolvable/resolvable.go b/internal/exec/resolvable/resolvable.go
index c4802520..27809230 100644
--- a/internal/exec/resolvable/resolvable.go
+++ b/internal/exec/resolvable/resolvable.go
@@ -33,6 +33,7 @@ type Field struct {
schema.Field
TypeName string
MethodIndex int
+ FieldIndex int
HasContext bool
HasError bool
ArgsPacker *packer.StructPacker
@@ -40,6 +41,10 @@ type Field struct {
TraceLabel string
}
+func (f *Field) UseMethodResolver() bool {
+ return f.FieldIndex == -1
+}
+
type TypeAssertion struct {
MethodIndex int
TypeExec Resolvable
@@ -189,13 +194,13 @@ func makeScalarExec(t *schema.Scalar, resolverType reflect.Type) (Resolvable, er
implementsType := false
switch r := reflect.New(resolverType).Interface().(type) {
case *int32:
- implementsType = (t.Name == "Int")
+ implementsType = t.Name == "Int"
case *float64:
- implementsType = (t.Name == "Float")
+ implementsType = t.Name == "Float"
case *string:
- implementsType = (t.Name == "String")
+ implementsType = t.Name == "String"
case *bool:
- implementsType = (t.Name == "Boolean")
+ implementsType = t.Name == "Boolean"
case packer.Unmarshaler:
implementsType = r.ImplementsGraphQLType(t.Name)
}
@@ -205,7 +210,8 @@ func makeScalarExec(t *schema.Scalar, resolverType reflect.Type) (Resolvable, er
return &Scalar{}, nil
}
-func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, possibleTypes []*schema.Object, nonNull bool, resolverType reflect.Type) (*Object, error) {
+func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, possibleTypes []*schema.Object,
+ nonNull bool, resolverType reflect.Type) (*Object, error) {
if !nonNull {
if resolverType.Kind() != reflect.Ptr && resolverType.Kind() != reflect.Interface {
return nil, fmt.Errorf("%s is not a pointer or interface", resolverType)
@@ -215,9 +221,14 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p
methodHasReceiver := resolverType.Kind() != reflect.Interface
Fields := make(map[string]*Field)
+ rt := unwrapPtr(resolverType)
for _, f := range fields {
+ fieldIndex := -1
methodIndex := findMethod(resolverType, f.Name)
- if methodIndex == -1 {
+ if b.schema.UseFieldResolvers && methodIndex == -1 {
+ fieldIndex = findField(rt, f.Name)
+ }
+ if methodIndex == -1 && fieldIndex == -1 {
hint := ""
if findMethod(reflect.PtrTo(resolverType), f.Name) != -1 {
hint = " (hint: the method exists on the pointer type)"
@@ -225,30 +236,41 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p
return nil, fmt.Errorf("%s does not resolve %q: missing method for field %q%s", resolverType, typeName, f.Name, hint)
}
- m := resolverType.Method(methodIndex)
- fe, err := b.makeFieldExec(typeName, f, m, methodIndex, methodHasReceiver)
+ var m reflect.Method
+ var sf reflect.StructField
+ if methodIndex != -1 {
+ m = resolverType.Method(methodIndex)
+ } else {
+ sf = rt.Field(fieldIndex)
+ }
+ fe, err := b.makeFieldExec(typeName, f, m, sf, methodIndex, fieldIndex, methodHasReceiver)
if err != nil {
return nil, fmt.Errorf("%s\n\treturned by (%s).%s", err, resolverType, m.Name)
}
Fields[f.Name] = fe
}
+ // Check type assertions when
+ // 1) using method resolvers
+ // 2) Or resolver is not an interface type
typeAssertions := make(map[string]*TypeAssertion)
- for _, impl := range possibleTypes {
- methodIndex := findMethod(resolverType, "To"+impl.Name)
- if methodIndex == -1 {
- return nil, fmt.Errorf("%s does not resolve %q: missing method %q to convert to %q", resolverType, typeName, "To"+impl.Name, impl.Name)
- }
- if resolverType.Method(methodIndex).Type.NumOut() != 2 {
- return nil, fmt.Errorf("%s does not resolve %q: method %q should return a value and a bool indicating success", resolverType, typeName, "To"+impl.Name)
- }
- a := &TypeAssertion{
- MethodIndex: methodIndex,
- }
- if err := b.assignExec(&a.TypeExec, impl, resolverType.Method(methodIndex).Type.Out(0)); err != nil {
- return nil, err
+ if !b.schema.UseFieldResolvers || resolverType.Kind() != reflect.Interface {
+ for _, impl := range possibleTypes {
+ methodIndex := findMethod(resolverType, "To"+impl.Name)
+ if methodIndex == -1 {
+ return nil, fmt.Errorf("%s does not resolve %q: missing method %q to convert to %q", resolverType, typeName, "To"+impl.Name, impl.Name)
+ }
+ if resolverType.Method(methodIndex).Type.NumOut() != 2 {
+ return nil, fmt.Errorf("%s does not resolve %q: method %q should return a value and a bool indicating success", resolverType, typeName, "To"+impl.Name)
+ }
+ a := &TypeAssertion{
+ MethodIndex: methodIndex,
+ }
+ if err := b.assignExec(&a.TypeExec, impl, resolverType.Method(methodIndex).Type.Out(0)); err != nil {
+ return nil, err
+ }
+ typeAssertions[impl.Name] = a
}
- typeAssertions[impl.Name] = a
}
return &Object{
@@ -261,50 +283,58 @@ func (b *execBuilder) makeObjectExec(typeName string, fields schema.FieldList, p
var contextType = reflect.TypeOf((*context.Context)(nil)).Elem()
var errorType = reflect.TypeOf((*error)(nil)).Elem()
-func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect.Method, methodIndex int, methodHasReceiver bool) (*Field, error) {
- in := make([]reflect.Type, m.Type.NumIn())
- for i := range in {
- in[i] = m.Type.In(i)
- }
- if methodHasReceiver {
- in = in[1:] // first parameter is receiver
- }
-
- hasContext := len(in) > 0 && in[0] == contextType
- if hasContext {
- in = in[1:]
- }
+func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect.Method, sf reflect.StructField,
+ methodIndex, fieldIndex int, methodHasReceiver bool) (*Field, error) {
var argsPacker *packer.StructPacker
- if len(f.Args) > 0 {
- if len(in) == 0 {
- return nil, fmt.Errorf("must have parameter for field arguments")
+ var hasError bool
+ var hasContext bool
+
+ // Validate resolver method only when there is one
+ if methodIndex != -1 {
+ in := make([]reflect.Type, m.Type.NumIn())
+ for i := range in {
+ in[i] = m.Type.In(i)
}
- var err error
- argsPacker, err = b.packerBuilder.MakeStructPacker(f.Args, in[0])
- if err != nil {
- return nil, err
+ if methodHasReceiver {
+ in = in[1:] // first parameter is receiver
}
- in = in[1:]
- }
- if len(in) > 0 {
- return nil, fmt.Errorf("too many parameters")
- }
+ hasContext = len(in) > 0 && in[0] == contextType
+ if hasContext {
+ in = in[1:]
+ }
- maxNumOfReturns := 2
- if m.Type.NumOut() < maxNumOfReturns-1 {
- return nil, fmt.Errorf("too few return values")
- }
+ if len(f.Args) > 0 {
+ if len(in) == 0 {
+ return nil, fmt.Errorf("must have parameter for field arguments")
+ }
+ var err error
+ argsPacker, err = b.packerBuilder.MakeStructPacker(f.Args, in[0])
+ if err != nil {
+ return nil, err
+ }
+ in = in[1:]
+ }
- if m.Type.NumOut() > maxNumOfReturns {
- return nil, fmt.Errorf("too many return values")
- }
+ if len(in) > 0 {
+ return nil, fmt.Errorf("too many parameters")
+ }
- hasError := m.Type.NumOut() == maxNumOfReturns
- if hasError {
- if m.Type.Out(maxNumOfReturns-1) != errorType {
- return nil, fmt.Errorf(`must have "error" as its last return value`)
+ maxNumOfReturns := 2
+ if m.Type.NumOut() < maxNumOfReturns-1 {
+ return nil, fmt.Errorf("too few return values")
+ }
+
+ if m.Type.NumOut() > maxNumOfReturns {
+ return nil, fmt.Errorf("too many return values")
+ }
+
+ hasError = m.Type.NumOut() == maxNumOfReturns
+ if hasError {
+ if m.Type.Out(maxNumOfReturns-1) != errorType {
+ return nil, fmt.Errorf(`must have "error" as its last return value`)
+ }
}
}
@@ -312,19 +342,26 @@ func (b *execBuilder) makeFieldExec(typeName string, f *schema.Field, m reflect.
Field: *f,
TypeName: typeName,
MethodIndex: methodIndex,
+ FieldIndex: fieldIndex,
HasContext: hasContext,
ArgsPacker: argsPacker,
HasError: hasError,
TraceLabel: fmt.Sprintf("GraphQL field: %s.%s", typeName, f.Name),
}
- out := m.Type.Out(0)
- if typeName == "Subscription" && out.Kind() == reflect.Chan {
- out = m.Type.Out(0).Elem()
+ var out reflect.Type
+ if methodIndex != -1 {
+ out = m.Type.Out(0)
+ if typeName == "Subscription" && out.Kind() == reflect.Chan {
+ out = m.Type.Out(0).Elem()
+ }
+ } else {
+ out = sf.Type
}
if err := b.assignExec(&fe.ValueExec, f.Type, out); err != nil {
return nil, err
}
+
return fe, nil
}
@@ -337,6 +374,15 @@ func findMethod(t reflect.Type, name string) int {
return -1
}
+func findField(t reflect.Type, name string) int {
+ for i := 0; i < t.NumField(); i++ {
+ if strings.EqualFold(stripUnderscore(name), stripUnderscore(t.Field(i).Name)) {
+ return i
+ }
+ }
+ return -1
+}
+
func unwrapNonNull(t common.Type) (common.Type, bool) {
if nn, ok := t.(*common.NonNull); ok {
return nn.OfType, true
@@ -347,3 +393,10 @@ func unwrapNonNull(t common.Type) (common.Type, bool) {
func stripUnderscore(s string) string {
return strings.Replace(s, "_", "", -1)
}
+
+func unwrapPtr(t reflect.Type) reflect.Type {
+ if t.Kind() == reflect.Ptr {
+ return t.Elem()
+ }
+ return t
+}
diff --git a/internal/schema/schema.go b/internal/schema/schema.go
index 569b26b2..08cc47e3 100644
--- a/internal/schema/schema.go
+++ b/internal/schema/schema.go
@@ -41,6 +41,8 @@ type Schema struct {
// http://facebook.github.io/graphql/draft/#sec-Type-System.Directives
Directives map[string]*DirectiveDecl
+ UseFieldResolvers bool
+
entryPointNames map[string]string
objects []*Object
unions []*Union