Skip to content

Commit

Permalink
Update and simplify tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dfava committed Oct 31, 2024
1 parent ed01238 commit 1f7fcdc
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 83 deletions.
2 changes: 1 addition & 1 deletion examples/metrics/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func initdb() (*sql.DB, error) {
}

func populatedb(ctx context.Context, sqldb *sql.DB) error {
qry := `if object_id('dbo.MyUsers', 'U') is not null drop table MyUsers
qry := `drop table if exists
create table MyUsers (
Id int identity(1,1) primary key,
UserName nvarchar(50) not null,
Expand Down
206 changes: 124 additions & 82 deletions querysql/querysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"log"
"testing"
"time"

Expand Down Expand Up @@ -80,68 +81,96 @@ select _function='OtherTestFunction', time=42, money=convert(money, 12345.67);
Y string
}

var ctx context.Context
var hook LogHook
logger := logrus.StandardLogger()
logger.Hooks.Add(&hook)
ctx := querysql.WithLogger(context.Background(), querysql.LogrusMSSQLLogger(logger, logrus.InfoLevel))
ctx = querysql.WithDispatcher(ctx, querysql.GoMSSQLDispatcher([]interface{}{
testhelper.TestFunction,
testhelper.OtherTestFunction,
}))
rs := querysql.New(ctx, sqldb, qry, "world")
rows := rs.Rows
testhelper.ResetTestFunctionsCalled()

// select 2
assert.Equal(t, 2, querysql.MustNextResult(rs, querysql.SingleOf[int]))
testcases := []struct {
name string
init func()
checkLog func()
}{
{
name: "With logrus logger",
init: func() {
logger := logrus.StandardLogger()
logger.Hooks.Add(&hook)
ctx = querysql.WithLogger(context.Background(), querysql.LogrusMSSQLLogger(logger, logrus.InfoLevel))
},
checkLog: func() {
// Check that we have exhausted the logging select before we do the call that gets ErrNoMoreSets
assert.Equal(t, []logrus.Fields{
{"x": "hello world", "y": int64(1)},
{"x": "hello world2", "y": int64(2)},
{"x": "hello world3", "y": int64(3)},
{"x": "hello world3", "y": int64(4)},
{"_norows": true, "x": ""},
{"log": "at end"},
}, hook.lines)
},
},
{
name: "With std logger",
init: func() {
logger := log.Default()
ctx = querysql.WithLogger(context.Background(), querysql.StdMSSQLLogger(logger))
},
checkLog: func() {},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
tc.init()
ctx = querysql.WithDispatcher(ctx, querysql.GoMSSQLDispatcher([]interface{}{
testhelper.TestFunction,
testhelper.OtherTestFunction,
}))
rs := querysql.New(ctx, sqldb, qry, "world")
rows := rs.Rows
testhelper.ResetTestFunctionsCalled()

// select X = 1, Y = 'one';
assert.Equal(t, row{1, "one"}, querysql.MustNextResult(rs, querysql.SingleOf[row]))
// select 2
assert.Equal(t, 2, querysql.MustNextResult(rs, querysql.SingleOf[int]))

// select 'hello' union all select @p1;
assert.Equal(t, []string{"hello", "world"}, querysql.MustNextResult(rs, querysql.SliceOf[string]))
// select X = 1, Y = 'one';
assert.Equal(t, row{1, "one"}, querysql.MustNextResult(rs, querysql.SingleOf[row]))

// select X = 1, Y = 'one' where 1 = 0;
assert.Equal(t, []row(nil), querysql.MustNextResult(rs, querysql.SliceOf[row]))
// select 'hello' union all select @p1;
assert.Equal(t, []string{"hello", "world"}, querysql.MustNextResult(rs, querysql.SliceOf[string]))

// select X = 1, Y = 'one'
// union all select X = 2, Y = 'two';
assert.Equal(t, []row{{1, "one"}, {2, "two"}}, querysql.MustNextResult(rs, querysql.SliceOf[row]))
// select X = 1, Y = 'one' where 1 = 0;
assert.Equal(t, []row(nil), querysql.MustNextResult(rs, querysql.SliceOf[row]))

// select 0x0102030405 union all select 0x0102030406
assert.Equal(t, []MyArray{{1, 2, 3, 4, 5}, {1, 2, 3, 4, 6}}, querysql.MustNextResult(rs, querysql.SliceOf[MyArray]))
// select X = 1, Y = 'one'
// union all select X = 2, Y = 'two';
assert.Equal(t, []row{{1, "one"}, {2, "two"}}, querysql.MustNextResult(rs, querysql.SliceOf[row]))

// select concat('hello ', @p1);
assert.Equal(t, "hello world", querysql.MustNextResult(rs, querysql.SingleOf[string]))
// select 0x0102030405 union all select 0x0102030406
assert.Equal(t, []MyArray{{1, 2, 3, 4, 5}, {1, 2, 3, 4, 6}}, querysql.MustNextResult(rs, querysql.SliceOf[MyArray]))

// select 0x0102030405
assert.Equal(t, MyArray{1, 2, 3, 4, 5}, querysql.MustNextResult(rs, querysql.SingleOf[MyArray]))
// select concat('hello ', @p1);
assert.Equal(t, "hello world", querysql.MustNextResult(rs, querysql.SingleOf[string]))

// select newid()
assert.Equal(t, 16, len(querysql.MustNextResult(rs, querysql.SingleOf[[]uint8])))
// select 0x0102030405
assert.Equal(t, MyArray{1, 2, 3, 4, 5}, querysql.MustNextResult(rs, querysql.SingleOf[MyArray]))

// Check that we have exhausted the logging select before we do the call that gets ErrNoMoreSets
assert.Equal(t, []logrus.Fields{
{"x": "hello world", "y": int64(1)},
{"x": "hello world2", "y": int64(2)},
{"x": "hello world3", "y": int64(3)},
{"x": "hello world3", "y": int64(4)},
{"_norows": true, "x": ""},
{"log": "at end"},
}, hook.lines)
// select newid()
assert.Equal(t, 16, len(querysql.MustNextResult(rs, querysql.SingleOf[[]uint8])))

querysql.NextResult(rs, querysql.SliceOf[string]) // This will process all dispatcher function calls
assert.True(t, testhelper.TestFunctionsCalled["TestFunction"])
assert.True(t, testhelper.TestFunctionsCalled["OtherTestFunction"])
tc.checkLog()

_, err := querysql.NextResult(rs, querysql.SingleOf[int])
assert.Equal(t, querysql.ErrNoMoreSets, err)
assert.True(t, isClosed(rows))
assert.True(t, rs.Done())
assert.True(t, testhelper.TestFunctionsCalled["TestFunction"])
assert.True(t, testhelper.TestFunctionsCalled["OtherTestFunction"])

rs.Close()
assert.True(t, isClosed(rows))
_, err := querysql.NextResult(rs, querysql.SingleOf[int])
assert.Equal(t, querysql.ErrNoMoreSets, err)
assert.True(t, isClosed(rows))
assert.True(t, rs.Done())

err = rs.Close()
assert.NoError(t, err)
assert.True(t, isClosed(rows))
})
}
}

