Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User defined custom types #10

Merged
merged 8 commits into from
Feb 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
type Build struct {
PackageName string
Objects Objects
Inputs Objects
Interfaces []*Interface
Imports Imports
QueryRoot *Object
Expand All @@ -29,10 +30,13 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*
return nil, err
}

bindTypes(imports, namedTypes, prog)

b := &Build{
PackageName: filepath.Base(destDir),
Objects: buildObjects(namedTypes, schema, prog),
Interfaces: buildInterfaces(namedTypes, schema),
Inputs: buildInputs(namedTypes, schema, prog),
Imports: imports,
}

Expand Down
5 changes: 1 addition & 4 deletions codegen/import_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@ func buildImports(types NamedTypes, destDir string) Imports {
{"io", "io"},
{"strconv", "strconv"},
{"time", "time"},
{"reflect", "reflect"},
{"strings", "strings"},
{"sync", "sync"},
{"mapstructure", "github.com/mitchellh/mapstructure"},
{"introspection", "github.com/vektah/gqlgen/neelance/introspection"},
{"errors", "github.com/vektah/gqlgen/neelance/errors"},
{"query", "github.com/vektah/gqlgen/neelance/query"},
{"schema", "github.com/vektah/gqlgen/neelance/schema"},
{"validation", "github.com/vektah/gqlgen/neelance/validation"},
{"jsonw", "github.com/vektah/gqlgen/jsonw"},
{"graphql", "github.com/vektah/gqlgen/graphql"},
}

for _, t := range types {
Expand Down
70 changes: 70 additions & 0 deletions codegen/input_build.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package codegen

import (
"fmt"
"go/types"
"os"
"sort"
"strings"

"github.com/vektah/gqlgen/neelance/schema"
"golang.org/x/tools/go/loader"
)

func buildInputs(namedTypes NamedTypes, s *schema.Schema, prog *loader.Program) Objects {
var inputs Objects

for _, typ := range s.Types {
switch typ := typ.(type) {
case *schema.InputObject:
input := buildInput(namedTypes, typ)

def, err := findGoType(prog, input.Package, input.GoType)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
}
if def != nil {
input.Marshaler = buildInputMarshaler(typ, def)
bindObject(def.Type(), input)
}

inputs = append(inputs, input)
}
}

sort.Slice(inputs, func(i, j int) bool {
return strings.Compare(inputs[i].GQLType, inputs[j].GQLType) == -1
})

return inputs
}

func buildInput(types NamedTypes, typ *schema.InputObject) *Object {
obj := &Object{NamedType: types[typ.TypeName()]}

for _, field := range typ.Values {
obj.Fields = append(obj.Fields, Field{
GQLName: field.Name.Name,
Type: types.getType(field.Type),
Object: obj,
})
}
return obj
}

// if user has implemented an UnmarshalGQL method on the input type manually, use it
// otherwise we will generate one.
func buildInputMarshaler(typ *schema.InputObject, def types.Object) *Ref {
switch def := def.(type) {
case *types.TypeName:
namedType := def.Type().(*types.Named)
for i := 0; i < namedType.NumMethods(); i++ {
method := namedType.Method(i)
if method.Name() == "UnmarshalGQL" {
return nil
}
}
}

return &Ref{GoType: typ.Name}
}
28 changes: 15 additions & 13 deletions codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type FieldArgument struct {
GQLName string // The name of the argument in graphql
}

type Objects []*Object

