Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically add type conversions around wrapped types #13

Merged
merged 1 commit into from
Feb 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions codegen/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*

b := &Build{
PackageName: filepath.Base(destDir),
Objects: buildObjects(namedTypes, schema, prog),
Objects: buildObjects(namedTypes, schema, prog, imports),
Interfaces: buildInterfaces(namedTypes, schema),
Inputs: buildInputs(namedTypes, schema, prog),
Inputs: buildInputs(namedTypes, schema, prog, imports),
Imports: imports,
}

Expand All @@ -56,19 +56,19 @@ func Bind(schema *schema.Schema, userTypes map[string]string, destDir string) (*
// Poke a few magic methods into query
q := b.Objects.ByName(b.QueryRoot.GQLType)
q.Fields = append(q.Fields, Field{
Type: &Type{namedTypes["__Schema"], []string{modPtr}},
Type: &Type{namedTypes["__Schema"], []string{modPtr}, ""},
GQLName: "__schema",
NoErr: true,
GoMethodName: "ec.introspectSchema",
Object: q,
})
q.Fields = append(q.Fields, Field{
Type: &Type{namedTypes["__Type"], []string{modPtr}},
Type: &Type{namedTypes["__Type"], []string{modPtr}, ""},
GQLName: "__type",
NoErr: true,
GoMethodName: "ec.introspectType",
Args: []FieldArgument{
{GQLName: "name", Type: &Type{namedTypes["String"], []string{}}},
{GQLName: "name", Type: &Type{namedTypes["String"], []string{}, ""}},
},
Object: q,
})
Expand Down
4 changes: 2 additions & 2 deletions codegen/input_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"golang.org/x/tools/go/loader"
)