func TestInvalidLogLevel(t *testing.T) {
Expand All @@ -151,19 +180,45 @@ select _log=1, x = 'hello world', y = 1;
`

var hook LogHook
logger := logrus.StandardLogger()
logger.Hooks.Add(&hook)
ctx := querysql.WithLogger(context.Background(), querysql.LogrusMSSQLLogger(logger, logrus.InfoLevel))
rs := querysql.New(ctx, sqldb, qry, "world")
err := querysql.NextNoScanner(rs)
assert.Error(t, err)
assert.Equal(t, "no more result sets", err.Error())
var ctx context.Context
testcases := []struct {
name string
init func()
checkLog func()
}{
{
name: "With logrus logger",
init: func() {
logger := logrus.StandardLogger()
logger.Hooks.Add(&hook)
ctx = querysql.WithLogger(context.Background(), querysql.LogrusMSSQLLogger(logger, logrus.InfoLevel))
},
checkLog: func() {
// Check that we have exhausted the logging select before we do the call that gets ErrNoMoreSets
assert.Equal(t, []logrus.Fields{
{"event": "invalid.log.level", "invalid.level": "1"},
{"x": "hello world", "y": int64(1)},
}, hook.lines)
},
},
{
name: "With std logger",
init: func() {
logger := log.Default()
ctx = querysql.WithLogger(context.Background(), querysql.StdMSSQLLogger(logger))
},
checkLog: func() {},
},
}

// Check that we have exhausted the logging select before we do the call that gets ErrNoMoreSets
assert.Equal(t, []logrus.Fields{
{"event": "invalid.log.level", "invalid.level": "1"},
{"x": "hello world", "y": int64(1)},
}, hook.lines)
for _, tc := range testcases {
tc.init()
rs := querysql.New(ctx, sqldb, qry, "world")
err := querysql.NextNoScanner(rs)
assert.Error(t, err)
assert.Equal(t, "no more result sets", err.Error())
tc.checkLog()
}
}

func Test_LogAndException(t *testing.T) {
Expand Down Expand Up @@ -218,16 +273,13 @@ select 2;

func TestDispatcherSetupError(t *testing.T) {
var mustNotBeTrue bool
var hook LogHook
logger := logrus.StandardLogger()
logger.Hooks.Add(&hook)
defer func() {
r := recover()
assert.NotNil(t, r) // nil if a panic didn't happen, not nil if a panic happened
assert.False(t, mustNotBeTrue)
}()

ctx := querysql.WithLogger(context.Background(), querysql.LogrusMSSQLLogger(logger, logrus.InfoLevel))
ctx := context.Background()
ctx = querysql.WithDispatcher(ctx, querysql.GoMSSQLDispatcher([]interface{}{
"SomethingThatIsNotAFunctionPointer", // This should cause a panic
}))
Expand Down Expand Up @@ -312,7 +364,8 @@ func TestDispatcherRuntimeErrorsAndCornerCases(t *testing.T) {
assert.True(t, isClosed(rows))
assert.True(t, rs.Done())

rs.Close()
err = rs.Close()
assert.NoError(t, err)
assert.True(t, isClosed(rows))
}
}
Expand Down Expand Up @@ -433,13 +486,13 @@ func TestEmptyStruct(t *testing.T) {

func TestEmptyResultWithError(t *testing.T) {
qry := `
if OBJECT_ID('dbo.MyUsers', 'U') is not null drop table MyUsers
drop table if exists MyUsers
create table MyUsers (
ID INT IDENTITY(1,1) PRIMARY KEY,
Username NVARCHAR(50) not null,
Userage int
);
insert into MyUsers (Userage)
insert into MyUsers (Userage)
output inserted.ID
values (42);
`
Expand Down Expand Up @@ -585,7 +638,7 @@ func TestStructScanError(t *testing.T) {

func TestExecContext(t *testing.T) {
qry := `
if OBJECT_ID('dbo.MyUsers', 'U') is not null drop table MyUsers
drop table if exists MyUsers
create table MyUsers (
ID INT IDENTITY(1,1) PRIMARY KEY,
Username NVARCHAR(50)
Expand Down Expand Up @@ -655,7 +708,7 @@ type MyType struct {
b string
}

func (m MyType) Scan(src any) error {
func (m MyType) Scan(_ any) error {
return nil
}

Expand Down Expand Up @@ -717,26 +770,20 @@ values (42.00);

func TestAnonDispatcherFunc(t *testing.T) {
qry := `
if OBJECT_ID('dbo.MyUsers', 'U') is not null drop table MyUsers
drop table if exists MyUsers
create table MyUsers (
ID INT IDENTITY(1,1) PRIMARY KEY,
Username NVARCHAR(50)
);
insert into MyUsers (Username) values ('JohnDoe');
-- logging
select _log='info', Y = 'one';
-- dispatcher
select _function='TestFunction', component = 'abc', val=1, time=1.23;
select _function='ReturnAnonFunc', label = 'myLabel', time=1.23;
`

var hook LogHook
logger := logrus.StandardLogger()
logger.Hooks.Add(&hook)
ctx := querysql.WithLogger(context.Background(), querysql.LogrusMSSQLLogger(logger, logrus.InfoLevel))
ctx := context.Background()
ctx = querysql.WithDispatcher(ctx, querysql.GoMSSQLDispatcher([]interface{}{
testhelper.TestFunction,
testhelper.ReturnAnonFunc("myComponent"),
Expand All @@ -748,11 +795,6 @@ select _function='ReturnAnonFunc', label = 'myLabel', time=1.23;

assert.True(t, testhelper.TestFunctionsCalled["ReturnAnonFunc.myComponent"])

// Check that we have exhausted the logging select before we do the call that gets ErrNoMoreSets
assert.Equal(t, []logrus.Fields{
{"Y": "one"},
}, hook.lines)

assert.True(t, testhelper.TestFunctionsCalled["TestFunction"])
}

Expand All @@ -767,7 +809,7 @@ func TestDispatcherPanicsWithTwoAnonFuncs(t *testing.T) {
assert.False(t, mustNotBeTrue)
}()

ctx := querysql.WithLogger(context.Background(), querysql.LogrusMSSQLLogger(logger, logrus.InfoLevel))
ctx := context.Background()
ctx = querysql.WithDispatcher(ctx, querysql.GoMSSQLDispatcher([]interface{}{
testhelper.ReturnAnonFunc("myComponent"),
testhelper.ReturnAnonFunc("myComponent2"), // This should cause a panic
Expand Down

0 comments on commit 1f7fcdc

Please sign in to comment.