func (o *Object) GetField(name string) *Field {
for i, field := range o.Fields {
if strings.EqualFold(field.GQLName, name) {
Expand Down Expand Up @@ -105,7 +107,7 @@ func (f *Field) doWriteJson(res string, val string, remainingMods []string, isPt
case len(remainingMods) > 0 && remainingMods[0] == modPtr:
return tpl(`
if {{.val}} == nil {
{{.res}} = jsonw.Null
{{.res}} = graphql.Null
} else {
{{.next}}
}`, map[string]interface{}{
Expand All @@ -123,9 +125,9 @@ func (f *Field) doWriteJson(res string, val string, remainingMods []string, isPt
var index = "idx" + strconv.Itoa(depth)

return tpl(`
{{.arr}} := jsonw.Array{}
{{.arr}} := graphql.Array{}
for {{.index}} := range {{.val}} {
var {{.tmp}} jsonw.Writer
var {{.tmp}} graphql.Marshaler
{{.next}}
{{.arr}} = append({{.arr}}, {{.tmp}})
}
Expand All @@ -142,7 +144,7 @@ func (f *Field) doWriteJson(res string, val string, remainingMods []string, isPt
if isPtr {
val = "*" + val
}
return fmt.Sprintf("%s = jsonw.%s(%s)", res, ucFirst(f.GoType), val)
return f.Marshal(res, val)

default:
if !isPtr {
Expand All @@ -152,21 +154,21 @@ func (f *Field) doWriteJson(res string, val string, remainingMods []string, isPt
}
}

func (os Objects) ByName(name string) *Object {
for i, o := range os {
if strings.EqualFold(o.GQLType, name) {
return os[i]
}
}
return nil
}

func tpl(tpl string, vars map[string]interface{}) string {
b := &bytes.Buffer{}
template.Must(template.New("inline").Parse(tpl)).Execute(b, vars)
return b.String()
}

func ucFirst(s string) string {
if s == "" {
return ""
}
r := []rune(s)
r[0] = unicode.ToUpper(r[0])
return string(r)
}

func lcFirst(s string) string {
if s == "" {
return ""
Expand Down
146 changes: 8 additions & 138 deletions codegen/object_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package codegen

import (
"fmt"
"go/types"
"os"
"sort"
"strings"
Expand All @@ -11,16 +10,21 @@ import (
"golang.org/x/tools/go/loader"
)

type Objects []*Object

func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Objects {
var objects Objects

for _, typ := range s.Types {
switch typ := typ.(type) {
case *schema.Object:
obj := buildObject(types, typ)
bindObject(prog, obj)

def, err := findGoType(prog, obj.Package, obj.GoType)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
}
if def != nil {
bindObject(def.Type(), obj)
}

objects = append(objects, obj)
}
Expand All @@ -41,15 +45,6 @@ func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Obje
return objects
}

func (os Objects) ByName(name string) *Object {
for i, o := range os {
if strings.EqualFold(o.GQLType, name) {
return os[i]
}
}
return nil
}

func buildObject(types NamedTypes, typ *schema.Object) *Object {
obj := &Object{NamedType: types[typ.TypeName()]}

Expand All @@ -75,128 +70,3 @@ func buildObject(types NamedTypes, typ *schema.Object) *Object {
}
return obj
}

func bindObject(prog *loader.Program, obj *Object) {
if obj.Package == "" {
return
}
pkgName, err := resolvePkg(obj.Package)
if err != nil {
fmt.Fprintf(os.Stderr, "unable to resolve package for %s: %s\n", obj.GQLType, err.Error())
return
}

pkg := prog.Imported[pkgName]
if pkg == nil {
fmt.Fprintf(os.Stderr, "required package was not loaded: %s", pkgName)
return
}

for astNode, object := range pkg.Defs {
if astNode.Name != obj.GoType {
continue
}

if findBindTargets(object.Type(), obj) {
return
}
}
}

func findBindTargets(t types.Type, object *Object) bool {
switch t := t.(type) {
case *types.Named:
for i := 0; i < t.NumMethods(); i++ {
method := t.Method(i)
if !method.Exported() {
continue
}

if methodField := object.GetField(method.Name()); methodField != nil {
methodField.GoMethodName = "it." + method.Name()
sig := method.Type().(*types.Signature)

methodField.Type.Modifiers = modifiersFromGoType(sig.Results().At(0).Type())

// check arg order matches code, not gql

var newArgs []FieldArgument
l2:
for j := 0; j < sig.Params().Len(); j++ {
param := sig.Params().At(j)
for _, oldArg := range methodField.Args {
if strings.EqualFold(oldArg.GQLName, param.Name()) {
oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
newArgs = append(newArgs, oldArg)
continue l2
}
}
fmt.Fprintln(os.Stderr, "cannot match argument "+param.Name()+" to any argument in "+t.String())
}
methodField.Args = newArgs

if sig.Results().Len() == 1 {
methodField.NoErr = true
} else if sig.Results().Len() != 2 {
fmt.Fprintf(os.Stderr, "weird number of results on %s. expected either (result), or (result, error)\n", method.Name())
}
}
}

findBindTargets(t.Underlying(), object)
return true

case *types.Struct:
for i := 0; i < t.NumFields(); i++ {
field := t.Field(i)
// Todo: struct tags, name and - at least

if !field.Exported() {
continue
}

// Todo: check for type matches before binding too?
if objectField := object.GetField(field.Name()); objectField != nil {
objectField.GoVarName = "it." + field.Name()
objectField.Type.Modifiers = modifiersFromGoType(field.Type())
}
}
t.Underlying()
return true
}

return false
}

func mutationRoot(schema *schema.Schema) string {
if mu, ok := schema.EntryPoints["mutation"]; ok {
return mu.TypeName()
}
return ""
}

func queryRoot(schema *schema.Schema) string {
if mu, ok := schema.EntryPoints["mutation"]; ok {
return mu.TypeName()
}
return ""
}

func modifiersFromGoType(t types.Type) []string {
var modifiers []string
for {
switch val := t.(type) {
case *types.Pointer:
modifiers = append(modifiers, modPtr)
t = val.Elem()
case *types.Array:
modifiers = append(modifiers, modList)
t = val.Elem()
case *types.Slice:
modifiers = append(modifiers, modList)
t = val.Elem()
default:
return modifiers
}
}
}
Loading