Skip to content

Commit

Permalink
Validate single row count for Row.Scan
Browse files Browse the repository at this point in the history
Validate that a single row only is returned by queries used for Row.Scan.

This avoid unexpected results when the query has an issue such as missing join criteria or limit in conjunction with functions which expect only on row returned e.g. Get(...).

Also:
* Fixed missing \n's for test output of ConnectAll.
  • Loading branch information
stevenh committed Oct 13, 2018
1 parent 0dae4fe commit f8641fc
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 8 deletions.
25 changes: 19 additions & 6 deletions sqlx.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ import (
"github.com/jmoiron/sqlx/reflectx"
)

// ErrMultiRows is returned by functions which are expected to work with result sets
// that only contain a single row but multiple rows where returned.
// This typically indicates an issue with the query such as a missing join criteria or
// limit condition or the use of Get(...) when Select(...) was intended.
var ErrMultiRows = errors.New("sql: multiple rows returned")

// Although the NameMapper is convenient, in practice it should not
// be relied on except for application code. If you are writing a library
// that uses sqlx, you should be aware that the name mappings you expect
Expand Down Expand Up @@ -177,6 +183,7 @@ type Row struct {

// Scan is a fixed implementation of sql.Row.Scan, which does not discard the
// underlying error from the internal rows object if it exists.
// Returns ErrMultiRows if the result set contains more than one row.
func (r *Row) Scan(dest ...interface{}) error {
if r.err != nil {
return r.err
Expand Down Expand Up @@ -208,10 +215,16 @@ func (r *Row) Scan(dest ...interface{}) error {
}
return sql.ErrNoRows
}
err := r.rows.Scan(dest...)
if err != nil {
if err := r.rows.Scan(dest...); err != nil {
return err
}

if r.rows.Next() {
return ErrMultiRows
} else if err := r.rows.Err(); err != nil {
return err
}

// Make sure the query can be processed to completion with no errors.
if err := r.rows.Close(); err != nil {
return err
Expand Down Expand Up @@ -323,7 +336,7 @@ func (db *DB) Select(dest interface{}, query string, args ...interface{}) error

// Get using this DB.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
// An error is returned if the result set is empty or contains more than one row.
func (db *DB) Get(dest interface{}, query string, args ...interface{}) error {
return Get(db, dest, query, args...)
}
Expand Down Expand Up @@ -446,7 +459,7 @@ func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row {

// Get within a transaction.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
// An error is returned if the result set is empty or contains more than one row.
func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error {
return Get(tx, dest, query, args...)
}
Expand Down Expand Up @@ -516,7 +529,7 @@ func (s *Stmt) Select(dest interface{}, args ...interface{}) error {

// Get using the prepared statement.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
// An error is returned if the result set is empty or contains more than one row.
func (s *Stmt) Get(dest interface{}, args ...interface{}) error {
return Get(&qStmt{s}, dest, "", args...)
}
Expand Down Expand Up @@ -682,7 +695,7 @@ func Select(q Queryer, dest interface{}, query string, args ...interface{}) erro
// to dest. If dest is scannable, the result must only have one column. Otherwise,
// StructScan is used. Get will return sql.ErrNoRows like row.Scan would.
// Any placeholder parameters are replaced with supplied args.
// An error is returned if the result set is empty.
// An error is returned if the result set is empty or contains more than one row.
func Get(q Queryer, dest interface{}, query string, args ...interface{}) error {
r := q.QueryRowx(query, args...)
return r.scanAny(dest, false)
Expand Down
44 changes: 42 additions & 2 deletions sqlx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func ConnectAll() {
if TestMysql {
mysqldb, err = Connect("mysql", mydsn)
if err != nil {
fmt.Printf("Disabling MySQL tests:\n %v", err)
fmt.Printf("Disabling MySQL tests:\n %v\n", err)
TestMysql = false
}
} else {
Expand All @@ -85,7 +85,7 @@ func ConnectAll() {
if TestSqlite {
sldb, err = Connect("sqlite3", sqdsn)
if err != nil {
fmt.Printf("Disabling SQLite:\n %v", err)
fmt.Printf("Disabling SQLite:\n %v\n", err)
TestSqlite = false
}
} else {
Expand Down Expand Up @@ -1708,6 +1708,46 @@ func TestEmbeddedLiterals(t *testing.T) {
})
}

// TestGet tests to ensure that Get behaves correctly for
// single row and multi row results.
func TestGet(t *testing.T) {
var schema = Schema{
create: `CREATE TABLE tst (v integer);`,
drop: `drop table tst;`,
}

RunWithSchema(schema, t, func(db *DB, t *testing.T) {
for _, v := range []int{1, 2} {
_, err := db.Exec(db.Rebind("INSERT INTO tst (v) VALUES (?)"), v)
if err != nil {
t.Error(err)
}
}

tests := []struct {
name string
val int
err bool
}{
{"multi-rows", 1, true},
{"single-row", 2, false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var v int
err := db.Get(&v, db.Rebind("SELECT v FROM tst WHERE v >= ?"), tc.val)
if tc.err {
if err == nil {
t.Error("expected error but got nil")
}
} else if err != nil {
t.Error("unexpected error:", err)
}
})
}
})
}

func BenchmarkBindStruct(b *testing.B) {
b.StopTimer()
q1 := `INSERT INTO foo (a, b, c, d) VALUES (:name, :age, :first, :last)`
Expand Down

0 comments on commit f8641fc

Please sign in to comment.