diff --git a/internal/template/template.go b/internal/template/template.go index 22bb8e89..9d7cd066 100644 --- a/internal/template/template.go +++ b/internal/template/template.go @@ -48,7 +48,22 @@ import ( {{- if not $.SkipEnsure -}} // Ensure, that {{.MockName}} does implement {{$.SrcPkgQualifier}}{{.InterfaceName}}. // If this is not the case, regenerate this file with moq. -var _ {{$.SrcPkgQualifier}}{{.InterfaceName}} = &{{.MockName}}{} +var _ {{$.SrcPkgQualifier}}{{.InterfaceName -}} + {{- if .TypeParams }}[ + {{- range $index, $param := .TypeParams}} + {{- if $index}}, {{end -}} + {{if $param.Constraint}}{{$param.Constraint.String}}{{else}}{{$param.TypeString}}{{end}} + {{- end -}} + ] + {{- end }} = &{{.MockName}} + {{- if .TypeParams }}[ + {{- range $index, $param := .TypeParams}} + {{- if $index}}, {{end -}} + {{if $param.Constraint}}{{$param.Constraint.String}}{{else}}{{$param.TypeString}}{{end}} + {{- end -}} + ] + {{- end -}} +{} {{- end}} // {{.MockName}} is a mock implementation of {{$.SrcPkgQualifier}}{{.InterfaceName}}. @@ -68,7 +83,12 @@ var _ {{$.SrcPkgQualifier}}{{.InterfaceName}} = &{{.MockName}}{} // // and then make assertions. // // } -type {{.MockName}} struct { +type {{.MockName}} +{{- if .TypeParams -}} + [{{- range $index, $param := .TypeParams}} + {{- if $index}}, {{end}}{{$param.Name | Exported}} {{$param.TypeString}} + {{- end -}}] +{{- end }} struct { {{- range .Methods}} // {{.Name}}Func mocks the {{.Name}} method. {{.Name}}Func func({{.ArgList}}) {{.ReturnArgTypeList}} @@ -91,7 +111,13 @@ type {{.MockName}} struct { } {{range .Methods}} // {{.Name}} calls {{.Name}}Func. -func (mock *{{$mock.MockName}}) {{.Name}}({{.ArgList}}) {{.ReturnArgTypeList}} { +func (mock *{{$mock.MockName}} +{{- if $mock.TypeParams -}} + [{{- range $index, $param := $mock.TypeParams}} + {{- if $index}}, {{end}}{{$param.Name | Exported}} + {{- end -}}] +{{- end -}} +) {{.Name}}({{.ArgList}}) {{.ReturnArgTypeList}} { {{- if not $.StubImpl}} if mock.{{.Name}}Func == nil { panic("{{$mock.MockName}}.{{.Name}}Func: method is nil but {{$mock.InterfaceName}}.{{.Name}} was just called") @@ -134,7 +160,13 @@ func (mock *{{$mock.MockName}}) {{.Name}}({{.ArgList}}) {{.ReturnArgTypeList}} { // {{.Name}}Calls gets all the calls that were made to {{.Name}}. // Check the length with: // len(mocked{{$mock.InterfaceName}}.{{.Name}}Calls()) -func (mock *{{$mock.MockName}}) {{.Name}}Calls() []struct { +func (mock *{{$mock.MockName}} +{{- if $mock.TypeParams -}} + [{{- range $index, $param := $mock.TypeParams}} + {{- if $index}}, {{end}}{{$param.Name | Exported}} + {{- end -}}] +{{- end -}} +) {{.Name}}Calls() []struct { {{- range .Params}} {{.Name | Exported}} {{.TypeString}} {{- end}} diff --git a/internal/template/template_data.go b/internal/template/template_data.go index 0a95bb5e..12c91174 100644 --- a/internal/template/template_data.go +++ b/internal/template/template_data.go @@ -2,6 +2,7 @@ package template import ( "fmt" + "go/types" "strings" "github.com/matryer/moq/internal/registry" @@ -33,6 +34,7 @@ func (d Data) MocksSomeMethod() bool { type MockData struct { InterfaceName string MockName string + TypeParams []TypeParamData Methods []MethodData } @@ -87,6 +89,11 @@ func (m MethodData) ReturnArgNameList() string { return strings.Join(params, ", ") } +type TypeParamData struct { + ParamData + Constraint types.Type +} + // ParamData is the data which represents a parameter to some method of // an interface. type ParamData struct {