diff --git a/named.go b/named.go index 1f416121..728aa04d 100644 --- a/named.go +++ b/named.go @@ -224,29 +224,47 @@ func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) return bound, arglist, nil } -var valueBracketReg = regexp.MustCompile(`(?i)VALUES\s*(\([^(]*.[^(]\))`) +var valuesReg = regexp.MustCompile(`\)\s*(?i)VALUES\s*\(`) -func fixBound(bound string, loop int) string { +func findMatchingClosingBracketIndex(s string) int { + count := 0 + for i, ch := range s { + if ch == '(' { + count++ + } + if ch == ')' { + count-- + if count == 0 { + return i + } + } + } + return 0 +} - loc := valueBracketReg.FindAllStringSubmatchIndex(bound, -1) - // Either no VALUES () found or more than one found?? - if len(loc) != 1 { +func fixBound(bound string, loop int) string { + loc := valuesReg.FindStringIndex(bound) + // defensive guard when "VALUES (...)" not found + if len(loc) < 2 { return bound } - // defensive guard. loc should be len 4 representing the starting and - // ending index for the whole regex match and the starting + ending - // index for the single inside group - if len(loc[0]) != 4 { + + openingBracketIndex := loc[1] - 1 + index := findMatchingClosingBracketIndex(bound[openingBracketIndex:]) + // defensive guard. must have closing bracket + if index == 0 { return bound } + closingBracketIndex := openingBracketIndex + index + 1 + var buffer bytes.Buffer - buffer.WriteString(bound[0:loc[0][1]]) + buffer.WriteString(bound[0:closingBracketIndex]) for i := 0; i < loop-1; i++ { buffer.WriteString(",") - buffer.WriteString(bound[loc[0][2]:loc[0][3]]) + buffer.WriteString(bound[openingBracketIndex:closingBracketIndex]) } - buffer.WriteString(bound[loc[0][1]:]) + buffer.WriteString(bound[closingBracketIndex:]) return buffer.String() } diff --git a/named_test.go b/named_test.go index 70bc4484..8481b35b 100644 --- a/named_test.go +++ b/named_test.go @@ -3,7 +3,6 @@ package sqlx import ( "database/sql" "fmt" - "regexp" "testing" ) @@ -105,6 +104,7 @@ type Test struct { } func (t Test) Error(err error, msg ...interface{}) { + t.t.Helper() if err != nil { if len(msg) == 0 { t.t.Error(err) @@ -115,6 +115,7 @@ func (t Test) Error(err error, msg ...interface{}) { } func (t Test) Errorf(err error, format string, args ...interface{}) { + t.t.Helper() if err != nil { t.t.Errorf(format, args...) } @@ -339,7 +340,7 @@ func TestFixBounds(t *testing.T) { { name: `found twice test`, query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`, - expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`, + expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`, loop: 2, }, { @@ -354,6 +355,73 @@ func TestFixBounds(t *testing.T) { expect: `INSERT INTO foo (a,b) values(:a, :b),(:a, :b)`, loop: 2, }, + { + name: `on duplicate key using VALUES`, + query: `INSERT INTO foo (a,b) VALUES (:a, :b) ON DUPLICATE KEY UPDATE a=VALUES(a)`, + expect: `INSERT INTO foo (a,b) VALUES (:a, :b),(:a, :b) ON DUPLICATE KEY UPDATE a=VALUES(a)`, + loop: 2, + }, + { + name: `single column`, + query: `INSERT INTO foo (a) VALUES (:a)`, + expect: `INSERT INTO foo (a) VALUES (:a),(:a)`, + loop: 2, + }, + { + name: `call now`, + query: `INSERT INTO foo (a, b) VALUES (:a, NOW())`, + expect: `INSERT INTO foo (a, b) VALUES (:a, NOW()),(:a, NOW())`, + loop: 2, + }, + { + name: `two level depth function call`, + query: `INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW()))`, + expect: `INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW())),(:a, YEAR(NOW()))`, + loop: 2, + }, + { + name: `missing closing bracket`, + query: `INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW())`, + expect: `INSERT INTO foo (a, b) VALUES (:a, YEAR(NOW())`, + loop: 2, + }, + { + name: `table with "values" at the end`, + query: `INSERT INTO table_values (a, b) VALUES (:a, :b)`, + expect: `INSERT INTO table_values (a, b) VALUES (:a, :b),(:a, :b)`, + loop: 2, + }, + { + name: `multiline indented query`, + query: `INSERT INTO foo ( + a, + b, + c, + d + ) VALUES ( + :name, + :age, + :first, + :last + )`, + expect: `INSERT INTO foo ( + a, + b, + c, + d + ) VALUES ( + :name, + :age, + :first, + :last + ),( + :name, + :age, + :first, + :last + )`, + loop: 2, + }, } for _, tc := range table { @@ -364,18 +432,4 @@ func TestFixBounds(t *testing.T) { } }) } - - t.Run("regex changed", func(t *testing.T) { - var valueBracketRegChanged = regexp.MustCompile(`(VALUES)\s+(\([^(]*.[^(]\))`) - saveRegexp := valueBracketReg - defer func() { - valueBracketReg = saveRegexp - }() - valueBracketReg = valueBracketRegChanged - - res := fixBound("VALUES (:a, :b)", 2) - if res != "VALUES (:a, :b)" { - t.Errorf("changed regex should return string") - } - }) }