func buildInputs(namedTypes NamedTypes, s *schema.Schema, prog *loader.Program) Objects {
func buildInputs(namedTypes NamedTypes, s *schema.Schema, prog *loader.Program, imports Imports) Objects {
var inputs Objects

for _, typ := range s.Types {
Expand All @@ -25,7 +25,7 @@ func buildInputs(namedTypes NamedTypes, s *schema.Schema, prog *loader.Program)
}
if def != nil {
input.Marshaler = buildInputMarshaler(typ, def)
bindObject(def.Type(), input)
bindObject(def.Type(), input, imports)
}

inputs = append(inputs, input)
Expand Down
9 changes: 0 additions & 9 deletions codegen/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,6 @@ type FieldArgument struct {

type Objects []*Object

func (o *Object) GetField(name string) *Field {
for i, field := range o.Fields {
if strings.EqualFold(field.GQLName, name) {
return &o.Fields[i]
}
}
return nil
}

func (o *Object) Implementors() string {
satisfiedBy := strconv.Quote(o.GQLType)
for _, s := range o.Satisfies {
Expand Down
4 changes: 2 additions & 2 deletions codegen/object_build.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"golang.org/x/tools/go/loader"
)

func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Objects {
func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program, imports Imports) Objects {
var objects Objects

for _, typ := range s.Types {
Expand All @@ -23,7 +23,7 @@ func buildObjects(types NamedTypes, s *schema.Schema, prog *loader.Program) Obje
fmt.Fprintf(os.Stderr, err.Error())
}
if def != nil {
bindObject(def.Type(), obj)
bindObject(def.Type(), obj, imports)
}

objects = append(objects, obj)
Expand Down
30 changes: 27 additions & 3 deletions codegen/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type Type struct {
*NamedType

Modifiers []string
CastType string // the type to cast to when unmarshalling
}

const (
Expand All @@ -46,6 +47,15 @@ func (t Type) Signature() string {
return strings.Join(t.Modifiers, "") + t.FullName()
}

func (t Type) FullSignature() string {
pkg := ""
if t.Package != "" {
pkg = t.Package + "."
}

return strings.Join(t.Modifiers, "") + pkg + t.GoType
}

func (t Type) IsPtr() bool {
return len(t.Modifiers) > 0 && t.Modifiers[0] == modPtr
}
Expand All @@ -59,18 +69,32 @@ func (t NamedType) IsMarshaled() bool {
}

func (t Type) Unmarshal(result, raw string) string {
if t.Marshaler != nil {
return result + ", err := " + t.Marshaler.pkgDot() + "Unmarshal" + t.Marshaler.GoType + "(" + raw + ")"
realResult := result
if t.CastType != "" {
result = "castTmp"
}
return tpl(`var {{.result}} {{.type}}
ret := tpl(`var {{.result}} {{.type}}
err := (&{{.result}}).UnmarshalGQL({{.raw}})`, map[string]interface{}{
"result": result,
"raw": raw,
"type": t.FullName(),
})

if t.Marshaler != nil {
ret = result + ", err := " + t.Marshaler.pkgDot() + "Unmarshal" + t.Marshaler.GoType + "(" + raw + ")"
}

if t.CastType != "" {
ret += "\n" + realResult + " := " + t.CastType + "(castTmp)"
}
return ret
}

func (t Type) Marshal(result, val string) string {
if t.CastType != "" {
val = t.GoType + "(" + val + ")"
}

if t.Marshaler != nil {
return result + " = " + t.Marshaler.pkgDot() + "Marshal" + t.Marshaler.GoType + "(" + val + ")"
}
Expand Down
129 changes: 82 additions & 47 deletions codegen/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,69 +47,104 @@ func isMethod(t types.Object) bool {
return f.Type().(*types.Signature).Recv() != nil
}

func bindObject(t types.Type, object *Object) bool {
switch t := t.(type) {
case *types.Named:
for i := 0; i < t.NumMethods(); i++ {
method := t.Method(i)
if !method.Exported() {
continue
}
func findMethod(typ *types.Named, name string) *types.Func {
for i := 0; i < typ.NumMethods(); i++ {
method := typ.Method(i)
if !method.Exported() {
continue
}

if strings.EqualFold(method.Name(), name) {
return method
}
}
return nil
}

func findField(typ *types.Struct, name string) *types.Var {
for i := 0; i < typ.NumFields(); i++ {
field := typ.Field(i)
if !field.Exported() {
continue
}

if methodField := object.GetField(method.Name()); methodField != nil {
methodField.GoMethodName = "it." + method.Name()
sig := method.Type().(*types.Signature)
if strings.EqualFold(field.Name(), name) {
return field
}
}
return nil
}

methodField.Type.Modifiers = modifiersFromGoType(sig.Results().At(0).Type())
func bindObject(t types.Type, object *Object, imports Imports) {
namedType, ok := t.(*types.Named)
if !ok {
fmt.Fprintf(os.Stderr, "expected %s to be a named struct, instead found %s", object.FullName(), t.String())
return
}

// check arg order matches code, not gql
underlying, ok := t.Underlying().(*types.Struct)
if !ok {
fmt.Fprintf(os.Stderr, "expected %s to be a named struct, instead found %s", object.FullName(), t.String())
return
}

var newArgs []FieldArgument
l2:
for j := 0; j < sig.Params().Len(); j++ {
param := sig.Params().At(j)
for _, oldArg := range methodField.Args {
if strings.EqualFold(oldArg.GQLName, param.Name()) {
oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
newArgs = append(newArgs, oldArg)
continue l2
}
for i := range object.Fields {
field := &object.Fields[i]
if method := findMethod(namedType, field.GQLName); method != nil {
sig := method.Type().(*types.Signature)
field.GoMethodName = "it." + method.Name()
field.Type.Modifiers = modifiersFromGoType(sig.Results().At(0).Type())

// check arg order matches code, not gql
var newArgs []FieldArgument
l2:
for j := 0; j < sig.Params().Len(); j++ {
param := sig.Params().At(j)
for _, oldArg := range field.Args {
if strings.EqualFold(oldArg.GQLName, param.Name()) {
oldArg.Type.Modifiers = modifiersFromGoType(param.Type())
newArgs = append(newArgs, oldArg)
continue l2
}
fmt.Fprintln(os.Stderr, "cannot match argument "+param.Name()+" to any argument in "+t.String())
}
methodField.Args = newArgs
fmt.Fprintln(os.Stderr, "cannot match argument "+param.Name()+" to any argument in "+t.String())
}
field.Args = newArgs

if sig.Results().Len() == 1 {
methodField.NoErr = true
} else if sig.Results().Len() != 2 {
fmt.Fprintf(os.Stderr, "weird number of results on %s. expected either (result), or (result, error)\n", method.Name())
}
if sig.Results().Len() == 1 {
field.NoErr = true
} else if sig.Results().Len() != 2 {
fmt.Fprintf(os.Stderr, "weird number of results on %s. expected either (result), or (result, error)\n", method.Name())
}
continue
}

bindObject(t.Underlying(), object)
return true
if structField := findField(underlying, field.GQLName); structField != nil {
field.Type.Modifiers = modifiersFromGoType(structField.Type())
field.GoVarName = "it." + structField.Name()

case *types.Struct:
for i := 0; i < t.NumFields(); i++ {
field := t.Field(i)
// Todo: struct tags, name and - at least
switch field.Type.FullSignature() {
case structField.Type().String():
// everything is fine

if !field.Exported() {
continue
}
case structField.Type().Underlying().String():
pkg, typ := pkgAndType(structField.Type().String())
imp := imports.findByPkg(pkg)
field.CastType = typ
if imp.Name != "" {
field.CastType = imp.Name + "." + typ
}

// Todo: check for type matches before binding too?
if objectField := object.GetField(field.Name()); objectField != nil {
objectField.GoVarName = "it." + field.Name()
objectField.Type.Modifiers = modifiersFromGoType(field.Type())
default:
fmt.Fprintf(os.Stderr, "type mismatch on %s.%s, expected %s got %s\n", object.GQLType, field.GQLName, field.Type.FullSignature(), structField.Type())
}
continue
}
t.Underlying()
return true
}

return false
if field.IsScalar {
fmt.Fprintf(os.Stderr, "unable to bind %s.%s to anything, %s has no suitable fields or methods\n", object.GQLType, field.GQLName, namedType.String())
}
}
}

func modifiersFromGoType(t types.Type) []string {
Expand Down
17 changes: 16 additions & 1 deletion example/scalars/generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ func (ec *executionContext) _user(sel []query.Selection, it *User) graphql.Marsh
res := it.Location

out.Values[i] = res
case "isBanned":
badArgs := false
if badArgs {
continue
}
res := it.IsBanned

out.Values[i] = graphql.MarshalBoolean(bool(res))
default:
panic("unknown field " + strconv.Quote(field.Name))
}
Expand Down Expand Up @@ -807,13 +815,20 @@ func UnmarshalSearchArgs(v interface{}) (SearchArgs, error) {
return it, err
}
it.CreatedAfter = &val
case "isBanned":
castTmp, err := graphql.UnmarshalBoolean(v)
val := Banned(castTmp)
if err != nil {
return it, err
}
it.IsBanned = val
}
}

return it, nil
}

var parsedSchema = schema.MustParse("schema {\n query: Query\n}\n\ntype Query {\n user(id: ID!): User\n search(input: SearchArgs!): [User!]!\n}\n\ntype User {\n id: ID!\n name: String!\n created: Timestamp\n location: Point\n}\n\ninput SearchArgs {\n location: Point\n createdAfter: Timestamp\n}\n\nscalar Timestamp\nscalar Point\n")
var parsedSchema = schema.MustParse("schema {\n query: Query\n}\n\ntype Query {\n user(id: ID!): User\n search(input: SearchArgs!): [User!]!\n}\n\ntype User {\n id: ID!\n name: String!\n created: Timestamp\n location: Point\n isBanned: Boolean!\n}\n\ninput SearchArgs {\n location: Point\n createdAfter: Timestamp\n isBanned: Boolean\n}\n\nscalar Timestamp\nscalar Point\n")

func (ec *executionContext) introspectSchema() *introspection.Schema {
return introspection.WrapSchema(parsedSchema)
Expand Down
4 changes: 4 additions & 0 deletions example/scalars/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ import (
"github.com/vektah/gqlgen/graphql"
)

type Banned bool

type User struct {
ID string
Name string
Location Point // custom scalar types
Created time.Time // direct binding to builtin types with external Marshal/Unmarshal methods
IsBanned Banned // aliased primitive
}

// Point is serialized as a simple array, eg [1, 2]
Expand Down Expand Up @@ -71,4 +74,5 @@ func UnmarshalTimestamp(v interface{}) (time.Time, error) {
type SearchArgs struct {
Location *Point
CreatedAfter *time.Time
IsBanned Banned
}
2 changes: 2 additions & 0 deletions example/scalars/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ type User {
name: String!
created: Timestamp
location: Point
isBanned: Boolean!
}

input SearchArgs {
location: Point
createdAfter: Timestamp
isBanned: Boolean
}

scalar Timestamp
Expand Down
2 changes: 1 addition & 1 deletion example/starwars/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (h *Human) Height(unit string) float64 {
type Starship struct {
ID string
Name string
History [][2]int
History [][]int
lengthMeters float64
}

Expand Down
Loading