diff --git a/pkg/sqlmodel/multirow.go b/pkg/sqlmodel/multirow.go index 0f7341bf749..452a69bb2f9 100644 --- a/pkg/sqlmodel/multirow.go +++ b/pkg/sqlmodel/multirow.go @@ -16,12 +16,12 @@ package sqlmodel import ( "strings" + "github.com/pingcap/log" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/format" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/opcode" driver "github.com/pingcap/tidb/types/parser_driver" - "github.com/pingcap/tiflow/dm/pkg/log" "github.com/pingcap/tiflow/pkg/quotes" "go.uber.org/zap" ) @@ -122,6 +122,152 @@ func GenDeleteSQL(changes ...*RowChange) (string, []interface{}) { return buf.String(), args } +// GenUpdateSQLFast generates the UPDATE SQL and its arguments. +// Input `changes` should have same target table and same columns for WHERE +// (typically same PK/NOT NULL UK), otherwise the behaviour is undefined. +// It is a faster version compared with GenUpdateSQL. +func GenUpdateSQLFast(changes ...*RowChange) (string, []any) { + if len(changes) == 0 { + log.L().DPanic("row changes is empty") + return "", nil + } + var buf strings.Builder + buf.Grow(1024) + + // Generate UPDATE `db`.`table` SET + first := changes[0] + buf.WriteString("UPDATE ") + buf.WriteString(first.targetTable.QuoteString()) + buf.WriteString(" SET ") + + // Pre-generate essential sub statements used after WHEN, WHERE and IN. + var ( + whereCaseStmt string + whenCaseStmt string + inCaseStmt string + ) + whereColumns, _ := first.whereColumnsAndValues() + if len(whereColumns) == 1 { + // one field PK or UK, use `field`=? directly. + whereCaseStmt = quotes.QuoteName(whereColumns[0]) + whenCaseStmt = whereCaseStmt + "=?" + inCaseStmt = valuesHolder(len(changes)) + } else { + // multiple fields PK or UK, use ROW(...fields) expression. + whereValuesHolder := valuesHolder(len(whereColumns)) + whereCaseStmt = "ROW(" + for i, column := range whereColumns { + whereCaseStmt += quotes.QuoteName(column) + if i != len(whereColumns)-1 { + whereCaseStmt += "," + } else { + whereCaseStmt += ")" + whenCaseStmt = whereCaseStmt + "=ROW" + whereValuesHolder + } + } + var inCaseStmtBuf strings.Builder + // inCaseStmt sample: IN (ROW(?,?,?),ROW(?,?,?)) + // ^ ^ + // Buffer size count between |---------------------| + // equals to 3 * len(changes) for each `ROW` + // plus 1 * len(changes) - 1 for each `,` between every two ROW(?,?,?) + // plus len(whereValuesHolder) * len(changes) + // plus 2 for `(` and `)` + inCaseStmtBuf.Grow((4+len(whereValuesHolder))*len(changes) + 1) + inCaseStmtBuf.WriteString("(") + for i := range changes { + inCaseStmtBuf.WriteString("ROW") + inCaseStmtBuf.WriteString(whereValuesHolder) + if i != len(changes)-1 { + inCaseStmtBuf.WriteString(",") + } else { + inCaseStmtBuf.WriteString(")") + } + } + inCaseStmt = inCaseStmtBuf.String() + } + + // Generate `ColumnName`=CASE WHEN .. THEN .. END + // Use this value in order to identify which is the first CaseWhenThen line, + // because generated column can happen any where and it will be skipped. + isFirstCaseWhenThenLine := true + for _, column := range first.targetTableInfo.Columns { + if isGenerated(first.targetTableInfo.Columns, column.Name) { + continue + } + if !isFirstCaseWhenThenLine { + // insert ", " after END of each lines except for the first line. + buf.WriteString(", ") + } + + buf.WriteString(quotes.QuoteName(column.Name.String()) + "=CASE") + for range changes { + buf.WriteString(" WHEN ") + buf.WriteString(whenCaseStmt) + buf.WriteString(" THEN ?") + } + buf.WriteString(" END") + isFirstCaseWhenThenLine = false + } + + // Generate WHERE .. IN .. + buf.WriteString(" WHERE ") + buf.WriteString(whereCaseStmt) + buf.WriteString(" IN ") + buf.WriteString(inCaseStmt) + + // Build args of the UPDATE SQL + var assignValueColumnCount int + var skipColIdx []int + for i, col := range first.sourceTableInfo.Columns { + if isGenerated(first.targetTableInfo.Columns, col.Name) { + skipColIdx = append(skipColIdx, i) + continue + } + assignValueColumnCount++ + } + args := make([]any, 0, + assignValueColumnCount*len(changes)*(len(whereColumns)+1)+len(changes)*len(whereColumns)) + argsPerCol := make([][]any, assignValueColumnCount) + for i := 0; i < assignValueColumnCount; i++ { + argsPerCol[i] = make([]any, 0, len(changes)*(len(whereColumns)+1)) + } + whereValuesAtTheEnd := make([]any, 0, len(changes)*len(whereColumns)) + for _, change := range changes { + _, whereValues := change.whereColumnsAndValues() + // a simple check about different number of WHERE values, not trying to + // cover all cases + if len(whereValues) != len(whereColumns) { + log.Panic("len(whereValues) != len(whereColumns)", + zap.Int("len(whereValues)", len(whereValues)), + zap.Int("len(whereColumns)", len(whereColumns)), + zap.Any("whereValues", whereValues), + zap.Stringer("sourceTable", change.sourceTable)) + return "", nil + } + + whereValuesAtTheEnd = append(whereValuesAtTheEnd, whereValues...) + + i := 0 // used as index of skipColIdx + writeableCol := 0 + for j, val := range change.postValues { + if i < len(skipColIdx) && skipColIdx[i] == j { + i++ + continue + } + argsPerCol[writeableCol] = append(argsPerCol[writeableCol], whereValues...) + argsPerCol[writeableCol] = append(argsPerCol[writeableCol], val) + writeableCol++ + } + } + for _, a := range argsPerCol { + args = append(args, a...) + } + args = append(args, whereValuesAtTheEnd...) + + return buf.String(), args +} + // GenUpdateSQL generates the UPDATE SQL and its arguments. // Input `changes` should have same target table and same columns for WHERE // (typically same PK/NOT NULL UK), otherwise the behaviour is undefined. @@ -224,7 +370,7 @@ func GenUpdateSQL(changes ...*RowChange) (string, []interface{}) { // a simple check about different number of WHERE values, not trying to // cover all cases if len(whereValues) != len(whereColumns) { - log.L().DPanic("len(whereValues) != len(whereColumns)", + log.Panic("len(whereValues) != len(whereColumns)", zap.Int("len(whereValues)", len(whereValues)), zap.Int("len(whereColumns)", len(whereColumns)), zap.Any("whereValues", whereValues), diff --git a/pkg/sqlmodel/multirow_bench_test.go b/pkg/sqlmodel/multirow_bench_test.go new file mode 100644 index 00000000000..496758274b3 --- /dev/null +++ b/pkg/sqlmodel/multirow_bench_test.go @@ -0,0 +1,110 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlmodel + +import ( + "fmt" + "testing" + "time" + + cdcmodel "github.com/pingcap/tiflow/cdc/model" +) + +func prepareDataOneColoumnPK(t *testing.T, batch int) []*RowChange { + source := &cdcmodel.TableName{Schema: "db", Table: "tb"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI := mockTableInfo(t, `CREATE TABLE tb (c INT, c2 INT, c3 INT, + c4 VARCHAR(10), c5 VARCHAR(100), c6 VARCHAR(1000), PRIMARY KEY (c))`) + targetTI := mockTableInfo(t, `CREATE TABLE tb (c INT, c2 INT, c3 INT, + c4 VARCHAR(10), c5 VARCHAR(100), c6 VARCHAR(1000), PRIMARY KEY (c))`) + + changes := make([]*RowChange, 0, batch) + for i := 0; i < batch; i++ { + change := NewRowChange(source, target, + []interface{}{i + 1, i + 2, i + 3, "c4", "c5", "c6"}, + []interface{}{i + 10, i + 20, i + 30, "c4", "c5", "c6"}, + sourceTI, targetTI, nil) + changes = append(changes, change) + } + return changes +} + +func prepareDataMultiColumnsPK(t *testing.T, batch int) []*RowChange { + source := &cdcmodel.TableName{Schema: "db", Table: "tb"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI := mockTableInfo(t, `CREATE TABLE tb (c1 INT, c2 INT, c3 INT, c4 INT, + c5 VARCHAR(10), c6 VARCHAR(100), c7 VARCHAR(1000), c8 timestamp, c9 timestamp, + PRIMARY KEY (c1, c2, c3, c4))`) + targetTI := mockTableInfo(t, `CREATE TABLE tb (c1 INT, c2 INT, c3 INT, c4 INT, + c5 VARCHAR(10), c6 VARCHAR(100), c7 VARCHAR(1000), c8 timestamp, c9 timestamp, + PRIMARY KEY (c1, c2, c3, c4))`) + + changes := make([]*RowChange, 0, batch) + for i := 0; i < batch; i++ { + change := NewRowChange(source, target, + []interface{}{i + 1, i + 2, i + 3, i + 4, "c4", "c5", "c6", "c7", time.Time{}, time.Time{}}, + []interface{}{i + 10, i + 20, i + 30, i + 40, "c4", "c5", "c6", "c7", time.Time{}, time.Time{}}, + sourceTI, targetTI, nil) + changes = append(changes, change) + } + return changes +} + +// bench cmd: go test -run='^$' -benchmem -bench '^(BenchmarkGenUpdate)$' github.com/pingcap/tiflow/pkg/sqlmodel +func BenchmarkGenUpdate(b *testing.B) { + t := &testing.T{} + type genCase struct { + name string + fn genSQLFunc + prepare func(t *testing.T, batch int) []*RowChange + } + batchSizes := []int{ + 1, 2, 4, 8, 16, 32, 64, 128, + } + benchCases := []genCase{ + { + name: "OneColumnPK-GenUpdateSQL", + fn: GenUpdateSQL, + prepare: prepareDataOneColoumnPK, + }, + { + name: "OneColumnPK-GenUpdateSQLFast", + fn: GenUpdateSQLFast, + prepare: prepareDataOneColoumnPK, + }, + { + name: "MultiColumnsPK-GenUpdateSQL", + fn: GenUpdateSQL, + prepare: prepareDataMultiColumnsPK, + }, + { + name: "MultiColumnsPK-GenUpdateSQLFast", + fn: GenUpdateSQLFast, + prepare: prepareDataMultiColumnsPK, + }, + } + for _, bc := range benchCases { + for _, batch := range batchSizes { + name := fmt.Sprintf("%s-Batch%d", bc.name, batch) + b.Run(name, func(b *testing.B) { + changes := prepareDataOneColoumnPK(t, batch) + for i := 0; i < b.N; i++ { + bc.fn(changes...) + } + }) + } + } +} diff --git a/pkg/sqlmodel/multirow_test.go b/pkg/sqlmodel/multirow_test.go index 793b6d44056..3fd84ec9a2e 100644 --- a/pkg/sqlmodel/multirow_test.go +++ b/pkg/sqlmodel/multirow_test.go @@ -20,6 +20,8 @@ import ( "github.com/stretchr/testify/require" ) +type genSQLFunc func(changes ...*RowChange) (string, []interface{}) + func TestGenDeleteMultiRows(t *testing.T) { t.Parallel() @@ -41,7 +43,31 @@ func TestGenDeleteMultiRows(t *testing.T) { func TestGenUpdateMultiRows(t *testing.T) { t.Parallel() + testGenUpdateMultiRows(t, GenUpdateSQL) + testGenUpdateMultiRows(t, GenUpdateSQLFast) +} + +func TestGenUpdateMultiRowsOneColPK(t *testing.T) { + t.Parallel() + testGenUpdateMultiRowsOneColPK(t, GenUpdateSQL) + testGenUpdateMultiRowsOneColPK(t, GenUpdateSQLFast) +} + +func TestGenUpdateMultiRowsWithVirtualGeneratedColumn(t *testing.T) { + t.Parallel() + testGenUpdateMultiRowsWithVirtualGeneratedColumn(t, GenUpdateSQL) + testGenUpdateMultiRowsWithVirtualGeneratedColumn(t, GenUpdateSQLFast) + testGenUpdateMultiRowsWithVirtualGeneratedColumns(t, GenUpdateSQL) + testGenUpdateMultiRowsWithVirtualGeneratedColumns(t, GenUpdateSQLFast) +} +func TestGenUpdateMultiRowsWithStoredGeneratedColumn(t *testing.T) { + t.Parallel() + testGenUpdateMultiRowsWithStoredGeneratedColumn(t, GenUpdateSQL) + testGenUpdateMultiRowsWithStoredGeneratedColumn(t, GenUpdateSQLFast) +} + +func testGenUpdateMultiRows(t *testing.T, genUpdate genSQLFunc) { source1 := &cdcmodel.TableName{Schema: "db", Table: "tb1"} source2 := &cdcmodel.TableName{Schema: "db", Table: "tb2"} target := &cdcmodel.TableName{Schema: "db", Table: "tb"} @@ -52,7 +78,7 @@ func TestGenUpdateMultiRows(t *testing.T) { change1 := NewRowChange(source1, target, []interface{}{1, 2, 3}, []interface{}{10, 20, 30}, sourceTI1, targetTI, nil) change2 := NewRowChange(source2, target, []interface{}{4, 5, 6}, []interface{}{40, 50, 60}, sourceTI2, targetTI, nil) - sql, args := GenUpdateSQL(change1, change2) + sql, args := genUpdate(change1, change2) expectedSQL := "UPDATE `db`.`tb` SET " + "`c`=CASE WHEN ROW(`c`,`c2`)=ROW(?,?) THEN ? WHEN ROW(`c`,`c2`)=ROW(?,?) THEN ? END, " + @@ -70,9 +96,7 @@ func TestGenUpdateMultiRows(t *testing.T) { require.Equal(t, expectedArgs, args) } -func TestGenUpdateMultiRowsOneColPK(t *testing.T) { - t.Parallel() - +func testGenUpdateMultiRowsOneColPK(t *testing.T, genUpdate genSQLFunc) { source1 := &cdcmodel.TableName{Schema: "db", Table: "tb1"} source2 := &cdcmodel.TableName{Schema: "db", Table: "tb2"} target := &cdcmodel.TableName{Schema: "db", Table: "tb"} @@ -83,7 +107,7 @@ func TestGenUpdateMultiRowsOneColPK(t *testing.T) { change1 := NewRowChange(source1, target, []interface{}{1, 2, 3}, []interface{}{10, 20, 30}, sourceTI1, targetTI, nil) change2 := NewRowChange(source2, target, []interface{}{4, 5, 6}, []interface{}{40, 50, 60}, sourceTI2, targetTI, nil) - sql, args := GenUpdateSQL(change1, change2) + sql, args := genUpdate(change1, change2) expectedSQL := "UPDATE `db`.`tb` SET " + "`c`=CASE WHEN `c`=? THEN ? WHEN `c`=? THEN ? END, " + @@ -101,8 +125,7 @@ func TestGenUpdateMultiRowsOneColPK(t *testing.T) { require.Equal(t, expectedArgs, args) } -func TestGenUpdateMultiRowsWithVirtualGeneratedColumn(t *testing.T) { - t.Parallel() +func testGenUpdateMultiRowsWithVirtualGeneratedColumn(t *testing.T, genUpdate genSQLFunc) { source := &cdcmodel.TableName{Schema: "db", Table: "tb"} target := &cdcmodel.TableName{Schema: "db", Table: "tb"} @@ -112,7 +135,7 @@ func TestGenUpdateMultiRowsWithVirtualGeneratedColumn(t *testing.T) { change1 := NewRowChange(source, target, []interface{}{1, 101, 2, 3}, []interface{}{10, 110, 20, 30}, sourceTI, targetTI, nil) change2 := NewRowChange(source, target, []interface{}{4, 104, 5, 6}, []interface{}{40, 140, 50, 60}, sourceTI, targetTI, nil) change3 := NewRowChange(source, target, []interface{}{7, 107, 8, 9}, []interface{}{70, 170, 80, 90}, sourceTI, targetTI, nil) - sql, args := GenUpdateSQL(change1, change2, change3) + sql, args := genUpdate(change1, change2, change3) expectedSQL := "UPDATE `db`.`tb` SET " + "`c`=CASE WHEN `c`=? THEN ? WHEN `c`=? THEN ? WHEN `c`=? THEN ? END, " + @@ -130,8 +153,38 @@ func TestGenUpdateMultiRowsWithVirtualGeneratedColumn(t *testing.T) { require.Equal(t, expectedArgs, args) } -func TestGenUpdateMultiRowsWithStoredGeneratedColumn(t *testing.T) { - t.Parallel() +// multiple generated columns test case +func testGenUpdateMultiRowsWithVirtualGeneratedColumns(t *testing.T, genUpdate genSQLFunc) { + source := &cdcmodel.TableName{Schema: "db", Table: "tb"} + target := &cdcmodel.TableName{Schema: "db", Table: "tb"} + + sourceTI := mockTableInfo(t, `CREATE TABLE tb1 (c0 int as (c4*c4) virtual not null, + c1 int as (c+100) virtual not null, c2 INT, c3 INT, c4 INT, PRIMARY KEY (c4))`) + targetTI := mockTableInfo(t, `CREATE TABLE tb (c0 int as (c4*c4) virtual not null, + c1 int as (c+100) virtual not null, c2 INT, c3 INT, c4 INT, PRIMARY KEY (c4))`) + + change1 := NewRowChange(source, target, []interface{}{1, 101, 2, 3, 1}, []interface{}{100, 110, 20, 30, 10}, sourceTI, targetTI, nil) + change2 := NewRowChange(source, target, []interface{}{16, 104, 5, 6, 4}, []interface{}{1600, 140, 50, 60, 40}, sourceTI, targetTI, nil) + change3 := NewRowChange(source, target, []interface{}{49, 107, 8, 9, 7}, []interface{}{4900, 170, 80, 90, 70}, sourceTI, targetTI, nil) + sql, args := genUpdate(change1, change2, change3) + + expectedSQL := "UPDATE `db`.`tb` SET " + + "`c2`=CASE WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? END, " + + "`c3`=CASE WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? END, " + + "`c4`=CASE WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? WHEN `c4`=? THEN ? END " + + "WHERE `c4` IN (?,?,?)" + expectedArgs := []interface{}{ + 1, 20, 4, 50, 7, 80, + 1, 30, 4, 60, 7, 90, + 1, 10, 4, 40, 7, 70, + 1, 4, 7, + } + + require.Equal(t, expectedSQL, sql) + require.Equal(t, expectedArgs, args) +} + +func testGenUpdateMultiRowsWithStoredGeneratedColumn(t *testing.T, genUpdate genSQLFunc) { source := &cdcmodel.TableName{Schema: "db", Table: "tb"} target := &cdcmodel.TableName{Schema: "db", Table: "tb"} @@ -141,7 +194,7 @@ func TestGenUpdateMultiRowsWithStoredGeneratedColumn(t *testing.T) { change1 := NewRowChange(source, target, []interface{}{1, 101, 2, 3}, []interface{}{10, 110, 20, 30}, sourceTI, targetTI, nil) change2 := NewRowChange(source, target, []interface{}{4, 104, 5, 6}, []interface{}{40, 140, 50, 60}, sourceTI, targetTI, nil) change3 := NewRowChange(source, target, []interface{}{7, 107, 8, 9}, []interface{}{70, 170, 80, 90}, sourceTI, targetTI, nil) - sql, args := GenUpdateSQL(change1, change2, change3) + sql, args := genUpdate(change1, change2, change3) expectedSQL := "UPDATE `db`.`tb` SET " + "`c`=CASE WHEN `c1`=? THEN ? WHEN `c1`=? THEN ? WHEN `c1`=? THEN ? END, " +