diff --git a/named.go b/named.go index 2a96e8ed..fbf9ec12 100644 --- a/named.go +++ b/named.go @@ -224,11 +224,11 @@ func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) return bound, arglist, nil } -var valueBracketReg = regexp.MustCompile(`\([^(]*.[^(]\)\s*$`) +var valueBracketReg = regexp.MustCompile(`(?i)(VALUES\s+)\([^()]+(?:\(\s*\))?(?:\s*,\s*[^()]+(?:\(\s*\))?)*\)`) func fixBound(bound string, loop int) string { - loc := valueBracketReg.FindStringIndex(bound) - if len(loc) != 2 { + loc := valueBracketReg.FindStringSubmatchIndex(bound) + if len(loc) != 4 { return bound } var buffer bytes.Buffer @@ -236,7 +236,7 @@ func fixBound(bound string, loop int) string { buffer.WriteString(bound[0:loc[1]]) for i := 0; i < loop-1; i++ { buffer.WriteString(",") - buffer.WriteString(bound[loc[0]:loc[1]]) + buffer.WriteString(bound[loc[3]:loc[1]]) } buffer.WriteString(bound[loc[1]:]) return buffer.String() diff --git a/named_test.go b/named_test.go index 0cb4088b..353bb922 100644 --- a/named_test.go +++ b/named_test.go @@ -104,6 +104,7 @@ type Test struct { } func (t Test) Error(err error, msg ...interface{}) { + t.t.Helper() if err != nil { if len(msg) == 0 { t.t.Error(err) @@ -114,6 +115,7 @@ func (t Test) Error(err error, msg ...interface{}) { } func (t Test) Errorf(err error, format string, args ...interface{}) { + t.t.Helper() if err != nil { t.t.Errorf(format, args...) } @@ -296,3 +298,88 @@ func TestNamedQueries(t *testing.T) { }) } + +func TestNamedBulkInsert(t *testing.T) { + type Val struct { + K string `db:"k"` + V int `db:"v"` + } + + vs := []Val{{K: "x"}, {K: "y"}, {K: "z"}} + table := []struct { + values []interface{} + q, expected string + }{ + { + values: []interface{}{vs[0]}, + q: "INSERT INTO val (k) VALUES (:k)", + expected: "INSERT INTO val (k) VALUES (?)", + }, + { + values: []interface{}{vs[0], vs[1]}, + q: "INSERT INTO val (k) VALUES (:k)", + expected: "INSERT INTO val (k) VALUES (?),(?)", + }, + { + values: []interface{}{vs[0]}, + q: "INSERT INTO val (k,v) VALUES (:k,:v)", + expected: "INSERT INTO val (k,v) VALUES (?,?)", + }, + { + values: []interface{}{vs[0], vs[1]}, + q: "INSERT INTO val (k,v) VALUES (:k,:v)", + expected: "INSERT INTO val (k,v) VALUES (?,?),(?,?)", + }, + { + values: []interface{}{vs[0]}, + q: "INSERT INTO val (k,v) VALUES ( :k, :v )", + expected: "INSERT INTO val (k,v) VALUES ( ?, ? )", + }, + { + values: []interface{}{vs[0], vs[1]}, + q: "INSERT INTO val (k,v) VALUES ( :k, :v )", + expected: "INSERT INTO val (k,v) VALUES ( ?, ? ),( ?, ? )", + }, + // sql functions (0 arguments) + { + values: []interface{}{vs[0], vs[1]}, + q: func() string { + _, _, now := defaultSchema.Postgres() + return fmt.Sprintf("INSERT INTO val (k, v, added_at) VALUES (:k, :v, %v)", now) + }(), + expected: "INSERT INTO val (k, v, added_at) VALUES (?, ?, now()),(?, ?, now())", + }, + { + values: []interface{}{vs[0], vs[1]}, + q: func() string { + _, _, now := defaultSchema.MySQL() + return fmt.Sprintf("INSERT INTO val (k, v, added_at) VALUES (:k, :v, %v)", now) + }(), + expected: "INSERT INTO val (k, v, added_at) VALUES (?, ?, now()),(?, ?, now())", + }, + { + values: []interface{}{vs[0], vs[1]}, + q: func() string { + _, _, now := defaultSchema.Sqlite3() + return fmt.Sprintf("INSERT INTO val (k, v, added_at) VALUES (:k, :v, %v)", now) + }(), + expected: "INSERT INTO val (k, v, added_at) VALUES (?, ?, CURRENT_TIMESTAMP),(?, ?, CURRENT_TIMESTAMP)", + }, + // extra operation + { + values: []interface{}{vs[0], vs[1]}, + q: "INSERT INTO val (k,v) VALUES (:k,:v) ON DUPLICATE KEY UPDATE v = VALUES(v)", + expected: "INSERT INTO val (k,v) VALUES (?,?),(?,?) ON DUPLICATE KEY UPDATE v = VALUES(v)", + }, + } + + for _, test := range table { + actual, _, err := Named(test.q, test.values) + if err != nil { + t.Fatalf("unexpected error %+v, when len(values) == %d", err, len(test.values)) + } + if test.expected != actual { + t.Errorf("expected query is (len(values) == %d)\n\t%q\nbut actual result is\n\t%q", len(test.values), test.expected, actual) + } + } +}