Skip to content

Commit

Permalink
Handle types that implement sql.Scanner and handle time.Time (#19)
Browse files Browse the repository at this point in the history
* Handle types that implement sql.Scanner interface
* Handle time.Time
  • Loading branch information
dfava authored Sep 26, 2024
1 parent 96adc63 commit 5e091a1
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 7 deletions.
22 changes: 20 additions & 2 deletions querysql/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ package querysql
import (
"database/sql"
"reflect"
"time"
)

type typeinfo struct {
valid bool
isStruct bool // use special struct demarshalling; otherwise standard SQL Scan
valid bool
isStruct bool // use special struct demarshalling; otherwise standard SQL Scan
implementsScan bool
isTimeDotTime bool
}

var sqlScannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
Expand All @@ -17,6 +20,21 @@ func inspectType[T any]() typeinfo {
typ := reflect.TypeOf(zeroValue)
kind := typ.Kind()

switch any(zeroValue).(type) {
case sql.Scanner:
// Check if type implements the Scanner interface
return typeinfo{
valid: true,
implementsScan: true,
}
case time.Time:
// underlying sql package automatically converts DATETIME or TIMESTAMP to time.Time
return typeinfo{
valid: true,
isTimeDotTime: true,
}
}

if kind == reflect.Struct {
return typeinfo{
valid: true,
Expand Down
55 changes: 51 additions & 4 deletions querysql/querysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"testing"
"time"

"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -35,23 +36,23 @@ func TestInspectType(t *testing.T) {
expected, got typeinfo
}{
{
expected: typeinfo{true, false},
expected: typeinfo{true, false, false, false},
got: inspectType[int](),
},
{
expected: typeinfo{true, false},
expected: typeinfo{true, false, false, false},
got: inspectType[[]byte](),
},
{
expected: typeinfo{true, true},
expected: typeinfo{true, true, false, false},
got: inspectType[mystruct](),
},
{
expected: typeinfo{valid: false},
got: inspectType[[]mystruct](),
},
{
expected: typeinfo{true, false},
expected: typeinfo{true, false, false, false},
got: inspectType[MyArray](),
},
{
Expand Down Expand Up @@ -683,3 +684,49 @@ select _function='TestFunction', component = 'abc', val=1, time=1.23;

assert.True(t, testhelper.TestFunctionsCalled["TestFunction"])
}

func Test_timeDotTime(t *testing.T) {
testcases := []struct {
name string
qry string
expected string
err error
}{
{
name: "Scan into time.Time",
qry: `select sysutcdatetime();`,
},
}
ctx := context.Background()
for _, tc := range testcases {
res, err := Single[time.Time](ctx, sqldb, tc.qry, "world")
if err == nil {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
if tc.expected != "" {
assert.Equal(t, tc.expected, res)
}
}
}

type MyType struct {
a int
b string
}

func (m MyType) Scan(src any) error {
return nil
}

var _ sql.Scanner = MyType{}

func Test_TypeThatImplementsScan(t *testing.T) {
qry := `select 1`
ctx := context.Background()
// If MyType doesn't implement Scan, then querysql will try to put the result of the `select 1`
// into the `MyType struct{int, string}` and querysql will blow up with the error `failed to map all struct fields to query columns`
_, err := Single[MyType](ctx, sqldb, qry, "world")
assert.NoError(t, err)
}
2 changes: 1 addition & 1 deletion querysql/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ type RowScanner[T any] struct {
func (scanner *RowScanner[T]) scanRow(rows *sql.Rows) error {
if !scanner.init {
scanner.init = true

scanner.typeinfo = inspectType[T]()
if !scanner.typeinfo.valid {
return fmt.Errorf("query.ScanRow: illegal type parameter T")
}

if scanner.isStruct {
var err error
scanner.scanPointers, err = getPointersToFields(rows, scanner.target)
Expand Down

0 comments on commit 5e091a1

Please sign in to comment.