diff --git a/pkg/sqlmodel/multivalue.go b/pkg/sqlmodel/multirow.go similarity index 67% rename from pkg/sqlmodel/multivalue.go rename to pkg/sqlmodel/multirow.go index b1dc05ab381..8e61ff1d942 100644 --- a/pkg/sqlmodel/multivalue.go +++ b/pkg/sqlmodel/multirow.go @@ -1,4 +1,4 @@ -// Copyright 2022 PingCAP, Inc. +// Copyright 2023 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,16 +16,14 @@ package sqlmodel import ( "strings" + "github.com/pingcap/log" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/format" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/opcode" driver "github.com/pingcap/tidb/types/parser_driver" - - "go.uber.org/zap" - - "github.com/pingcap/tiflow/dm/pkg/log" "github.com/pingcap/tiflow/pkg/quotes" + "go.uber.org/zap" ) // SameTypeTargetAndColumns check whether two row changes have same type, target @@ -309,7 +307,7 @@ func GenUpdateSQL(changes ...*RowChange) (string, []interface{}) { // a simple check about different number of WHERE values, not trying to // cover all cases if len(whereValues) != len(whereColumns) { - log.L().DPanic("len(whereValues) != len(whereColumns)", + log.Panic("len(whereValues) != len(whereColumns)", zap.Int("len(whereValues)", len(whereValues)), zap.Int("len(whereColumns)", len(whereColumns)), zap.Any("whereValues", whereValues), @@ -346,3 +344,149 @@ func GenUpdateSQL(changes ...*RowChange) (string, []interface{}) { } return buf.String(), args } + +// GenUpdateSQLFast generates the UPDATE SQL and its arguments. +// Input `changes` should have same target table and same columns for WHERE +// (typically same PK/NOT NULL UK), otherwise the behaviour is undefined. +// It is a faster version compared with GenUpdateSQL. +func GenUpdateSQLFast(changes ...*RowChange) (string, []any) { + if len(changes) == 0 { + log.L().DPanic("row changes is empty") + return "", nil + } + var buf strings.Builder + buf.Grow(1024) + + // Generate UPDATE `db`.`table` SET + first := changes[0] + buf.WriteString("UPDATE ") + buf.WriteString(first.targetTable.QuoteString()) + buf.WriteString(" SET ") + + // Pre-generate essential sub statements used after WHEN, WHERE and IN. + var ( + whereCaseStmt string + whenCaseStmt string + inCaseStmt string + ) + whereColumns, _ := first.whereColumnsAndValues() + if len(whereColumns) == 1 { + // one field PK or UK, use `field`=? directly. + whereCaseStmt = quotes.QuoteName(whereColumns[0]) + whenCaseStmt = whereCaseStmt + "=?" + inCaseStmt = valuesHolder(len(changes)) + } else { + // multiple fields PK or UK, use ROW(...fields) expression. + whereValuesHolder := valuesHolder(len(whereColumns)) + whereCaseStmt = "ROW(" + for i, column := range whereColumns { + whereCaseStmt += quotes.QuoteName(column) + if i != len(whereColumns)-1 { + whereCaseStmt += "," + } else { + whereCaseStmt += ")" + whenCaseStmt = whereCaseStmt + "=ROW" + whereValuesHolder + } + } + var inCaseStmtBuf strings.Builder + // inCaseStmt sample: IN (ROW(?,?,?),ROW(?,?,?)) + // ^ ^ + // Buffer size count between |---------------------| + // equals to 3 * len(changes) for each `ROW` + // plus 1 * len(changes) - 1 for each `,` between every two ROW(?,?,?) + // plus len(whereValuesHolder) * len(changes) + // plus 2 for `(` and `)` + inCaseStmtBuf.Grow((4+len(whereValuesHolder))*len(changes) + 1) + inCaseStmtBuf.WriteString("(") + for i := range changes { + inCaseStmtBuf.WriteString("ROW") + inCaseStmtBuf.WriteString(whereValuesHolder) + if i != len(changes)-1 { + inCaseStmtBuf.WriteString(",") + } else { + inCaseStmtBuf.WriteString(")") + } + } + inCaseStmt = inCaseStmtBuf.String() + } + + // Generate `ColumnName`=CASE WHEN .. THEN .. END + // Use this value in order to identify which is the first CaseWhenThen line, + // because generated column can happen any where and it will be skipped. + isFirstCaseWhenThenLine := true + for _, column := range first.targetTableInfo.Columns { + if isGenerated(first.targetTableInfo.Columns, column.Name) { + continue + } + if !isFirstCaseWhenThenLine { + // insert ", " after END of each lines except for the first line. + buf.WriteString(", ") + } + + buf.WriteString(quotes.QuoteName(column.Name.String()) + "=CASE") + for range changes { + buf.WriteString(" WHEN ") + buf.WriteString(whenCaseStmt) + buf.WriteString(" THEN ?") + } + buf.WriteString(" END") + isFirstCaseWhenThenLine = false + } + + // Generate WHERE .. IN .. + buf.WriteString(" WHERE ") + buf.WriteString(whereCaseStmt) + buf.WriteString(" IN ") + buf.WriteString(inCaseStmt) + + // Build args of the UPDATE SQL + var assignValueColumnCount int + var skipColIdx []int + for i, col := range first.sourceTableInfo.Columns { + if isGenerated(first.targetTableInfo.Columns, col.Name) { + skipColIdx = append(skipColIdx, i) + continue + } + assignValueColumnCount++ + } + args := make([]any, 0, + assignValueColumnCount*len(changes)*(len(whereColumns)+1)+len(changes)*len(whereColumns)) + argsPerCol := make([][]any, assignValueColumnCount) + for i := 0; i < assignValueColumnCount; i++ { + argsPerCol[i] = make([]any, 0, len(changes)*(len(whereColumns)+1)) + } + whereValuesAtTheEnd := make([]any, 0, len(changes)*len(whereColumns)) + for _, change := range changes { + _, whereValues := change.whereColumnsAndValues() + // a simple check about different number of WHERE values, not trying to + // cover all cases + if len(whereValues) != len(whereColumns) { + log.Panic("len(whereValues) != len(whereColumns)", + zap.Int("len(whereValues)", len(whereValues)), + zap.Int("len(whereColumns)", len(whereColumns)), + zap.Any("whereValues", whereValues), + zap.Stringer("sourceTable", change.sourceTable)) + return "", nil + } + + whereValuesAtTheEnd = append(whereValuesAtTheEnd, whereValues...) + + i := 0 // used as index of skipColIdx + writeableCol := 0 + for j, val := range change.postValues { + if i < len(skipColIdx) && skipColIdx[i] == j { + i++ + continue + } + argsPerCol[writeableCol] = append(argsPerCol[writeableCol], whereValues...) + argsPerCol[writeableCol] = append(argsPerCol[writeableCol], val) + writeableCol++ + } + } + for _, a := range argsPerCol { + args = append(args, a...) + } + args = append(args, whereValuesAtTheEnd...) + + return buf.String(), args +} diff --git a/pkg/sqlmodel/multirow_bench_test.go b/pkg/sqlmodel/multirow_bench_test.go new file mode 100644 index 00000000000..496758274b3 --- /dev/null +++ b/pkg/sqlmodel/multirow_bench_test.go @@ -0,0 +1,110 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlmodel + +import ( + "fmt" + "testing" + "time" + + cdcmodel "github.com/pingcap/tiflow/cdc/model" +) + +func prepareDataOneColoumnPK(t *testing.T, batch int) []*RowChange { + source := &cdcmodel.TableName{Schema: "db", Table: "tb"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI := mockTableInfo(t, `CREATE TABLE tb (c INT, c2 INT, c3 INT, + c4 VARCHAR(10), c5 VARCHAR(100), c6 VARCHAR(1000), PRIMARY KEY (c))`) + targetTI := mockTableInfo(t, `CREATE TABLE tb (c INT, c2 INT, c3 INT, + c4 VARCHAR(10), c5 VARCHAR(100), c6 VARCHAR(1000), PRIMARY KEY (c))`) + + changes := make([]*RowChange, 0, batch) + for i := 0; i < batch; i++ { + change := NewRowChange(source, target, + []interface{}{i + 1, i + 2, i + 3, "c4", "c5", "c6"}, + []interface{}{i + 10, i + 20, i + 30, "c4", "c5", "c6"}, + sourceTI, targetTI, nil) + changes = append(changes, change) + } + return changes +} + +func prepareDataMultiColumnsPK(t *testing.T, batch int) []*RowChange { + source := &cdcmodel.TableName{Schema: "db", Table: "tb"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI := mockTableInfo(t, `CREATE TABLE tb (c1 INT, c2 INT, c3 INT, c4 INT, + c5 VARCHAR(10), c6 VARCHAR(100), c7 VARCHAR(1000), c8 timestamp, c9 timestamp, + PRIMARY KEY (c1, c2, c3, c4))`) + targetTI := mockTableInfo(t, `CREATE TABLE tb (c1 INT, c2 INT, c3 INT, c4 INT, + c5 VARCHAR(10), c6 VARCHAR(100), c7 VARCHAR(1000), c8 timestamp, c9 timestamp, + PRIMARY KEY (c1, c2, c3, c4))`) + + changes := make([]*RowChange, 0, batch) + for i := 0; i < batch; i++ { + change := NewRowChange(source, target, + []interface{}{i + 1, i + 2, i + 3, i + 4, "c4", "c5", "c6", "c7", time.Time{}, time.Time{}}, + []interface{}{i + 10, i + 20, i + 30, i + 40, "c4", "c5", "c6", "c7", time.Time{}, time.Time{}}, + sourceTI, targetTI, nil) + changes = append(changes, change) + } + return changes +} + +// bench cmd: go test -run='^$' -benchmem -bench '^(BenchmarkGenUpdate)$' github.com/pingcap/tiflow/pkg/sqlmodel +func BenchmarkGenUpdate(b *testing.B) { + t := &testing.T{} + type genCase struct { + name string + fn genSQLFunc + prepare func(t *testing.T, batch int) []*RowChange + } + batchSizes := []int{ + 1, 2, 4, 8, 16, 32, 64, 128, + } + benchCases := []genCase{ + { + name: "OneColumnPK-GenUpdateSQL", + fn: GenUpdateSQL, + prepare: prepareDataOneColoumnPK, + }, + { + name: "OneColumnPK-GenUpdateSQLFast", + fn: GenUpdateSQLFast, + prepare: prepareDataOneColoumnPK, + }, + { + name: "MultiColumnsPK-GenUpdateSQL", + fn: GenUpdateSQL, + prepare: prepareDataMultiColumnsPK, + }, + { + name: "MultiColumnsPK-GenUpdateSQLFast", + fn: GenUpdateSQLFast, + prepare: prepareDataMultiColumnsPK, + }, + } + for _, bc := range benchCases { + for _, batch := range batchSizes { + name := fmt.Sprintf("%s-Batch%d", bc.name, batch) + b.Run(name, func(b *testing.B) { + changes := prepareDataOneColoumnPK(t, batch) + for i := 0; i < b.N; i++ { + bc.fn(changes...) + } + }) + } + } +} diff --git a/pkg/sqlmodel/multirow_test.go b/pkg/sqlmodel/multirow_test.go new file mode 100644 index 00000000000..3fd84ec9a2e --- /dev/null +++ b/pkg/sqlmodel/multirow_test.go @@ -0,0 +1,240 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlmodel + +import ( + "testing" + + cdcmodel "github.com/pingcap/tiflow/cdc/model" + "github.com/stretchr/testify/require" +) + +type genSQLFunc func(changes ...*RowChange) (string, []interface{}) + +func TestGenDeleteMultiRows(t *testing.T) { + t.Parallel() + + source1 := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + source2 := &cdcmodel.TableName{Schema: "db", Table: "tb2"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI1 := mockTableInfo(t, "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT)") + sourceTI2 := mockTableInfo(t, "CREATE TABLE tb2 (c INT PRIMARY KEY, c2 INT)") + targetTI := mockTableInfo(t, "CREATE TABLE tb (c INT PRIMARY KEY, c2 INT)") + + change1 := NewRowChange(source1, target, []interface{}{1, 2}, nil, sourceTI1, targetTI, nil) + change2 := NewRowChange(source2, target, []interface{}{3, 4}, nil, sourceTI2, targetTI, nil) + sql, args := GenDeleteSQL(change1, change2) + + require.Equal(t, "DELETE FROM `db`.`tb` WHERE (`c`) IN ((?),(?))", sql) + require.Equal(t, []interface{}{1, 3}, args) +} + +func TestGenUpdateMultiRows(t *testing.T) { + t.Parallel() + testGenUpdateMultiRows(t, GenUpdateSQL) + testGenUpdateMultiRows(t, GenUpdateSQLFast) +} + +func TestGenUpdateMultiRowsOneColPK(t *testing.T) { + t.Parallel() + testGenUpdateMultiRowsOneColPK(t, GenUpdateSQL) + testGenUpdateMultiRowsOneColPK(t, GenUpdateSQLFast) +} + +func TestGenUpdateMultiRowsWithVirtualGeneratedColumn(t *testing.T) { + t.Parallel() + testGenUpdateMultiRowsWithVirtualGeneratedColumn(t, GenUpdateSQL) + testGenUpdateMultiRowsWithVirtualGeneratedColumn(t, GenUpdateSQLFast) + testGenUpdateMultiRowsWithVirtualGeneratedColumns(t, GenUpdateSQL) + testGenUpdateMultiRowsWithVirtualGeneratedColumns(t, GenUpdateSQLFast) +} + +func TestGenUpdateMultiRowsWithStoredGeneratedColumn(t *testing.T) { + t.Parallel() + testGenUpdateMultiRowsWithStoredGeneratedColumn(t, GenUpdateSQL) + testGenUpdateMultiRowsWithStoredGeneratedColumn(t, GenUpdateSQLFast) +} + +func testGenUpdateMultiRows(t *testing.T, genUpdate genSQLFunc) { + source1 := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + source2 := &cdcmodel.TableName{Schema: "db", Table: "tb2"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI1 := mockTableInfo(t, "CREATE TABLE tb1 (c INT, c2 INT, c3 INT, PRIMARY KEY (c, c2))") + sourceTI2 := mockTableInfo(t, "CREATE TABLE tb2 (c INT, c2 INT, c3 INT, PRIMARY KEY (c, c2))") + targetTI := mockTableInfo(t, "CREATE TABLE tb (c INT, c2 INT, c3 INT, PRIMARY KEY (c, c2))") + + change1 := NewRowChange(source1, target, []interface{}{1, 2, 3}, []interface{}{10, 20, 30}, sourceTI1, targetTI, nil) + change2 := NewRowChange(source2, target, []interface{}{4, 5, 6}, []interface{}{40, 50, 60}, sourceTI2, targetTI, nil) + sql, args := genUpdate(change1, change2) + + expectedSQL := "UPDATE `db`.`tb` SET " + + "`c`=CASE WHEN ROW(`c`,`c2`)=ROW(?,?) THEN ? WHEN ROW(`c`,`c2`)=ROW(?,?) THEN ? END, " + + "`c2`=CASE WHEN ROW(`c`,`c2`)=ROW(?,?) THEN ? WHEN ROW(`c`,`c2`)=ROW(?,?) THEN ? END, " + + "`c3`=CASE WHEN ROW(`c`,`c2`)=ROW(?,?) THEN ? WHEN ROW(`c`,`c2`)=ROW(?,?) THEN ? END " + + "WHERE ROW(`c`,`c2`) IN (ROW(?,?),ROW(?,?))" + expectedArgs := []interface{}{ + 1, 2, 10, 4, 5, 40, + 1, 2, 20, 4, 5, 50, + 1, 2, 30, 4, 5, 60, + 1, 2, 4, 5, + } + + require.Equal(t, expectedSQL, sql) + require.Equal(t, expectedArgs, args) +} + +func testGenUpdateMultiRowsOneColPK(t *testing.T, genUpdate genSQLFunc) { + source1 := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + source2 := &cdcmodel.TableName{Schema: "db", Table: "tb2"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI1 := mockTableInfo(t, "CREATE TABLE tb1 (c INT, c2 INT, c3 INT, PRIMARY KEY (c))") + sourceTI2 := mockTableInfo(t, "CREATE TABLE tb2 (c INT, c2 INT, c3 INT, PRIMARY KEY (c))") + targetTI := mockTableInfo(t, "CREATE TABLE tb (c INT, c2 INT, c3 INT, PRIMARY KEY (c))") + + change1 := NewRowChange(source1, target, []interface{}{1, 2, 3}, []interface{}{10, 20, 30}, sourceTI1, targetTI, nil) + change2 := NewRowChange(source2, target, []interface{}{4, 5, 6}, []interface{}{40, 50, 60}, sourceTI2, targetTI, nil) + sql, args := genUpdate(change1, change2) + + expectedSQL := "UPDATE `db`.`tb` SET " + + "`c`=CASE WHEN `c`=? THEN ? WHEN `c`=? THEN ? END, " + + "`c2`=CASE WHEN `c`=? THEN ? WHEN `c`=? THEN ? END, " + + "`c3`=CASE WHEN `c`=? THEN ? WHEN `c`=? THEN ? END " + + "WHERE `c` IN (?,?)" + expectedArgs := []interface{}{ + 1, 10, 4, 40, + 1, 20, 4, 50, + 1, 30, 4, 60, + 1, 4, + } + + require.Equal(t, expectedSQL, sql) + require.Equal(t, expectedArgs, args) +} + +func testGenUpdateMultiRowsWithVirtualGeneratedColumn(t *testing.T, genUpdate genSQLFunc) { + source := &cdcmodel.TableName{Schema: "db", Table: "tb"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI := mockTableInfo(t, "CREATE TABLE tb1 (c INT, c1 int as (c+100) virtual not null, c2 INT, c3 INT, PRIMARY KEY (c))") + targetTI := mockTableInfo(t, "CREATE TABLE tb (c INT, c1 int as (c+100) virtual not null, c2 INT, c3 INT, PRIMARY KEY (c))") + + change1 := NewRowChange(source, target, []interface{}{1, 101, 2, 3}, []interface{}{10, 110, 20, 30}, sourceTI, targetTI, nil) + change2 := NewRowChange(source, target, []interface{}{4, 104, 5, 6}, []interface{}{40, 140, 50, 60}, sourceTI, targetTI, nil) + change3 := NewRowChange(source, target, []interface{}{7, 107, 8, 9}, []interface{}{70, 170, 80, 90}, sourceTI, targetTI, nil) + sql, args := genUpdate(change1, change2, change3) + + expectedSQL := "UPDATE `db`.`tb` SET " + + "`c`=CASE WHEN `c`=? THEN ? WHEN `c`=? THEN ? WHEN `c`=? THEN ? END, " + + "`c2`=CASE WHEN `c`=? THEN ? WHEN `c`=? THEN ? WHEN `c`=? THEN ? END, " + + "`c3`=CASE WHEN `c`=? THEN ? WHEN `c`=? THEN ? WHEN `c`=? THEN ? END " + + "WHERE `c` IN (?,?,?)" + expectedArgs := []interface{}{ + 1, 10, 4, 40, 7, 70, + 1, 20, 4, 50, 7, 80, + 1, 30, 4, 60, 7, 90, + 1, 4, 7, + } + + require.Equal(t, expectedSQL, sql) + require.Equal(t, expectedArgs, args) +} + +// multiple generated columns test case +func testGenUpdateMultiRowsWithVirtualGeneratedColumns(t *testing.T, genUpdate genSQLFunc) { + source := &cdcmodel.TableName{Schema: "db", Table: "tb"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI := mockTableInfo(t, `CREATE TABLE tb1 (c0 int as (c4*c4) virtual not null, + c1 int as (c+100) virtual not null, c2 INT, c3 INT, c4 INT, PRIMARY KEY (c4))`) + targetTI := mockTableInfo(t, `CREATE TABLE tb (c0 int as (c4*c4) virtual not null, + c1 int as (c+100) virtual not null, c2 INT, c3 INT, c4 INT, PRIMARY KEY (c4))`) + + change1 := NewRowChange(source, target, []interface{}{1, 101, 2, 3, 1}, []interface{}{100, 110, 20, 30, 10}, sourceTI, targetTI, nil) + change2 := NewRowChange(source, target, []interface{}{16, 104, 5, 6, 4}, []interface{}{1600, 140, 50, 60, 40}, sourceTI, targetTI, nil) + change3 := NewRowChange(source, target, []interface{}{49, 107, 8, 9, 7}, []interface{}{4900, 170, 80, 90, 70}, sourceTI, targetTI, nil) + sql, args := genUpdate(change1, change2, change3) + + expectedSQL := "UPDATE `db`.`tb` SET " + + "`c2`=CASE WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? END, " + + "`c3`=CASE WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? END, " + + "`c4`=CASE WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? END " + + "WHERE `c4` IN (?,?,?)" + expectedArgs := []interface{}{ + 1, 20, 4, 50, 7, 80, + 1, 30, 4, 60, 7, 90, + 1, 10, 4, 40, 7, 70, + 1, 4, 7, + } + + require.Equal(t, expectedSQL, sql) + require.Equal(t, expectedArgs, args) +} + +func testGenUpdateMultiRowsWithStoredGeneratedColumn(t *testing.T, genUpdate genSQLFunc) { + source := &cdcmodel.TableName{Schema: "db", Table: "tb"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI := mockTableInfo(t, "CREATE TABLE tb1 (c INT, c1 int as (c+100) stored, c2 INT, c3 INT, PRIMARY KEY (c1))") + targetTI := mockTableInfo(t, "CREATE TABLE tb (c INT, c1 int as (c+100) stored, c2 INT, c3 INT, PRIMARY KEY (c1))") + + change1 := NewRowChange(source, target, []interface{}{1, 101, 2, 3}, []interface{}{10, 110, 20, 30}, sourceTI, targetTI, nil) + change2 := NewRowChange(source, target, []interface{}{4, 104, 5, 6}, []interface{}{40, 140, 50, 60}, sourceTI, targetTI, nil) + change3 := NewRowChange(source, target, []interface{}{7, 107, 8, 9}, []interface{}{70, 170, 80, 90}, sourceTI, targetTI, nil) + sql, args := genUpdate(change1, change2, change3) + + expectedSQL := "UPDATE `db`.`tb` SET " + + "`c`=CASE WHEN `c1`=? THEN ? WHEN `c1`=? THEN ? WHEN `c1`=? THEN ? END, " + + "`c2`=CASE WHEN `c1`=? THEN ? WHEN `c1`=? THEN ? WHEN `c1`=? THEN ? END, " + + "`c3`=CASE WHEN `c1`=? THEN ? WHEN `c1`=? THEN ? WHEN `c1`=? THEN ? END " + + "WHERE `c1` IN (?,?,?)" + expectedArgs := []interface{}{ + 101, 10, 104, 40, 107, 70, + 101, 20, 104, 50, 107, 80, + 101, 30, 104, 60, 107, 90, + 101, 104, 107, + } + + require.Equal(t, expectedSQL, sql) + require.Equal(t, expectedArgs, args) +} + +func TestGenInsertMultiRows(t *testing.T) { + t.Parallel() + + source1 := &cdcmodel.TableName{Schema: "db", Table: "tb1"} + source2 := &cdcmodel.TableName{Schema: "db", Table: "tb2"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI1 := mockTableInfo(t, "CREATE TABLE tb1 (gen INT AS (c+1), c INT PRIMARY KEY, c2 INT)") + sourceTI2 := mockTableInfo(t, "CREATE TABLE tb2 (gen INT AS (c+1), c INT PRIMARY KEY, c2 INT)") + targetTI := mockTableInfo(t, "CREATE TABLE tb (gen INT AS (c+1), c INT PRIMARY KEY, c2 INT)") + + change1 := NewRowChange(source1, target, nil, []interface{}{2, 1, 2}, sourceTI1, targetTI, nil) + change2 := NewRowChange(source2, target, nil, []interface{}{4, 3, 4}, sourceTI2, targetTI, nil) + + sql, args := GenInsertSQL(DMLInsert, change1, change2) + require.Equal(t, "INSERT INTO `db`.`tb` (`c`,`c2`) VALUES (?,?),(?,?)", sql) + require.Equal(t, []interface{}{1, 2, 3, 4}, args) + + sql, args = GenInsertSQL(DMLReplace, change1, change2) + require.Equal(t, "REPLACE INTO `db`.`tb` (`c`,`c2`) VALUES (?,?),(?,?)", sql) + require.Equal(t, []interface{}{1, 2, 3, 4}, args) + + sql, args = GenInsertSQL(DMLInsertOnDuplicateUpdate, change1, change2) + require.Equal(t, "INSERT INTO `db`.`tb` (`c`,`c2`) VALUES (?,?),(?,?) ON DUPLICATE KEY UPDATE `c`=VALUES(`c`),`c2`=VALUES(`c2`)", sql) + require.Equal(t, []interface{}{1, 2, 3, 4}, args) +} diff --git a/pkg/sqlmodel/multivalue_test.go b/pkg/sqlmodel/multivalue_test.go deleted file mode 100644 index a06326d5ee9..00000000000 --- a/pkg/sqlmodel/multivalue_test.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlmodel - -import ( - "testing" - - "github.com/stretchr/testify/require" - - cdcmodel "github.com/pingcap/tiflow/cdc/model" -) - -func TestGenDeleteMultiValue(t *testing.T) { - t.Parallel() - - source1 := &cdcmodel.TableName{Schema: "db", Table: "tb1"} - source2 := &cdcmodel.TableName{Schema: "db", Table: "tb2"} - target := &cdcmodel.TableName{Schema: "db", Table: "tb"} - - sourceTI1 := mockTableInfo(t, "CREATE TABLE tb1 (c INT PRIMARY KEY, c2 INT)") - sourceTI2 := mockTableInfo(t, "CREATE TABLE tb2 (c INT PRIMARY KEY, c2 INT)") - targetTI := mockTableInfo(t, "CREATE TABLE tb (c INT PRIMARY KEY, c2 INT)") - - change1 := NewRowChange(source1, target, []interface{}{1, 2}, nil, sourceTI1, targetTI, nil) - change2 := NewRowChange(source2, target, []interface{}{3, 4}, nil, sourceTI2, targetTI, nil) - sql, args := GenDeleteSQL(change1, change2) - - require.Equal(t, "DELETE FROM `db`.`tb` WHERE (`c`) IN ((?),(?))", sql) - require.Equal(t, []interface{}{1, 3}, args) -} - -func TestGenInsertMultiValue(t *testing.T) { - t.Parallel() - - source1 := &cdcmodel.TableName{Schema: "db", Table: "tb1"} - source2 := &cdcmodel.TableName{Schema: "db", Table: "tb2"} - target := &cdcmodel.TableName{Schema: "db", Table: "tb"} - - sourceTI1 := mockTableInfo(t, "CREATE TABLE tb1 (gen INT AS (c+1), c INT PRIMARY KEY, c2 INT)") - sourceTI2 := mockTableInfo(t, "CREATE TABLE tb2 (gen INT AS (c+1), c INT PRIMARY KEY, c2 INT)") - targetTI := mockTableInfo(t, "CREATE TABLE tb (gen INT AS (c+1), c INT PRIMARY KEY, c2 INT)") - - change1 := NewRowChange(source1, target, nil, []interface{}{2, 1, 2}, sourceTI1, targetTI, nil) - change2 := NewRowChange(source2, target, nil, []interface{}{4, 3, 4}, sourceTI2, targetTI, nil) - - sql, args := GenInsertSQL(DMLInsert, change1, change2) - require.Equal(t, "INSERT INTO `db`.`tb` (`c`,`c2`) VALUES (?,?),(?,?)", sql) - require.Equal(t, []interface{}{1, 2, 3, 4}, args) - - sql, args = GenInsertSQL(DMLReplace, change1, change2) - require.Equal(t, "REPLACE INTO `db`.`tb` (`c`,`c2`) VALUES (?,?),(?,?)", sql) - require.Equal(t, []interface{}{1, 2, 3, 4}, args) - - sql, args = GenInsertSQL(DMLInsertOnDuplicateUpdate, change1, change2) - require.Equal(t, "INSERT INTO `db`.`tb` (`c`,`c2`) VALUES (?,?),(?,?) ON DUPLICATE KEY UPDATE `c`=VALUES(`c`),`c2`=VALUES(`c2`)", sql) - require.Equal(t, []interface{}{1, 2, 3, 4}, args) -}