diff --git a/querysql/gomssqldispatcher.go b/querysql/gomssqldispatcher.go index 5e5d3ba..cbe87ec 100644 --- a/querysql/gomssqldispatcher.go +++ b/querysql/gomssqldispatcher.go @@ -4,16 +4,18 @@ import ( "database/sql" "fmt" "reflect" + "regexp" "runtime" "strconv" "strings" ) type funcInfo struct { - name string - numArgs int - argType []reflect.Type - valueOf reflect.Value + name string + numArgs int + isClosure bool + argType []reflect.Type + valueOf reflect.Value } func GoMSSQLDispatcher(fs []interface{}) RowsGoDispatcher { @@ -30,13 +32,21 @@ func GoMSSQLDispatcher(fs []interface{}) RowsGoDispatcher { } fInfo.valueOf = reflect.ValueOf(f) - getFunctionName := func(fullName string) string { + getFunctionName := func(fullName string) (string, bool) { paths := strings.Split(fullName, "/") lastPath := paths[len(paths)-1] parts := strings.Split(lastPath, ".") - return parts[len(parts)-1] + fName := parts[len(parts)-1] + matched, err := regexp.Match(`func\d+`, []byte(fName)) + if err != nil { + panic(err.Error()) + } + if matched { + return parts[len(parts)-2], true // It is a closure + } + return fName, false } - fInfo.name = getFunctionName(runtime.FuncForPC(fInfo.valueOf.Pointer()).Name()) + fInfo.name, fInfo.isClosure = getFunctionName(runtime.FuncForPC(fInfo.valueOf.Pointer()).Name()) if knownFuncs == "" { knownFuncs = fmt.Sprintf("'%s'", fInfo.name) @@ -51,6 +61,9 @@ func GoMSSQLDispatcher(fs []interface{}) RowsGoDispatcher { for i := 0; i < fInfo.numArgs; i++ { fInfo.argType[i] = funcType.In(i) } + if _, in := funcMap[fInfo.name]; in { + panic(fmt.Sprintf("Function already in dispatcher %s (closure==%v)", fInfo.name, fInfo.isClosure)) + } funcMap[fInfo.name] = fInfo } diff --git a/querysql/querysql_test.go b/querysql/querysql_test.go index 874a9c1..f932217 100644 --- a/querysql/querysql_test.go +++ b/querysql/querysql_test.go @@ -714,3 +714,64 @@ values (42.00); assert.NoError(t, err) assert.Equal(t, "42.00", m.String()) } + +func TestAnonDispatcherFunc(t *testing.T) { + qry := ` +if OBJECT_ID('dbo.MyUsers', 'U') is not null drop table MyUsers +create table MyUsers ( + ID INT IDENTITY(1,1) PRIMARY KEY, + Username NVARCHAR(50) +); +insert into MyUsers (Username) values ('JohnDoe'); + +-- logging +select _log='info', Y = 'one'; + +-- dispatcher +select _function='TestFunction', component = 'abc', val=1, time=1.23; + +select _function='ReturnAnonFunc', label = 'myLabel', time=1.23; +` + + var hook LogHook + logger := logrus.StandardLogger() + logger.Hooks.Add(&hook) + ctx := querysql.WithLogger(context.Background(), querysql.LogrusMSSQLLogger(logger, logrus.InfoLevel)) + ctx = querysql.WithDispatcher(ctx, querysql.GoMSSQLDispatcher([]interface{}{ + testhelper.TestFunction, + testhelper.ReturnAnonFunc("myComponent"), + })) + testhelper.ResetTestFunctionsCalled() + + _, err := querysql.ExecContext(ctx, sqldb, qry, "world") + assert.NoError(t, err) + + assert.True(t, testhelper.TestFunctionsCalled["ReturnAnonFunc.myComponent"]) + + // Check that we have exhausted the logging select before we do the call that gets ErrNoMoreSets + assert.Equal(t, []logrus.Fields{ + {"Y": "one"}, + }, hook.lines) + + assert.True(t, testhelper.TestFunctionsCalled["TestFunction"]) +} + +func TestDispatcherPanicsWithTwoAnonFuncs(t *testing.T) { + var mustNotBeTrue bool + var hook LogHook + logger := logrus.StandardLogger() + logger.Hooks.Add(&hook) + defer func() { + r := recover() + assert.NotNil(t, r) // nil if a panic didn't happen, not nil if a panic happened + assert.False(t, mustNotBeTrue) + }() + + ctx := querysql.WithLogger(context.Background(), querysql.LogrusMSSQLLogger(logger, logrus.InfoLevel)) + ctx = querysql.WithDispatcher(ctx, querysql.GoMSSQLDispatcher([]interface{}{ + testhelper.ReturnAnonFunc("myComponent"), + testhelper.ReturnAnonFunc("myComponent2"), // This should cause a panic + })) + // Nothing here gets executed because we expect the WithDispatcher to have panicked + mustNotBeTrue = true +} diff --git a/querysql/testhelper/helper.go b/querysql/testhelper/helper.go index 4e6fd87..3dac681 100644 --- a/querysql/testhelper/helper.go +++ b/querysql/testhelper/helper.go @@ -23,6 +23,12 @@ func OtherTestFunction(time float64, money float64) { TestFunctionsCalled[getFunctionName()] = true } +func ReturnAnonFunc(component string) func(string, float64) { + return func(label string, time float64) { + TestFunctionsCalled[fmt.Sprintf("ReturnAnonFunc.%s", component)] = true + } +} + func ResetTestFunctionsCalled() { for k, _ := range TestFunctionsCalled { TestFunctionsCalled[k] = false