From b9f0105775f141a44ad7f416699a020458b04b0f Mon Sep 17 00:00:00 2001 From: Gevrai Jodoin-Tremblay Date: Sun, 11 Jul 2021 13:41:38 -0400 Subject: [PATCH] Add mock generation with expecter --- cmd/mockery.go | 1 + go.sum | 2 + pkg/config/config.go | 3 +- pkg/fixtures/expecterTest.go | 9 ++ pkg/fixtures/mocks/expecter.go | 228 ++++++++++++++++++++++++++++ pkg/fixtures/mocks/expecter_test.go | 194 +++++++++++++++++++++++ pkg/generator.go | 156 +++++++++++++++++-- pkg/generator_test.go | 145 +++++++++++++++--- 8 files changed, 709 insertions(+), 29 deletions(-) create mode 100644 pkg/fixtures/expecterTest.go create mode 100755 pkg/fixtures/mocks/expecter.go create mode 100644 pkg/fixtures/mocks/expecter_test.go diff --git a/cmd/mockery.go b/cmd/mockery.go index 4e3c82632..143dca759 100644 --- a/cmd/mockery.go +++ b/cmd/mockery.go @@ -93,6 +93,7 @@ func init() { pFlags.String("boilerplate-file", "", "File to read a boilerplate text from. Text should be a go block comment, i.e. /* ... */") pFlags.Bool("unroll-variadic", true, "For functions with variadic arguments, do not unroll the arguments into the underlying testify call. Instead, pass variadic slice as-is.") pFlags.Bool("exported", false, "Generates public mocks for private interfaces.") + pFlags.Bool("with-expecter", false, "Generate expecter utility around mock's On, Run and Return methods with explicit types. This option is NOT compatible with -unroll-variadic=false") viper.BindPFlags(pFlags) } diff --git a/go.sum b/go.sum index 369441b6a..372108342 100644 --- a/go.sum +++ b/go.sum @@ -193,6 +193,8 @@ github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= +github.com/vektra/mockery/v2 v2.9.0 h1:+3FhCL3EviR779mTzXwUuhPNnqFUA7sDnt9OFkXaFd4= +github.com/vektra/mockery/v2 v2.9.0/go.mod h1:2gU4Cf/f8YyC8oEaSXfCnZBMxMjMl/Ko205rlP0fO90= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/pkg/config/config.go b/pkg/config/config.go index 335f06bdc..0518ae83d 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -4,7 +4,7 @@ package config var SemVer = "v0.0.0-dev" func GetSemverInfo() string { - return SemVer + return SemVer } type Config struct { @@ -39,4 +39,5 @@ type Config struct { TestOnly bool UnrollVariadic bool `mapstructure:"unroll-variadic"` Version bool + WithExpecter bool `mapstructure:"with-expecter"` } diff --git a/pkg/fixtures/expecterTest.go b/pkg/fixtures/expecterTest.go new file mode 100644 index 000000000..d9027c5f0 --- /dev/null +++ b/pkg/fixtures/expecterTest.go @@ -0,0 +1,9 @@ +package test + +type ExpecterTest interface { + NoArg() string + NoReturn(str string) + ManyArgsReturns(str string, i int) (strs []string, err error) + Variadic(ints ...int) error + VariadicMany(i int, a string, intfs ...interface{}) error +} diff --git a/pkg/fixtures/mocks/expecter.go b/pkg/fixtures/mocks/expecter.go new file mode 100755 index 000000000..7d44e4a17 --- /dev/null +++ b/pkg/fixtures/mocks/expecter.go @@ -0,0 +1,228 @@ +// Code generated by mockery v1.0.0. DO NOT EDIT. + +package mocks + +import "github.com/stretchr/testify/mock" + +// ExpecterTest is an autogenerated mock type for the ExpecterTest type +type ExpecterTest struct { + mock.Mock +} + +type ExpecterTest_Expecter struct { + mock *mock.Mock +} + +func (_m *ExpecterTest) EXPECT() *ExpecterTest_Expecter { + return &ExpecterTest_Expecter{mock: &_m.Mock} +} + +// ManyArgsReturns provides a mock function with given fields: str, i +func (_m *ExpecterTest) ManyArgsReturns(str string, i int) ([]string, error) { + ret := _m.Called(str, i) + + var r0 []string + if rf, ok := ret.Get(0).(func(string, int) []string); ok { + r0 = rf(str, i) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, int) error); ok { + r1 = rf(str, i) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ExpecterTest_ManyArgsReturns_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ManyArgsReturns' +type ExpecterTest_ManyArgsReturns_Call struct { + *mock.Call +} + +// ManyArgsReturns is a helper method to define mock.On call +// - str string +// - i int +func (_e *ExpecterTest_Expecter) ManyArgsReturns(str interface{}, i interface{}) *ExpecterTest_ManyArgsReturns_Call { + return &ExpecterTest_ManyArgsReturns_Call{Call: _e.mock.On("ManyArgsReturns", str, i)} +} + +func (_c *ExpecterTest_ManyArgsReturns_Call) Run(run func(str string, i int)) *ExpecterTest_ManyArgsReturns_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string), args[1].(int)) + }) + return _c +} + +func (_c *ExpecterTest_ManyArgsReturns_Call) Return(strs []string, err error) *ExpecterTest_ManyArgsReturns_Call { + _c.Call.Return(strs, err) + return _c +} + +// NoArg provides a mock function with given fields: +func (_m *ExpecterTest) NoArg() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// ExpecterTest_NoArg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NoArg' +type ExpecterTest_NoArg_Call struct { + *mock.Call +} + +// NoArg is a helper method to define mock.On call +func (_e *ExpecterTest_Expecter) NoArg() *ExpecterTest_NoArg_Call { + return &ExpecterTest_NoArg_Call{Call: _e.mock.On("NoArg")} +} + +func (_c *ExpecterTest_NoArg_Call) Run(run func()) *ExpecterTest_NoArg_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *ExpecterTest_NoArg_Call) Return(_a0 string) *ExpecterTest_NoArg_Call { + _c.Call.Return(_a0) + return _c +} + +// NoReturn provides a mock function with given fields: str +func (_m *ExpecterTest) NoReturn(str string) { + _m.Called(str) +} + +// ExpecterTest_NoReturn_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NoReturn' +type ExpecterTest_NoReturn_Call struct { + *mock.Call +} + +// NoReturn is a helper method to define mock.On call +// - str string +func (_e *ExpecterTest_Expecter) NoReturn(str interface{}) *ExpecterTest_NoReturn_Call { + return &ExpecterTest_NoReturn_Call{Call: _e.mock.On("NoReturn", str)} +} + +func (_c *ExpecterTest_NoReturn_Call) Run(run func(str string)) *ExpecterTest_NoReturn_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *ExpecterTest_NoReturn_Call) Return() *ExpecterTest_NoReturn_Call { + _c.Call.Return() + return _c +} + +// Variadic provides a mock function with given fields: ints +func (_m *ExpecterTest) Variadic(ints ...int) error { + _va := make([]interface{}, len(ints)) + for _i := range ints { + _va[_i] = ints[_i] + } + var _ca []interface{} + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(...int) error); ok { + r0 = rf(ints...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ExpecterTest_Variadic_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Variadic' +type ExpecterTest_Variadic_Call struct { + *mock.Call +} + +// Variadic is a helper method to define mock.On call +// - ints ...int +func (_e *ExpecterTest_Expecter) Variadic(ints ...interface{}) *ExpecterTest_Variadic_Call { + return &ExpecterTest_Variadic_Call{Call: _e.mock.On("Variadic", + append([]interface{}{}, ints...)...)} +} + +func (_c *ExpecterTest_Variadic_Call) Run(run func(ints ...int)) *ExpecterTest_Variadic_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]int, len(args)-0) + for i, a := range args[0:] { + if a != nil { + variadicArgs[i] = a.(int) + } + } + run(variadicArgs...) + }) + return _c +} + +func (_c *ExpecterTest_Variadic_Call) Return(_a0 error) *ExpecterTest_Variadic_Call { + _c.Call.Return(_a0) + return _c +} + +// VariadicMany provides a mock function with given fields: i, a, intfs +func (_m *ExpecterTest) VariadicMany(i int, a string, intfs ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, i, a) + _ca = append(_ca, intfs...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(int, string, ...interface{}) error); ok { + r0 = rf(i, a, intfs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ExpecterTest_VariadicMany_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VariadicMany' +type ExpecterTest_VariadicMany_Call struct { + *mock.Call +} + +// VariadicMany is a helper method to define mock.On call +// - i int +// - a string +// - intfs ...interface{} +func (_e *ExpecterTest_Expecter) VariadicMany(i interface{}, a interface{}, intfs ...interface{}) *ExpecterTest_VariadicMany_Call { + return &ExpecterTest_VariadicMany_Call{Call: _e.mock.On("VariadicMany", + append([]interface{}{i, a}, intfs...)...)} +} + +func (_c *ExpecterTest_VariadicMany_Call) Run(run func(i int, a string, intfs ...interface{})) *ExpecterTest_VariadicMany_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(int), args[1].(string), variadicArgs...) + }) + return _c +} + +func (_c *ExpecterTest_VariadicMany_Call) Return(_a0 error) *ExpecterTest_VariadicMany_Call { + _c.Call.Return(_a0) + return _c +} diff --git a/pkg/fixtures/mocks/expecter_test.go b/pkg/fixtures/mocks/expecter_test.go new file mode 100644 index 000000000..1e192641e --- /dev/null +++ b/pkg/fixtures/mocks/expecter_test.go @@ -0,0 +1,194 @@ +package mocks + +import ( + "errors" + "reflect" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +var ( + defaultString = "some input string" + defaultInt = 1 + defaultError = errors.New("some error") +) + +// Test that the generated code for ExpecterTest interface is usable +func TestExpecter(t *testing.T) { + expMock := &ExpecterTest{} + + t.Run("NoArg", func(t *testing.T) { + var runCalled bool + expMock.EXPECT().NoArg().Run(func() { + runCalled = true + }).Return(defaultString).Once() + + // Good call + str := expMock.NoArg() + require.Equal(t, defaultString, str) + require.True(t, runCalled) + + // Call again panic + require.Panics(t, func() { + expMock.NoArg() + }) + expMock.AssertExpectations(t) + }) + + t.Run("NoReturn", func(t *testing.T) { + var runCalled bool + expMock.EXPECT().NoReturn(mock.Anything).Run(func(s string) { + require.Equal(t, defaultString, s) + runCalled = true + }).Return().Once() + + // Good call + expMock.NoReturn(defaultString) + require.True(t, runCalled) + + // Call again panic + require.Panics(t, func() { + expMock.NoReturn(defaultString) + }) + expMock.AssertExpectations(t) + }) + + t.Run("ManyArgsReturns", func(t *testing.T) { + var runCalled bool + expMock.EXPECT().ManyArgsReturns(mock.Anything, defaultInt).Run(func(s string, i int) { + require.Equal(t, defaultString, s) + require.Equal(t, defaultInt, i) + runCalled = true + }).Return([]string{defaultString, defaultString}, defaultError).Once() + + // Call with wrong arg + require.Panics(t, func() { + _, _ = expMock.ManyArgsReturns(defaultString, 0) + }) + + // Good call + strs, err := expMock.ManyArgsReturns(defaultString, defaultInt) + require.Equal(t, []string{defaultString, defaultString}, strs) + require.Equal(t, defaultError, err) + require.True(t, runCalled) + + // Call again panic + require.Panics(t, func() { + _, _ = expMock.ManyArgsReturns(defaultString, defaultInt) + }) + expMock.AssertExpectations(t) + }) + + t.Run("Variadic", func(t *testing.T) { + runCalled := 0 + + expMock.EXPECT().Variadic(1).Run(func(ints ...int) { + require.Equal(t, []int{1}, ints) + runCalled++ + }).Return(defaultError).Once() + + expMock.EXPECT().Variadic(1, 2, 3).Run(func(ints ...int) { + require.Equal(t, []int{1, 2, 3}, ints) + runCalled++ + }).Return(nil).Once() + + expMock.EXPECT().Variadic(1, mock.Anything, 3, mock.Anything).Run(func(ints ...int) { + require.Equal(t, []int{1, 2, 3, 4}, ints) + runCalled++ + }).Return(nil).Once() + + expMock.EXPECT().Variadic([]interface{}{2, 3, mock.Anything}...).Run(func(ints ...int) { + require.Equal(t, []int{2, 3, 4}, ints) + runCalled++ + }).Return(nil).Once() + + args := []int{1, 2, 3, 4, 5} + expMock.EXPECT().Variadic(intfSlice(args)...).Run(func(ints ...int) { + require.Equal(t, args, ints) + runCalled++ + }).Return(nil).Once() + + require.Error(t, expMock.Variadic(1)) + require.NoError(t, expMock.Variadic(1, 2, 3)) + require.NoError(t, expMock.Variadic(1, 2, 3, 4)) + require.NoError(t, expMock.Variadic(2, 3, 4)) + require.NoError(t, expMock.Variadic(args...)) + require.Equal(t, runCalled, 5) + expMock.AssertExpectations(t) + }) + + t.Run("VariadicOtherArgs", func(t *testing.T) { + runCalled := 0 + + expMock.EXPECT().VariadicMany(defaultInt, defaultString).Return(defaultError). + Run(func(i int, a string, intfs ...interface{}) { + require.Equal(t, defaultInt, i) + require.Equal(t, defaultString, a) + require.Empty(t, intfs) + runCalled++ + }).Once() + require.Error(t, expMock.VariadicMany(defaultInt, defaultString)) + + expMock.EXPECT().VariadicMany(defaultInt, defaultString, 1).Return(defaultError). + Run(func(i int, a string, intfs ...interface{}) { + require.Equal(t, defaultInt, i) + require.Equal(t, defaultString, a) + require.Equal(t, []interface{}{1}, intfs) + runCalled++ + }).Once() + require.Error(t, expMock.VariadicMany(defaultInt, defaultString, 1)) + + expMock.EXPECT().VariadicMany(mock.Anything, mock.Anything, 1, nil, mock.AnythingOfType("string")).Return(nil). + Run(func(i int, a string, intfs ...interface{}) { + require.Equal(t, defaultInt, i) + require.Equal(t, defaultString, a) + require.Equal(t, []interface{}{1, nil, "blah"}, intfs) + runCalled++ + }).Once() + require.Panics(t, func() { + expMock.VariadicMany(defaultInt, defaultString, 1, nil, 123) + }) + require.NoError(t, expMock.VariadicMany(defaultInt, defaultString, 1, nil, "blah")) + + expMock.EXPECT().VariadicMany(mock.Anything, mock.Anything, 1, nil, "blah").Run(func(i int, a string, intfs ...interface{}) { + require.Equal(t, defaultInt, i) + require.Equal(t, defaultString, a) + require.Equal(t, []interface{}{1, nil, "blah"}, intfs) + runCalled++ + }).Return(defaultError).Once() + require.Panics(t, func() { + expMock.VariadicMany(defaultInt, defaultString, 1, nil, "other string") + }) + err := expMock.VariadicMany(defaultInt, defaultString, 1, nil, "blah") + require.Equal(t, defaultError, err) + + args := []interface{}{1, 2, 3, 4, 5} + expMock.EXPECT().VariadicMany(defaultInt, defaultString, args...).Run(func(i int, a string, intfs ...interface{}) { + require.Equal(t, defaultInt, i) + require.Equal(t, defaultString, a) + require.Equal(t, []interface{}{1, 2, 3, 4, 5}, intfs) + runCalled++ + }).Return(nil).Once() + require.NoError(t, expMock.VariadicMany(defaultInt, defaultString, args...)) + + require.Equal(t, runCalled, 5) + expMock.AssertExpectations(t) + }) + +} + +func intfSlice(slice interface{}) []interface{} { + val := reflect.ValueOf(slice) + switch val.Kind() { + case reflect.Slice, reflect.Array, reflect.String: + out := make([]interface{}, val.Len()) + for i := 0; i < val.Len(); i++ { + out[i] = val.Index(i).Interface() + } + return out + default: + panic("inftSlice only accepts slices or arrays") + } +} diff --git a/pkg/generator.go b/pkg/generator.go index 6f8a59e2d..0f44eeb56 100644 --- a/pkg/generator.go +++ b/pkg/generator.go @@ -14,13 +14,13 @@ import ( "regexp" "sort" "strings" + "text/template" "unicode" "github.com/rs/zerolog" - "golang.org/x/tools/imports" - "github.com/vektra/mockery/v2/pkg/config" "github.com/vektra/mockery/v2/pkg/logging" + "golang.org/x/tools/imports" ) var invalidIdentifierChar = regexp.MustCompile("[^[:digit:][:alpha:]_]") @@ -201,7 +201,7 @@ func upperFirstOnly(s string) string { return unicode.ToUpper(r) } return r - },s) + }, s) } func (g *Generator) mockName() string { @@ -223,6 +223,10 @@ func (g *Generator) mockName() string { return g.iface.Name } +func (g *Generator) expecterName() string { + return g.mockName() + "_Expecter" +} + func (g *Generator) sortedImportNames() (importNames []string) { for name := range g.nameToPackagePath { importNames = append(importNames, name) @@ -300,6 +304,22 @@ func (g *Generator) printf(s string, vals ...interface{}) { fmt.Fprintf(&g.buf, s, vals...) } +var templates = template.New("base template") + +func (g *Generator) printTemplate(data interface{}, templateString string) { + err := templates.ExecuteTemplate(&g.buf, templateString, data) + if err != nil { + tmpl, err := templates.New(templateString).Parse(templateString) + if err != nil { + // couldn't compile template + panic(err) + } + if err := tmpl.Execute(&g.buf, data); err != nil { + panic(err) + } + } +} + type namer interface { Name() string } @@ -404,11 +424,12 @@ func isNillable(typ types.Type) bool { } type paramList struct { - Names []string - Types []string - Params []string - Nilable []bool - Variadic bool + Names []string + Types []string + Params []string + ParamsIntf []string + Nilable []bool + Variadic bool } func (g *Generator) genList(ctx context.Context, list *types.Tuple, variadic bool) *paramList { @@ -445,6 +466,12 @@ func (g *Generator) genList(ctx context.Context, list *types.Tuple, variadic boo params.Params = append(params.Params, fmt.Sprintf("%s %s", pname, ts)) params.Nilable = append(params.Nilable, isNillable(v.Type())) + + if strings.Contains(ts, "...") { + params.ParamsIntf = append(params.ParamsIntf, fmt.Sprintf("%s ...interface{}", pname)) + } else { + params.ParamsIntf = append(params.ParamsIntf, fmt.Sprintf("%s interface{}", pname)) + } } return ¶ms @@ -469,16 +496,25 @@ func (g *Generator) Generate(ctx context.Context) error { } g.printf( - "// %s is an autogenerated mock type for the %s type\n", g.mockName(), - g.iface.Name, + "// %s is an autogenerated mock type for the %s type\n", + g.mockName(), g.iface.Name, ) g.printf( "type %s struct {\n\tmock.Mock\n}\n\n", g.mockName(), ) + if g.WithExpecter { + g.generateExpecterStruct() + } + for _, method := range g.iface.Methods() { + // It's probably possible, but not worth the trouble for prototype + if method.Signature.Variadic() && g.WithExpecter && !g.UnrollVariadic { + return fmt.Errorf("cannot generate a valid expecter for variadic method with unroll-variadic=false") + } + ftype := method.Signature fname := method.Name @@ -558,11 +594,111 @@ func (g *Generator) Generate(ctx context.Context) error { } g.printf("}\n") + + // Construct expecter helper functions + if g.WithExpecter { + g.generateExpecterMethodCall(method, params, returns) + } } return nil } +func (g *Generator) generateExpecterStruct() { + data := struct{ MockName, ExpecterName string }{ + MockName: g.mockName(), + ExpecterName: g.expecterName(), + } + g.printTemplate(data, ` +type {{.ExpecterName}} struct { + mock *mock.Mock +} + +func (_m *{{.MockName}}) EXPECT() *{{.ExpecterName}} { + return &{{.ExpecterName}}{mock: &_m.Mock} +} +`) +} + +func (g *Generator) generateExpecterMethodCall(method *Method, params, returns *paramList) { + + data := struct { + MockName, ExpecterName string + CallStruct string + MethodName string + Params, ParamsAsInterfaces, Returns *paramList + LastParamName string + LastParamType string + NbNonVariadic int + }{ + MockName: g.mockName(), + ExpecterName: g.expecterName(), + CallStruct: fmt.Sprintf("%s_%s_Call", g.mockName(), method.Name), + MethodName: method.Name, + Params: params, + Returns: returns, + } + + // Get some info about parameters for variadic methods, way easier than doing it in golang template directly + if data.Params.Variadic { + data.LastParamName = data.Params.Names[len(data.Params.Names)-1] + data.LastParamType = strings.TrimLeft(data.Params.Types[len(data.Params.Types)-1], "...") + data.NbNonVariadic = len(data.Params.Types) - 1 + } + + g.printTemplate(data, ` +// {{.CallStruct}} is a *mock.Call that shadows Run/Return methods with type explicit version for method '{{.MethodName}}' +type {{.CallStruct}} struct { + *mock.Call +} + +// {{.MethodName}} is a helper method to define mock.On call +{{- range .Params.Params}} +// - {{.}} +{{- end}} +func (_e *{{.ExpecterName}}) {{.MethodName}}({{range .Params.ParamsIntf}}{{.}},{{end}}) *{{.CallStruct}} { + return &{{.CallStruct}}{Call: _e.mock.On("{{.MethodName}}", + {{- if not .Params.Variadic }} + {{- range .Params.Names}}{{.}},{{end}} + {{- else }} + append([]interface{}{ + {{- range $i, $name := .Params.Names }} + {{- if (lt $i $.NbNonVariadic)}} {{$name}}, + {{- else}} }, {{$name}}... + {{- end}} + {{- end}} )... + {{- end }} )} +} + +func (_c *{{.CallStruct}}) Run(run func({{range .Params.Params}}{{.}},{{end}})) *{{.CallStruct}} { + _c.Call.Run(func(args mock.Arguments) { + {{- if not .Params.Variadic }} + run({{range $i, $type := .Params.Types }}args[{{$i}}].({{$type}}),{{end}}) + {{- else}} + variadicArgs := make([]{{.LastParamType}}, len(args) - {{.NbNonVariadic}}) + for i, a := range args[{{.NbNonVariadic}}:] { + if a != nil { + variadicArgs[i] = a.({{.LastParamType}}) + } + } + run( + {{- range $i, $type := .Params.Types }} + {{- if (lt $i $.NbNonVariadic)}}args[{{$i}}].({{$type}}), + {{- else}}variadicArgs...) + {{- end}} + {{- end}} + {{- end}} + }) + return _c +} + +func (_c *{{.CallStruct}}) Return({{range .Returns.Params}}{{.}},{{end}}) *{{.CallStruct}} { + _c.Call.Return({{range .Returns.Names}}{{.}},{{end}}) + return _c +} +`) +} + // generateCalled returns the Mock.Called invocation string and, if necessary, prints the // steps to prepare its argument list. // diff --git a/pkg/generator_test.go b/pkg/generator_test.go index b9ce9c1c0..610373cfd 100644 --- a/pkg/generator_test.go +++ b/pkg/generator_test.go @@ -32,7 +32,9 @@ func (s *GeneratorSuite) SetupTest() { s.ctx = context.Background() } -func (s *GeneratorSuite) getInterfaceFromFile(interfacePath, interfaceName string) *Interface { +func (s *GeneratorSuite) getInterfaceFromFile( + interfacePath, interfaceName string, +) *Interface { if !strings.Contains(interfacePath, fixturePath) { interfacePath = filepath.Join(fixturePath, interfacePath) } @@ -50,23 +52,21 @@ func (s *GeneratorSuite) getInterfaceFromFile(interfacePath, interfaceName strin return iface } -func (s *GeneratorSuite) getGenerator( - filepath, interfaceName string, inPackage bool, structName string, +func (s *GeneratorSuite) getGeneratorWithConfig( + filepath, interfaceName string, cfg config.Config, ) *Generator { - return NewGenerator( - s.ctx, config.Config{ - StructName: structName, - InPackage: inPackage, - UnrollVariadic: true, - }, s.getInterfaceFromFile(filepath, interfaceName), pkg, - ) + return NewGenerator(s.ctx, cfg, s.getInterfaceFromFile(filepath, interfaceName), pkg) } -func (s *GeneratorSuite) checkGeneration( - filepath, interfaceName string, inPackage bool, structName string, expected string, +func (s *GeneratorSuite) checkGenerationWithConfig( + filepath, interfaceName string, cfg config.Config, expected string, ) *Generator { - generator := s.getGenerator(filepath, interfaceName, inPackage, structName) - s.NoError(generator.Generate(s.ctx), "The generator ran without errors.") + generator := s.getGeneratorWithConfig(filepath, interfaceName, cfg) + err := generator.Generate(s.ctx) + s.NoError(err, "The generator ran without errors.") + if err != nil { + return generator + } // Mirror the formatting done by normally done by golang.org/x/tools/imports in Generator.Write. // @@ -82,13 +82,32 @@ func (s *GeneratorSuite) checkGeneration( expectedLines := strings.Split(expected, "\n") actualLines := strings.Split(string(actual), "\n") - s.Equal( - expectedLines, actualLines, - "The generator produced unexpected output.", - ) + // Error out at first unmatched line + for i := range actualLines { + s.Equal(expectedLines[i], actualLines[i]) + } return generator } +func (s *GeneratorSuite) getGenerator( + filepath, interfaceName string, inPackage bool, structName string, +) *Generator { + return s.getGeneratorWithConfig(filepath, interfaceName, config.Config{ + StructName: structName, + InPackage: inPackage, + UnrollVariadic: true, + }) +} + +func (s *GeneratorSuite) checkGeneration(filepath, interfaceName string, inPackage bool, structName string, expected string) *Generator { + cfg := config.Config{ + StructName: structName, + InPackage: inPackage, + UnrollVariadic: true, + } + return s.checkGenerationWithConfig(filepath, interfaceName, cfg, expected) +} + func (s *GeneratorSuite) checkPrologueGeneration( generator *Generator, expected string, ) { @@ -137,6 +156,96 @@ func (_m *Requester) Get(path string) (string, error) { s.checkGeneration(testFile, "Requester", false, "", expected) } +func (s *GeneratorSuite) TestGeneratorRequesterWithExpecter() { + expected := `// Requester is an autogenerated mock type for the Requester type +type Requester struct { + mock.Mock +} + +type Requester_Expecter struct { + mock *mock.Mock +} + +func (_m *Requester) EXPECT() *Requester_Expecter { + return &Requester_Expecter{mock: &_m.Mock} +} + +// Get provides a mock function with given fields: path +func (_m *Requester) Get(path string) (string, error) { + ret := _m.Called(path) + + var r0 string + if rf, ok := ret.Get(0).(func(string) string); ok { + r0 = rf(path) + } else { + r0 = ret.Get(0).(string) + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(path) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Requester_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type Requester_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - path string +func (_e *Requester_Expecter) Get(path interface{}) *Requester_Get_Call { + return &Requester_Get_Call{Call: _e.mock.On("Get", path)} +} + +func (_c *Requester_Get_Call) Run(run func(path string)) *Requester_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *Requester_Get_Call) Return(_a0 string, _a1 error) *Requester_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} +` + + cfg := config.Config{ + WithExpecter: true, + UnrollVariadic: false, // it's okay if the interface doesn't have any variadic method + } + s.checkGenerationWithConfig(testFile, "Requester", cfg, expected) +} + +func (s *GeneratorSuite) TestGeneratorExpecterComplete() { + expectedBytes, err := ioutil.ReadFile(filepath.Join(fixturePath, "mocks", "expecter.go")) + s.NoError(err) + expected := string(expectedBytes) + expected = expected[strings.Index(expected, "// ExpecterTest is"):] + + cfg := config.Config{ + WithExpecter: true, + UnrollVariadic: true, + } + s.checkGenerationWithConfig(testFile, "ExpecterTest", cfg, expected) +} + +func (s *GeneratorSuite) TestGeneratorExpecterFailsWithoutUnrolledVariadic() { + cfg := config.Config{ + WithExpecter: true, + UnrollVariadic: false, + } + gen := s.getGeneratorWithConfig(testFile, "ExpecterTest", cfg) + err := gen.Generate(s.ctx) + s.Error(err) + s.Contains(err.Error(), "cannot generate a valid expecter for variadic method with unroll-variadic=false") +} + func (s *GeneratorSuite) TestGeneratorFunction() { expected := `// SendFunc is an autogenerated mock type for the SendFunc type type SendFunc struct {