Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow custom matcher functions to return error strings displayed on a failed match #639

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,10 @@ const (
Anything = "mock.Anything"
)

var (
errorType = reflect.TypeOf((*error)(nil)).Elem()
)

// AnythingOfTypeArgument is a string that contains the type of an argument
// for use when type checking. Used in Diff and Assert.
type AnythingOfTypeArgument string
Expand All @@ -578,6 +582,10 @@ type argumentMatcher struct {
}

func (f argumentMatcher) Matches(argument interface{}) bool {
return f.match(argument) == nil
}

func (f argumentMatcher) match(argument interface{}) error {
expectType := f.fn.Type().In(0)
expectTypeNilSupported := false
switch expectType.Kind() {
Expand All @@ -598,25 +606,52 @@ func (f argumentMatcher) Matches(argument interface{}) bool {
}
if argType == nil || argType.AssignableTo(expectType) {
result := f.fn.Call([]reflect.Value{arg})
return result[0].Bool()

var matchError error
switch {
case result[0].Type().Kind() == reflect.Bool:
if !result[0].Bool() {
matchError = fmt.Errorf("not matched by %s", f)
}
case result[0].Type().Implements(errorType):
if !result[0].IsNil() {
matchError = result[0].Interface().(error)
}
default:
panic(fmt.Errorf("matcher function of unknown type: %s", result[0].Type().Kind()))
}

return matchError
}
return false
return fmt.Errorf("unexpected type for %s", f)
}

func (f argumentMatcher) String() string {
return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).Name())
return fmt.Sprintf("func(%s) %s", f.fn.Type().In(0).String(), f.fn.Type().Out(0).String())
}

func (f argumentMatcher) GoString() string {
return fmt.Sprintf("MatchedBy(%s)", f)
}

// MatchedBy can be used to match a mock call based on only certain properties
// from a complex struct or some calculation. It takes a function that will be
// evaluated with the called argument and will return true when there's a match
// and false otherwise.
// evaluated with the called argument and will return either a boolean (true
// when there's a match and false otherwise) or an error (nil when there's a
// match and error holding the failure message otherwise).
//
// Examples:
// m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" }))
//
// Example:
// m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" }))
// m.On("Do", MatchedBy(func(req *http.Request) (err error) {
// if req.Host != "example.com" {
// err = errors.New("host was not example.com")
// }
// return
// })
//
// |fn|, must be a function accepting a single argument (of the expected type)
// which returns a bool. If |fn| doesn't match the required signature,
// which returns a bool or error. If |fn| doesn't match the required signature,
// MatchedBy() panics.
func MatchedBy(fn interface{}) argumentMatcher {
fnType := reflect.TypeOf(fn)
Expand All @@ -627,8 +662,9 @@ func MatchedBy(fn interface{}) argumentMatcher {
if fnType.NumIn() != 1 {
panic(fmt.Sprintf("assert: arguments: %s does not take exactly one argument", fn))
}
if fnType.NumOut() != 1 || fnType.Out(0).Kind() != reflect.Bool {
panic(fmt.Sprintf("assert: arguments: %s does not return a bool", fn))

if fnType.NumOut() != 1 || (fnType.Out(0).Kind() != reflect.Bool && !fnType.Out(0).Implements(errorType)) {
panic(fmt.Sprintf("assert: arguments: %s does not return a bool or a error", fn))
}

return argumentMatcher{fn: reflect.ValueOf(fn)}
Expand Down Expand Up @@ -688,11 +724,11 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
}

if matcher, ok := expected.(argumentMatcher); ok {
if matcher.Matches(actual) {
if matchError := matcher.match(actual); matchError == nil {
output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher)
} else {
differences++
output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher)
output = fmt.Sprintf("%s\t%d: FAIL: %s %s\n", output, i, actualFmt, matchError)
}
} else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() {

Expand Down
29 changes: 27 additions & 2 deletions mock/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,7 @@ func Test_Arguments_Diff_WithArgMatcher(t *testing.T) {

diff, count = args.Diff([]interface{}{"string", false, true})
assert.Equal(t, 1, count)
assert.Contains(t, diff, `(bool=false) not matched by func(int) bool`)
assert.Contains(t, diff, `(bool=false) unexpected type for func(int) bool`)

diff, count = args.Diff([]interface{}{"string", 123, false})
assert.Contains(t, diff, `(int=123) matched by func(int) bool`)
Expand All @@ -1269,6 +1269,31 @@ func Test_Arguments_Diff_WithArgMatcher(t *testing.T) {
assert.Contains(t, diff, `No differences.`)
}

func Test_Arguments_Diff_WithArgMatcherReturningError(t *testing.T) {
matchFn := func(a int) (err error) {
if a != 123 {
err = errors.New("did not match")
}
return
}
var args = Arguments([]interface{}{"string", MatchedBy(matchFn), true})

diff, count := args.Diff([]interface{}{"string", 124, true})
assert.Equal(t, 1, count)
assert.Contains(t, diff, `(int=124) did not match`)

diff, count = args.Diff([]interface{}{"string", false, true})
assert.Equal(t, 1, count)
assert.Contains(t, diff, `(bool=false) unexpected type for func(int) error`)

diff, count = args.Diff([]interface{}{"string", 123, false})
assert.Contains(t, diff, `(int=123) matched by func(int) error`)

diff, count = args.Diff([]interface{}{"string", 123, true})
assert.Equal(t, 0, count)
assert.Contains(t, diff, `No differences.`)
}

func Test_Arguments_Assert(t *testing.T) {

var args = Arguments([]interface{}{"string", 123, true})
Expand Down Expand Up @@ -1445,7 +1470,7 @@ func TestArgumentMatcherToPrintMismatch(t *testing.T) {
defer func() {
if r := recover(); r != nil {
matchingExp := regexp.MustCompile(
`\s+mock: Unexpected Method Call\s+-*\s+GetTime\(int\)\s+0: 1\s+The closest call I have is:\s+GetTime\(mock.argumentMatcher\)\s+0: mock.argumentMatcher\{.*?\}\s+Diff:.*\(int=1\) not matched by func\(int\) bool`)
`\s+mock: Unexpected Method Call\s+-*\s+GetTime\(int\)\s+0: 1\s+The closest call I have is:\s+GetTime\(mock.argumentMatcher\)\s+0: MatchedBy\(func\(int\) bool\)\s+Diff:.*\(int=1\) not matched by func\(int\) bool`)
assert.Regexp(t, matchingExp, r)
}
}()
Expand Down