From 3fdb7f2a79bcf3195c036498cd815d8a4925d751 Mon Sep 17 00:00:00 2001 From: Rangel Reale Date: Wed, 8 Feb 2023 13:36:14 -0300 Subject: [PATCH] add replace-type parameter --- cmd/mockery.go | 1 + pkg/config/config.go | 3 +- pkg/fixtures/example_project/baz/foo.go | 12 +++ .../example_project/baz/internal/foo/foo.go | 6 ++ pkg/generator.go | 55 +++++++++++- pkg/generator_test.go | 84 +++++++++++++++++++ 6 files changed, 159 insertions(+), 2 deletions(-) create mode 100644 pkg/fixtures/example_project/baz/foo.go create mode 100644 pkg/fixtures/example_project/baz/internal/foo/foo.go diff --git a/cmd/mockery.go b/cmd/mockery.go index 76e0d77f..dbc0c618 100644 --- a/cmd/mockery.go +++ b/cmd/mockery.go @@ -75,6 +75,7 @@ func NewRootCmd() *cobra.Command { pFlags.Bool("unroll-variadic", true, "For functions with variadic arguments, do not unroll the arguments into the underlying testify call. Instead, pass variadic slice as-is.") pFlags.Bool("exported", false, "Generates public mocks for private interfaces.") pFlags.Bool("with-expecter", false, "Generate expecter utility around mock's On, Run and Return methods with explicit types. This option is NOT compatible with -unroll-variadic=false") + pFlags.StringArray("replace-type", nil, "Replace types") viper.BindPFlags(pFlags) diff --git a/pkg/config/config.go b/pkg/config/config.go index b0ca2115..db92bf26 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -51,5 +51,6 @@ type Config struct { TestOnly bool UnrollVariadic bool `mapstructure:"unroll-variadic"` Version bool - WithExpecter bool `mapstructure:"with-expecter"` + WithExpecter bool `mapstructure:"with-expecter"` + ReplaceType []string `mapstructure:"replace-type"` } diff --git a/pkg/fixtures/example_project/baz/foo.go b/pkg/fixtures/example_project/baz/foo.go new file mode 100644 index 00000000..608b8f06 --- /dev/null +++ b/pkg/fixtures/example_project/baz/foo.go @@ -0,0 +1,12 @@ +package baz + +import ( + ifoo "github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz/internal/foo" +) + +type Baz = ifoo.InternalBaz + +type Foo interface { + DoFoo() string + GetBaz() (*Baz, error) +} diff --git a/pkg/fixtures/example_project/baz/internal/foo/foo.go b/pkg/fixtures/example_project/baz/internal/foo/foo.go new file mode 100644 index 00000000..cb6930fe --- /dev/null +++ b/pkg/fixtures/example_project/baz/internal/foo/foo.go @@ -0,0 +1,6 @@ +package foo + +type InternalBaz struct { + One string + Two int +} diff --git a/pkg/generator.go b/pkg/generator.go index 69fc94b2..d0a84772 100644 --- a/pkg/generator.go +++ b/pkg/generator.go @@ -101,14 +101,67 @@ func (g *Generator) getPackageScopedType(ctx context.Context, o *types.TypeName) if o.Pkg() == nil || o.Pkg().Name() == "main" || (!g.KeepTree && g.InPackage && o.Pkg() == g.iface.Pkg) { return o.Name() } - return g.addPackageImport(ctx, o.Pkg()) + "." + o.Name() + pkg := g.addPackageImport(ctx, o.Pkg()) + name := o.Name() + g.checkReplaceType(ctx, func(from replaceType, to replaceType) bool { + if o.Pkg().Path() == from.pkg && name == from.typ { + name = to.typ + return false + } + return true + }) + return pkg + "." + name } func (g *Generator) addPackageImport(ctx context.Context, pkg *types.Package) string { return g.addPackageImportWithName(ctx, pkg.Path(), pkg.Name()) } +type replaceType struct { + alias string + pkg string + typ string +} + +func parseReplaceType(t string) replaceType { + ret := replaceType{} + r := strings.SplitN(t, ":", 2) + if len(r) > 1 { + ret.alias = r[0] + t = r[1] + } + lastInd := strings.LastIndex(t, ".") + ret.pkg = t[:lastInd] + ret.typ = t[lastInd+1:] + return ret +} + +func (g *Generator) checkReplaceType(ctx context.Context, f func(from replaceType, to replaceType) bool) { + for _, replace := range g.ReplaceType { + r := strings.SplitN(replace, "=", 2) + if len(r) == 2 { + if !f(parseReplaceType(r[0]), parseReplaceType(r[1])) { + break + } + } else { + log := zerolog.Ctx(ctx) + log.Error().Msgf("invalid replace type value: %s", replace) + } + } +} + func (g *Generator) addPackageImportWithName(ctx context.Context, path, name string) string { + g.checkReplaceType(ctx, func(from replaceType, to replaceType) bool { + if path == from.pkg { + path = to.pkg + if to.alias != "" { + name = to.alias + } + return false + } + return true + }) + if existingName, pathExists := g.packagePathToName[path]; pathExists { return existingName } diff --git a/pkg/generator_test.go b/pkg/generator_test.go index efe92eac..6963338e 100644 --- a/pkg/generator_test.go +++ b/pkg/generator_test.go @@ -2329,6 +2329,90 @@ import mock "github.com/stretchr/testify/mock" s.checkPrologueGeneration(generator, expected) } +func (s *GeneratorSuite) TestInternalPackagePrologue() { + expected := `package mocks + +import baz "github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz" +import mock "github.com/stretchr/testify/mock" + +` + generator := NewGenerator( + s.ctx, + config.Config{InPackage: false, LogLevel: "debug", ReplaceType: []string{ + "github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz/internal/foo.InternalBaz=baz:github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz.Baz", + }}, + s.getInterfaceFromFile("example_project/baz/foo.go", "Foo"), + pkg, + ) + + s.checkPrologueGeneration(generator, expected) +} + +func (s *GeneratorSuite) TestInternalPackage() { + expected := `// Foo is an autogenerated mock type for the Foo type +type Foo struct { + mock.Mock +} + +// DoFoo provides a mock function with given fields: +func (_m *Foo) DoFoo() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// GetBaz provides a mock function with given fields: +func (_m *Foo) GetBaz() (*baz.Baz, error) { + ret := _m.Called() + + var r0 *baz.Baz + if rf, ok := ret.Get(0).(func() *baz.Baz); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*baz.Baz) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type mockConstructorTestingTNewFoo interface { + mock.TestingT + Cleanup(func()) +} + +// NewFoo creates a new instance of Foo. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewFoo(t mockConstructorTestingTNewFoo) *Foo { + mock := &Foo{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} +` + cfg := config.Config{InPackage: false, LogLevel: "debug", ReplaceType: []string{ + "github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz/internal/foo.InternalBaz=baz:github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz.Baz", + }} + + s.checkGenerationWithConfig("example_project/baz/foo.go", "Foo", cfg, expected) +} + func (s *GeneratorSuite) TestGenericGenerator() { expected := `// RequesterGenerics is an autogenerated mock type for the RequesterGenerics type type RequesterGenerics[TAny interface{}, TComparable comparable, TSigned constraints.Signed, TIntf test.GetInt, TExternalIntf io.Writer, TGenIntf test.GetGeneric[TSigned], TInlineType interface{ ~int | ~uint }, TInlineTypeGeneric interface {