From 6ec365046295574d8e903ee93cf24b7d49b4b180 Mon Sep 17 00:00:00 2001 From: Adam Date: Mon, 20 Jan 2020 16:05:40 +1100 Subject: [PATCH] Copy existing resolver bodies when regenerating new resolvers --- codegen/templates/templates.go | 14 ++--- example/config/.gqlgen.yml | 3 +- example/config/resolver.go | 53 ----------------- example/config/schema_resolvers.go | 30 ++++++++++ example/config/todo_resolvers.go | 22 +++++++ internal/rewrite/rewriter.go | 87 ++++++++++++++++++++++++++++ internal/rewrite/rewriter_test.go | 22 +++++++ internal/rewrite/testdata/example.go | 14 +++++ plugin/resolvergen/resolver.go | 55 ++++++++++-------- plugin/resolvergen/resolver.gotpl | 2 +- 10 files changed, 217 insertions(+), 85 deletions(-) create mode 100644 example/config/schema_resolvers.go create mode 100644 example/config/todo_resolvers.go create mode 100644 internal/rewrite/rewriter.go create mode 100644 internal/rewrite/rewriter_test.go create mode 100644 internal/rewrite/testdata/example.go diff --git a/codegen/templates/templates.go b/codegen/templates/templates.go index 8e126d6d1fb..e8722442531 100644 --- a/codegen/templates/templates.go +++ b/codegen/templates/templates.go @@ -162,8 +162,8 @@ func center(width int, pad string, s string) string { func Funcs() template.FuncMap { return template.FuncMap{ - "ucFirst": ucFirst, - "lcFirst": lcFirst, + "ucFirst": UcFirst, + "lcFirst": LcFirst, "quote": strconv.Quote, "rawQuote": rawQuote, "dump": Dump, @@ -185,7 +185,7 @@ func Funcs() template.FuncMap { } } -func ucFirst(s string) string { +func UcFirst(s string) string { if s == "" { return "" } @@ -194,7 +194,7 @@ func ucFirst(s string) string { return string(r) } -func lcFirst(s string) string { +func LcFirst(s string) string { if s == "" { return "" } @@ -275,7 +275,7 @@ func ToGo(name string) string { if strings.ToUpper(word) == word || strings.ToLower(word) == word { // FOO or foo → Foo // FOo → FOo - word = ucFirst(strings.ToLower(word)) + word = UcFirst(strings.ToLower(word)) } } runes = append(runes, []rune(word)...) @@ -297,13 +297,13 @@ func ToGoPrivate(name string) string { word = strings.ToLower(info.Word) } else { // ITicket → iTicket - word = lcFirst(info.Word) + word = LcFirst(info.Word) } first = false case info.MatchCommonInitial: word = strings.ToUpper(word) case !info.HasCommonInitial: - word = ucFirst(strings.ToLower(word)) + word = UcFirst(strings.ToLower(word)) } runes = append(runes, []rune(word)...) }) diff --git a/example/config/.gqlgen.yml b/example/config/.gqlgen.yml index ccea6688f39..c8a41dc0e40 100644 --- a/example/config/.gqlgen.yml +++ b/example/config/.gqlgen.yml @@ -9,8 +9,9 @@ exec: model: filename: models_gen.go resolver: - filename: resolver.go type: Resolver + layout: follow-schema + dir: . models: Todo: # Object diff --git a/example/config/resolver.go b/example/config/resolver.go index 519db1fe290..182634a1756 100644 --- a/example/config/resolver.go +++ b/example/config/resolver.go @@ -2,11 +2,6 @@ package config -import ( - "context" - "fmt" -) - func New() Config { c := Config{ Resolvers: &Resolver{ @@ -25,51 +20,3 @@ type Resolver struct { todos []*Todo nextID int } - -func (r *Resolver) Mutation() MutationResolver { - return &mutationResolver{r} -} -func (r *Resolver) Query() QueryResolver { - return &queryResolver{r} -} -func (r *Resolver) Todo() TodoResolver { - return &todoResolver{r} -} - -type mutationResolver struct{ *Resolver } - -func (r *mutationResolver) CreateTodo(ctx context.Context, input NewTodo) (*Todo, error) { - newID := r.nextID - r.nextID++ - - newTodo := &Todo{ - DatabaseID: newID, - Description: input.Text, - } - - r.todos = append(r.todos, newTodo) - - return newTodo, nil -} - -type queryResolver struct{ *Resolver } - -func (r *queryResolver) Todos(ctx context.Context) ([]*Todo, error) { - return r.todos, nil -} - -type todoResolver struct{ *Resolver } - -func (r *todoResolver) Description(ctx context.Context, obj *Todo) (string, error) { - panic("implement me") -} - -func (r *todoResolver) ID(ctx context.Context, obj *Todo) (string, error) { - if obj.ID != "" { - return obj.ID, nil - } - - obj.ID = fmt.Sprintf("TODO:%d", obj.DatabaseID) - - return obj.ID, nil -} diff --git a/example/config/schema_resolvers.go b/example/config/schema_resolvers.go new file mode 100644 index 00000000000..7c590b68c03 --- /dev/null +++ b/example/config/schema_resolvers.go @@ -0,0 +1,30 @@ +// This file will be automatically regenerated based on the schema, any resolver implementations +// will be copied through when generating and any unknown code will be moved to the end. +package config + +import ( + "context" +) + +func (r *mutationResolver) CreateTodo(ctx context.Context, input NewTodo) (*Todo, error) { + newID := r.nextID + r.nextID++ + + newTodo := &Todo{ + DatabaseID: newID, + Description: input.Text, + } + + r.todos = append(r.todos, newTodo) + + return newTodo, nil +} +func (r *queryResolver) Todos(ctx context.Context) ([]*Todo, error) { + return r.todos, nil +} + +func (r *Resolver) Mutation() MutationResolver { return &mutationResolver{r} } +func (r *Resolver) Query() QueryResolver { return &queryResolver{r} } + +type mutationResolver struct{ *Resolver } +type queryResolver struct{ *Resolver } diff --git a/example/config/todo_resolvers.go b/example/config/todo_resolvers.go new file mode 100644 index 00000000000..7bddb90d90c --- /dev/null +++ b/example/config/todo_resolvers.go @@ -0,0 +1,22 @@ +// This file will be automatically regenerated based on the schema, any resolver implementations +// will be copied through when generating and any unknown code will be moved to the end. +package config + +import ( + "context" + "fmt" +) + +func (r *todoResolver) ID(ctx context.Context, obj *Todo) (string, error) { + if obj.ID != "" { + return obj.ID, nil + } + + obj.ID = fmt.Sprintf("TODO:%d", obj.DatabaseID) + + return obj.ID, nil +} + +func (r *Resolver) Todo() TodoResolver { return &todoResolver{r} } + +type todoResolver struct{ *Resolver } diff --git a/internal/rewrite/rewriter.go b/internal/rewrite/rewriter.go new file mode 100644 index 00000000000..f569eb32501 --- /dev/null +++ b/internal/rewrite/rewriter.go @@ -0,0 +1,87 @@ +package rewrite + +import ( + "fmt" + "go/ast" + "go/token" + "io/ioutil" + + "golang.org/x/tools/go/packages" +) + +type Rewriter struct { + pkg *packages.Package + files map[string]string +} + +func New(importPath string) (*Rewriter, error) { + pkgs, err := packages.Load(&packages.Config{ + Mode: packages.NeedSyntax | packages.NeedTypes, + }, importPath) + if err != nil { + return nil, err + } + + return &Rewriter{ + pkg: pkgs[0], + files: map[string]string{}, + }, nil +} + +func (r *Rewriter) getSource(start, end token.Pos) string { + startPos := r.pkg.Fset.Position(start) + endPos := r.pkg.Fset.Position(end) + + if startPos.Filename != endPos.Filename { + panic("cant get source spanning multiple files") + } + + file := r.getFile(startPos.Filename) + return file[startPos.Offset:endPos.Offset] +} + +func (r *Rewriter) getFile(filename string) string { + if _, ok := r.files[filename]; !ok { + b, err := ioutil.ReadFile(filename) + if err != nil { + panic(fmt.Errorf("unable to load file, already exists: %s", err.Error())) + } + + r.files[filename] = string(b) + + } + + return r.files[filename] +} + +func (r *Rewriter) GetMethodBody(structname string, methodname string) string { + for _, f := range r.pkg.Syntax { + for _, d := range f.Decls { + switch d := d.(type) { + case *ast.FuncDecl: + if d.Name.Name != methodname { + continue + } + if d.Recv.List == nil { + continue + } + recv := d.Recv.List[0].Type + if star, isStar := d.Recv.List[0].Type.(*ast.StarExpr); isStar { + recv = star.X + } + ident, ok := recv.(*ast.Ident) + if !ok { + continue + } + + if ident.Name != structname { + continue + } + + return r.getSource(d.Body.Pos()+1, d.Body.End()-1) + } + } + } + + return "" +} diff --git a/internal/rewrite/rewriter_test.go b/internal/rewrite/rewriter_test.go new file mode 100644 index 00000000000..367461a2e6c --- /dev/null +++ b/internal/rewrite/rewriter_test.go @@ -0,0 +1,22 @@ +package rewrite + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRewriter(t *testing.T) { + r, err := New("github.com/99designs/gqlgen/internal/rewrite/testdata") + require.NoError(t, err) + + body := r.GetMethodBody("Foo", "Method") + require.Equal(t, ` + // leading comment + + // field comment + m.Field++ + + // trailing comment +`, body) +} diff --git a/internal/rewrite/testdata/example.go b/internal/rewrite/testdata/example.go new file mode 100644 index 00000000000..db82c66c50e --- /dev/null +++ b/internal/rewrite/testdata/example.go @@ -0,0 +1,14 @@ +package testdata + +type Foo struct { + Field int +} + +func (m *Foo) Method(arg int) { + // leading comment + + // field comment + m.Field++ + + // trailing comment +} diff --git a/plugin/resolvergen/resolver.go b/plugin/resolvergen/resolver.go index f5bef243652..4b936f3cb65 100644 --- a/plugin/resolvergen/resolver.go +++ b/plugin/resolvergen/resolver.go @@ -5,12 +5,12 @@ import ( "path/filepath" "strings" - "github.com/pkg/errors" - "github.com/99designs/gqlgen/codegen" "github.com/99designs/gqlgen/codegen/config" "github.com/99designs/gqlgen/codegen/templates" + "github.com/99designs/gqlgen/internal/rewrite" "github.com/99designs/gqlgen/plugin" + "github.com/pkg/errors" ) func New() plugin.Plugin { @@ -57,7 +57,7 @@ func (m *Plugin) generateSingleFile(data *codegen.Data) error { continue } - resolver := Resolver{o, f} + resolver := Resolver{o, f, `panic("not implemented")`} file.Resolvers = append(file.Resolvers, &resolver) } } @@ -78,6 +78,11 @@ func (m *Plugin) generateSingleFile(data *codegen.Data) error { } func (m *Plugin) generatePerSchema(data *codegen.Data) error { + rewriter, err := rewrite.New(data.Config.Resolver.ImportPath()) + if err != nil { + return err + } + files := map[string]*File{} for _, o := range data.Objects { @@ -94,7 +99,13 @@ func (m *Plugin) generatePerSchema(data *codegen.Data) error { continue } - resolver := Resolver{o, f} + structName := templates.LcFirst(o.Name) + templates.UcFirst(data.Config.Resolver.Type) + implementation := strings.TrimSpace(rewriter.GetMethodBody(structName, f.GoFieldName)) + if implementation == "" { + implementation = `panic(fmt.Errorf("not implemented"))` + } + + resolver := Resolver{o, f, implementation} fn := gqlToResolverName(data.Config.Resolver.Dir(), f.Position.Src.Name) if files[fn] == nil { files[fn] = &File{} @@ -124,23 +135,20 @@ func (m *Plugin) generatePerSchema(data *codegen.Data) error { } } - if data.Config.Resolver.Layout == config.LayoutFollowSchema { - rootFilename := filepath.Join(data.Config.Resolver.Dir(), "resolver.go") - - if _, err := os.Stat(rootFilename); os.IsNotExist(errors.Cause(err)) { - err := templates.Render(templates.Options{ - PackageName: data.Config.Resolver.Package, - PackageDoc: ` - // This file will not be regenerated automatically. - // - // It serves as dependency injection for your app, add any dependencies you require here.`, - Template: `type {{.}} struct {}`, - Filename: rootFilename, - Data: data.Config.Resolver.Type, - }) - if err != nil { - return err - } + rootFilename := filepath.Join(data.Config.Resolver.Dir(), "resolver.go") + if _, err := os.Stat(rootFilename); os.IsNotExist(errors.Cause(err)) { + err := templates.Render(templates.Options{ + PackageName: data.Config.Resolver.Package, + PackageDoc: ` + // This file will not be regenerated automatically. + // + // It serves as dependency injection for your app, add any dependencies you require here.`, + Template: `type {{.}} struct {}`, + Filename: rootFilename, + Data: data.Config.Resolver.Type, + }) + if err != nil { + return err } } return nil @@ -161,8 +169,9 @@ type File struct { } type Resolver struct { - Object *codegen.Object - Field *codegen.Field + Object *codegen.Object + Field *codegen.Field + Implementation string } func (r *Resolver) filename(cfg config.ResolverConfig) string { diff --git a/plugin/resolvergen/resolver.gotpl b/plugin/resolvergen/resolver.gotpl index af36a0525d0..eab32b878d3 100644 --- a/plugin/resolvergen/resolver.gotpl +++ b/plugin/resolvergen/resolver.gotpl @@ -18,7 +18,7 @@ {{ range $resolver := .Resolvers -}} func (r *{{lcFirst $resolver.Object.Name}}{{ucFirst $.ResolverType}}) {{$resolver.Field.GoFieldName}}{{ $resolver.Field.ShortResolverDeclaration }} { - panic("not implemented") + {{ $resolver.Implementation }} } {{ end }}