diff --git a/drainer/translator/mysql.go b/drainer/translator/mysql.go index d0f38e035..f7cac28f3 100644 --- a/drainer/translator/mysql.go +++ b/drainer/translator/mysql.go @@ -19,7 +19,6 @@ import ( "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/util/codec" ) const implicitColID = -1 @@ -598,64 +597,3 @@ func formatData(data types.Datum, ft types.FieldType) (types.Datum, error) { return data, nil } - -// DecodeOldAndNewRow decodes a byte slice into datums with a existing row map. -// Row layout: colID1, value1, colID2, value2, ..... -func DecodeOldAndNewRow(b []byte, cols map[int64]*types.FieldType, loc *time.Location) (map[int64]types.Datum, map[int64]types.Datum, error) { - if b == nil { - return nil, nil, nil - } - if b[0] == codec.NilFlag { - return nil, nil, nil - } - - cnt := 0 - var ( - data []byte - err error - oldRow = make(map[int64]types.Datum, len(cols)) - newRow = make(map[int64]types.Datum, len(cols)) - ) - for len(b) > 0 { - // Get col id. - data, b, err = codec.CutOne(b) - if err != nil { - return nil, nil, errors.Trace(err) - } - _, cid, err := codec.DecodeOne(data) - if err != nil { - return nil, nil, errors.Trace(err) - } - // Get col value. - data, b, err = codec.CutOne(b) - if err != nil { - return nil, nil, errors.Trace(err) - } - id := cid.GetInt64() - ft, ok := cols[id] - if ok { - v, err := tablecodec.DecodeColumnValue(data, ft, loc) - if err != nil { - return nil, nil, errors.Trace(err) - } - - if _, ok := oldRow[id]; ok { - newRow[id] = v - } else { - oldRow[id] = v - } - - cnt++ - if cnt == len(cols)*2 { - // Get enough data. - break - } - } - } - - if cnt != len(cols)*2 || len(newRow) != len(oldRow) { - return nil, nil, errors.Errorf(" row data is corruption %v", b) - } - - return oldRow, newRow, nil -} diff --git a/drainer/translator/sequence_iterator.go b/drainer/translator/sequence_iterator.go new file mode 100644 index 000000000..d0d5aa686 --- /dev/null +++ b/drainer/translator/sequence_iterator.go @@ -0,0 +1,48 @@ +package translator + +import ( + "io" + + "github.com/pingcap/errors" + "github.com/pingcap/tipb/go-binlog" +) + +// sequenceIterator is a helper to iterate row event by sequence +type sequenceIterator struct { + mutation *binlog.TableMutation + idx int + insertIdx int + deleteIdx int + updateIdx int +} + +func newSequenceIterator(mutation *binlog.TableMutation) *sequenceIterator { + return &sequenceIterator{mutation: mutation} +} + +func (si *sequenceIterator) next() (tp binlog.MutationType, row []byte, err error) { + if si.idx >= len(si.mutation.Sequence) { + err = io.EOF + return + } + + tp = si.mutation.Sequence[si.idx] + si.idx++ + + switch tp { + case binlog.MutationType_Insert: + row = si.mutation.InsertedRows[si.insertIdx] + si.insertIdx++ + case binlog.MutationType_Update: + row = si.mutation.UpdatedRows[si.updateIdx] + si.updateIdx++ + case binlog.MutationType_DeleteRow: + row = si.mutation.DeletedRows[si.deleteIdx] + si.deleteIdx++ + default: + err = errors.Errorf("unknown mutation type: %v", tp) + return + } + + return +} diff --git a/drainer/translator/sequence_iterator_test.go b/drainer/translator/sequence_iterator_test.go new file mode 100644 index 000000000..8d48169f2 --- /dev/null +++ b/drainer/translator/sequence_iterator_test.go @@ -0,0 +1,61 @@ +package translator + +import ( + "io" + + . "github.com/pingcap/check" + ti "github.com/pingcap/tipb/go-binlog" +) + +type testSequenceIteratorSuite struct{} + +var _ = Suite(&testSequenceIteratorSuite{}) + +func (t *testSequenceIteratorSuite) TestIterator(c *C) { + mut := new(ti.TableMutation) + var tps []ti.MutationType + var rows [][]byte + + // generate test data + for i := 0; i < 10; i++ { + row := []byte{byte(i)} + rows = append(rows, row) + switch i % 3 { + case 0: + mut.Sequence = append(mut.Sequence, ti.MutationType_Insert) + mut.InsertedRows = append(mut.InsertedRows, row) + tps = append(tps, ti.MutationType_Insert) + case 1: + mut.Sequence = append(mut.Sequence, ti.MutationType_Update) + mut.UpdatedRows = append(mut.UpdatedRows, row) + tps = append(tps, ti.MutationType_Update) + case 2: + mut.Sequence = append(mut.Sequence, ti.MutationType_DeleteRow) + mut.DeletedRows = append(mut.DeletedRows, row) + tps = append(tps, ti.MutationType_DeleteRow) + } + } + + // get back by iterator + iter := newSequenceIterator(mut) + var getTps []ti.MutationType + var getRows [][]byte + + for { + tp, row, err := iter.next() + if err == io.EOF { + break + } + + c.Assert(err, IsNil) + c.Fail() + break + } + + getTps = append(getTps, tp) + getRows = append(getRows, row) + } + + c.Assert(getTps, DeepEquals, tps) + c.Assert(getRows, DeepEquals, rows) +} diff --git a/drainer/translator/table_info.go b/drainer/translator/table_info.go new file mode 100644 index 000000000..d7cfcd786 --- /dev/null +++ b/drainer/translator/table_info.go @@ -0,0 +1,9 @@ +package translator + +import "github.com/pingcap/parser/model" + +// TableInfoGetter is used to get table info by table id of TiDB +type TableInfoGetter interface { + TableByID(id int64) (info *model.TableInfo, ok bool) + SchemaAndTableName(id int64) (string, string, bool) +} diff --git a/drainer/translator/testing.go b/drainer/translator/testing.go new file mode 100644 index 000000000..3f47d2ca6 --- /dev/null +++ b/drainer/translator/testing.go @@ -0,0 +1,367 @@ +package translator + +import ( + "fmt" + "time" + + . "github.com/pingcap/check" + "github.com/pingcap/parser/model" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/tablecodec" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/codec" + ti "github.com/pingcap/tipb/go-binlog" +) + +var _ TableInfoGetter = &BinlogGenrator{} + +// BinlogGenrator is a test helper for generating some binlog. +type BinlogGenrator struct { + TiBinlog *ti.Binlog + PV *ti.PrewriteValue + Schema string + Table string + + id2info map[int64]*model.TableInfo + id2name map[int64][2]string + + datums []types.Datum + oldDatums []types.Datum +} + +func (g *BinlogGenrator) reset() { + g.TiBinlog = nil + g.PV = nil + g.Schema = "" + g.Table = "" + g.id2info = make(map[int64]*model.TableInfo) + g.id2name = make(map[int64][2]string) +} + +// SetDelete set the info to be a delete event +func (g *BinlogGenrator) SetDelete(c *C) { + g.reset() + info := g.setEvent(c) + + row := testGenDeleteBinlog(c, info, g.datums) + + g.PV.Mutations = append(g.PV.Mutations, ti.TableMutation{ + TableId: info.ID, + DeletedRows: [][]byte{row}, + Sequence: []ti.MutationType{ti.MutationType_DeleteRow}, + }) +} + +func (g *BinlogGenrator) getDatums() (datums []types.Datum) { + datums = g.datums + return +} + +func (g *BinlogGenrator) getOldDatums() (datums []types.Datum) { + datums = g.oldDatums + return +} + +// TableByID implements TableInfoGetter interface +func (g *BinlogGenrator) TableByID(id int64) (info *model.TableInfo, ok bool) { + info, ok = g.id2info[id] + return +} + +// SchemaAndTableName implements TableInfoGetter interface +func (g *BinlogGenrator) SchemaAndTableName(id int64) (schema string, table string, ok bool) { + names, ok := g.id2name[id] + if !ok { + return "", "", false + } + + schema = names[0] + table = names[1] + ok = true + return +} + +// SetDDL set up a ddl binlog. +func (g *BinlogGenrator) SetDDL() { + g.reset() + g.TiBinlog = &ti.Binlog{ + Tp: ti.BinlogType_Commit, + StartTs: 100, + CommitTs: 200, + DdlQuery: []byte("create table test(id int)"), + DdlJobId: 1, + } + g.PV = nil + g.Schema = "test" + g.Table = "test" +} + +func (g *BinlogGenrator) setEvent(c *C) *model.TableInfo { + g.TiBinlog = &ti.Binlog{ + Tp: ti.BinlogType_Commit, + StartTs: 100, + CommitTs: 200, + } + + g.PV = new(ti.PrewriteValue) + + info := testGenTable("hasID") + g.id2info[info.ID] = info + g.id2name[info.ID] = [2]string{"test", info.Name.L} + + g.datums = testGenRandomDatums(c, info.Columns) + g.oldDatums = testGenRandomDatums(c, info.Columns) + c.Assert(len(g.datums), Equals, len(info.Columns)) + + return info +} + +// SetInsert set up a insert event binlog. +func (g *BinlogGenrator) SetInsert(c *C) { + g.reset() + info := g.setEvent(c) + + row := testGenInsertBinlog(c, info, g.datums) + g.PV.Mutations = append(g.PV.Mutations, ti.TableMutation{ + TableId: info.ID, + InsertedRows: [][]byte{row}, + Sequence: []ti.MutationType{ti.MutationType_Insert}, + }) +} + +// SetUpdate set up a update event binlog. +func (g *BinlogGenrator) SetUpdate(c *C) { + g.reset() + info := g.setEvent(c) + + row := testGenUpdateBinlog(c, info, g.oldDatums, g.datums) + + g.PV.Mutations = append(g.PV.Mutations, ti.TableMutation{ + TableId: info.ID, + UpdatedRows: [][]byte{row}, + Sequence: []ti.MutationType{ti.MutationType_Update}, + }) +} + +// hasID: create table t(id int primary key, name varchar(45), sex enum("male", "female")); +// hasPK: create table t(id int, name varchar(45), sex enum("male", "female"), PRIMARY KEY(id, name)); +// normal: create table t(id int, name varchar(45), sex enum("male", "female")); +func testGenTable(tt string) *model.TableInfo { + t := &model.TableInfo{State: model.StatePublic} + t.Name = model.NewCIStr("account") + + // the hard values are from TiDB :-), so just ingore them + userIDCol := &model.ColumnInfo{ + ID: 1, + Name: model.NewCIStr("ID"), + Offset: 0, + FieldType: types.FieldType{ + Tp: mysql.TypeLong, + Flag: mysql.BinaryFlag, + Flen: 11, + Decimal: -1, + Charset: "binary", + Collate: "binary", + }, + State: model.StatePublic, + } + + userNameCol := &model.ColumnInfo{ + ID: 2, + Name: model.NewCIStr("NAME"), + Offset: 1, + FieldType: types.FieldType{ + Tp: mysql.TypeVarchar, + Flag: 0, + Flen: 45, + Decimal: -1, + Charset: "utf8", + Collate: "utf8_unicode_ci", + }, + State: model.StatePublic, + } + + sexCol := &model.ColumnInfo{ + ID: 3, + Name: model.NewCIStr("SEX"), + Offset: 2, + FieldType: types.FieldType{ + Tp: mysql.TypeEnum, + Flag: mysql.BinaryFlag, + Flen: -1, + Decimal: -1, + Charset: "binary", + Collate: "binary", + Elems: []string{"male", "female"}, + }, + State: model.StatePublic, + } + + t.Columns = []*model.ColumnInfo{userIDCol, userNameCol, sexCol} + + switch tt { + case "hasID": + userIDCol.Flag = mysql.NotNullFlag | mysql.PriKeyFlag | mysql.BinaryFlag | mysql.NoDefaultValueFlag + + t.PKIsHandle = true + t.Indices = append(t.Indices, &model.IndexInfo{ + Primary: true, + Columns: []*model.IndexColumn{{Name: userIDCol.Name}}, + }) + + case "hasPK": + userIDCol.Flag = mysql.NotNullFlag | mysql.PriKeyFlag | mysql.BinaryFlag | mysql.NoDefaultValueFlag | mysql.UniqueKeyFlag + userNameCol.Flag = mysql.NotNullFlag | mysql.PriKeyFlag | mysql.NoDefaultValueFlag + + t.Indices = append(t.Indices, &model.IndexInfo{ + Primary: true, + Unique: true, + Columns: []*model.IndexColumn{{Name: userIDCol.Name}, {Name: userNameCol.Name}}, + }) + } + + return t +} + +func testGenRandomDatums(c *C, cols []*model.ColumnInfo) (datums []types.Datum) { + for i := 0; i < len(cols); i++ { + datum, _ := testGenDatum(c, cols[i], i) + datums = append(datums, datum) + } + + return +} + +func testGenDeleteBinlog(c *C, t *model.TableInfo, r []types.Datum) []byte { + var data []byte + var err error + + sc := &stmtctx.StatementContext{TimeZone: time.Local} + colIDs := make([]int64, len(t.Columns)) + for i, col := range t.Columns { + colIDs[i] = col.ID + } + data, err = tablecodec.EncodeRow(sc, r, colIDs, nil, nil) + c.Assert(err, IsNil) + return data +} + +// generate raw row data by column.Type +func testGenDatum(c *C, col *model.ColumnInfo, base int) (types.Datum, interface{}) { + var d types.Datum + var e interface{} + switch col.Tp { + case mysql.TypeTiny, mysql.TypeInt24, mysql.TypeShort, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear: + if mysql.HasUnsignedFlag(col.Flag) { + d.SetUint64(uint64(base)) + e = int64(base) + } else { + d.SetInt64(int64(base)) + e = int64(base) + } + case mysql.TypeFloat: + d.SetFloat32(float32(base)) + e = float32(base) + case mysql.TypeDouble: + d.SetFloat64(float64(base)) + e = float64(base) + case mysql.TypeNewDecimal: + d.SetMysqlDecimal(types.NewDecFromInt(int64(base))) + e = fmt.Sprintf("%v", base) + case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar: + baseVal := "test" + val := "" + for i := 0; i < base; i++ { + val = fmt.Sprintf("%s%s", val, baseVal) + } + d.SetString(val) + e = []byte(val) + case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + baseVal := "test" + val := "" + for i := 0; i < base; i++ { + val = fmt.Sprintf("%s%s", val, baseVal) + } + d.SetBytes([]byte(val)) + e = []byte(val) + case mysql.TypeDuration: + duration, err := types.ParseDuration(new(stmtctx.StatementContext), "10:10:10", 0) + c.Assert(err, IsNil) + d.SetMysqlDuration(duration) + e = "10:10:10" + case mysql.TypeDate, mysql.TypeNewDate: + t := types.CurrentTime(mysql.TypeDate) + d.SetMysqlTime(t) + e = t.String() + case mysql.TypeTimestamp: + t := types.CurrentTime(mysql.TypeTimestamp) + d.SetMysqlTime(t) + e = t.String() + case mysql.TypeDatetime: + t := types.CurrentTime(mysql.TypeDatetime) + d.SetMysqlTime(t) + e = t.String() + case mysql.TypeBit: + bit, err := types.ParseBitStr("0b01") + c.Assert(err, IsNil) + d.SetMysqlBit(bit) + case mysql.TypeSet: + elems := []string{"a", "b", "c", "d"} + set, err := types.ParseSetName(elems, elems[base-1]) + c.Assert(err, IsNil) + d.SetMysqlSet(set) + e = set.Value + case mysql.TypeEnum: + elems := []string{"male", "female"} + enum, err := types.ParseEnumName(elems, elems[base-1]) + c.Assert(err, IsNil) + d.SetMysqlEnum(enum) + e = enum.Value + } + return d, e +} + +func testGenInsertBinlog(c *C, t *model.TableInfo, r []types.Datum) []byte { + sc := &stmtctx.StatementContext{TimeZone: time.Local} + var recordID int64 = 11 + + colIDs := make([]int64, 0, len(r)) + row := make([]types.Datum, 0, len(r)) + for idx, col := range t.Columns { + if testIsPKHandleColumn(t, col) { + recordID = r[idx].GetInt64() + continue + } + + colIDs = append(colIDs, col.ID) + row = append(row, r[idx]) + } + + value, err := tablecodec.EncodeRow(sc, row, colIDs, nil, nil) + c.Assert(err, IsNil) + + handleVal, _ := codec.EncodeValue(sc, nil, types.NewIntDatum(recordID)) + bin := append(handleVal, value...) + return bin +} + +func testGenUpdateBinlog(c *C, t *model.TableInfo, oldData []types.Datum, newData []types.Datum) []byte { + sc := &stmtctx.StatementContext{TimeZone: time.Local} + colIDs := make([]int64, 0, len(t.Columns)) + for _, col := range t.Columns { + colIDs = append(colIDs, col.ID) + } + + var bin []byte + value, err := tablecodec.EncodeRow(sc, newData, colIDs, nil, nil) + c.Assert(err, IsNil) + oldValue, err := tablecodec.EncodeRow(sc, oldData, colIDs, nil, nil) + c.Assert(err, IsNil) + bin = append(oldValue, value...) + return bin +} + +func testIsPKHandleColumn(table *model.TableInfo, column *model.ColumnInfo) bool { + return mysql.HasPriKeyFlag(column.Flag) && table.PKIsHandle +} diff --git a/drainer/translator/translator.go b/drainer/translator/translator.go index efa07c5a3..69c0a9288 100644 --- a/drainer/translator/translator.go +++ b/drainer/translator/translator.go @@ -128,3 +128,64 @@ func getDefaultOrZeroValue(col *model.ColumnInfo) types.Datum { return table.GetZeroValue(col) } + +// DecodeOldAndNewRow decodes a byte slice into datums with a existing row map. +// Row layout: colID1, value1, colID2, value2, ..... +func DecodeOldAndNewRow(b []byte, cols map[int64]*types.FieldType, loc *time.Location) (map[int64]types.Datum, map[int64]types.Datum, error) { + if b == nil { + return nil, nil, nil + } + if b[0] == codec.NilFlag { + return nil, nil, nil + } + + cnt := 0 + var ( + data []byte + err error + oldRow = make(map[int64]types.Datum, len(cols)) + newRow = make(map[int64]types.Datum, len(cols)) + ) + for len(b) > 0 { + // Get col id. + data, b, err = codec.CutOne(b) + if err != nil { + return nil, nil, errors.Trace(err) + } + _, cid, err := codec.DecodeOne(data) + if err != nil { + return nil, nil, errors.Trace(err) + } + // Get col value. + data, b, err = codec.CutOne(b) + if err != nil { + return nil, nil, errors.Trace(err) + } + id := cid.GetInt64() + ft, ok := cols[id] + if ok { + v, err := tablecodec.DecodeColumnValue(data, ft, loc) + if err != nil { + return nil, nil, errors.Trace(err) + } + + if _, ok := oldRow[id]; ok { + newRow[id] = v + } else { + oldRow[id] = v + } + + cnt++ + if cnt == len(cols)*2 { + // Get enough data. + break + } + } + } + + if cnt != len(cols)*2 || len(newRow) != len(oldRow) { + return nil, nil, errors.Errorf("row data is corrupted %v", b) + } + + return oldRow, newRow, nil +} diff --git a/drainer/translator/translator_test.go b/drainer/translator/translator_test.go index ec9645f4c..dbdb4d257 100644 --- a/drainer/translator/translator_test.go +++ b/drainer/translator/translator_test.go @@ -5,16 +5,10 @@ import ( "strings" "testing" - "time" - . "github.com/pingcap/check" "github.com/pingcap/parser/model" parsermysql "github.com/pingcap/parser/mysql" - "github.com/pingcap/tidb/mysql" - "github.com/pingcap/tidb/sessionctx/stmtctx" - "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/util/codec" ) func TestClient(t *testing.T) { @@ -157,64 +151,6 @@ func testGenDDLSQL(c *C, s SQLTranslator) { c.Assert(sql, Equals, "use `t`; drop table t;") } -func testGenInsertBinlog(c *C, t *model.TableInfo, r []types.Datum) []byte { - sc := &stmtctx.StatementContext{TimeZone: time.Local} - recordID := int64(11) - for _, col := range t.Columns { - if testIsPKHandleColumn(t, col) { - recordID = r[col.Offset].GetInt64() - break - } - } - - colIDs := make([]int64, 0, len(r)) - row := make([]types.Datum, 0, len(r)) - for _, col := range t.Columns { - if testIsPKHandleColumn(t, col) { - continue - } - colIDs = append(colIDs, col.ID) - row = append(row, r[col.Offset]) - } - - value, err := tablecodec.EncodeRow(sc, row, colIDs, nil, nil) - c.Assert(err, IsNil) - - handleVal, _ := codec.EncodeValue(sc, nil, types.NewIntDatum(recordID)) - bin := append(handleVal, value...) - return bin -} - -func testGenUpdateBinlog(c *C, t *model.TableInfo, oldData []types.Datum, newData []types.Datum) []byte { - sc := &stmtctx.StatementContext{TimeZone: time.Local} - colIDs := make([]int64, 0, len(t.Columns)) - for _, col := range t.Columns { - colIDs = append(colIDs, col.ID) - } - - var bin []byte - value, err := tablecodec.EncodeRow(sc, newData, colIDs, nil, nil) - c.Assert(err, IsNil) - oldValue, err := tablecodec.EncodeRow(sc, oldData, colIDs, nil, nil) - c.Assert(err, IsNil) - bin = append(oldValue, value...) - return bin -} - -func testGenDeleteBinlog(c *C, t *model.TableInfo, r []types.Datum) []byte { - var data []byte - var err error - - sc := &stmtctx.StatementContext{TimeZone: time.Local} - colIDs := make([]int64, len(t.Columns)) - for i, col := range t.Columns { - colIDs[i] = col.ID - } - data, err = tablecodec.EncodeRow(sc, r, colIDs, nil, nil) - c.Assert(err, IsNil) - return data -} - func testGenRowData(c *C, cols []*model.ColumnInfo, base int) ([]types.Datum, []interface{}, []string) { datas := make([]types.Datum, len(cols)) excepted := make([]interface{}, len(cols)) @@ -228,162 +164,3 @@ func testGenRowData(c *C, cols []*model.ColumnInfo, base int) ([]types.Datum, [] } return datas, excepted, keys } - -// generate raw row data by column.Type -func testGenDatum(c *C, col *model.ColumnInfo, base int) (types.Datum, interface{}) { - var d types.Datum - var e interface{} - switch col.Tp { - case mysql.TypeTiny, mysql.TypeInt24, mysql.TypeShort, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear: - if mysql.HasUnsignedFlag(col.Flag) { - d.SetUint64(uint64(base)) - e = int64(base) - } else { - d.SetInt64(int64(base)) - e = int64(base) - } - case mysql.TypeFloat: - d.SetFloat32(float32(base)) - e = float32(base) - case mysql.TypeDouble: - d.SetFloat64(float64(base)) - e = float64(base) - case mysql.TypeNewDecimal: - d.SetMysqlDecimal(types.NewDecFromInt(int64(base))) - e = fmt.Sprintf("%v", base) - case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar: - baseVal := "test" - val := "" - for i := 0; i < base; i++ { - val = fmt.Sprintf("%s%s", val, baseVal) - } - d.SetString(val) - e = []byte(val) - case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - baseVal := "test" - val := "" - for i := 0; i < base; i++ { - val = fmt.Sprintf("%s%s", val, baseVal) - } - d.SetBytes([]byte(val)) - e = []byte(val) - case mysql.TypeDuration: - duration, err := types.ParseDuration(new(stmtctx.StatementContext), "10:10:10", 0) - c.Assert(err, IsNil) - d.SetMysqlDuration(duration) - e = "10:10:10" - case mysql.TypeDate, mysql.TypeNewDate: - t := types.CurrentTime(mysql.TypeDate) - d.SetMysqlTime(t) - e = t.String() - case mysql.TypeTimestamp: - t := types.CurrentTime(mysql.TypeTimestamp) - d.SetMysqlTime(t) - e = t.String() - case mysql.TypeDatetime: - t := types.CurrentTime(mysql.TypeDatetime) - d.SetMysqlTime(t) - e = t.String() - case mysql.TypeBit: - bit, err := types.ParseBitStr("0b01") - c.Assert(err, IsNil) - d.SetMysqlBit(bit) - case mysql.TypeSet: - elems := []string{"a", "b", "c", "d"} - set, err := types.ParseSetName(elems, elems[base-1]) - c.Assert(err, IsNil) - d.SetMysqlSet(set) - e = set.Value - case mysql.TypeEnum: - elems := []string{"male", "female"} - enum, err := types.ParseEnumName(elems, elems[base-1]) - c.Assert(err, IsNil) - d.SetMysqlEnum(enum) - e = enum.Value - } - return d, e -} - -// hasID: create table t(id int primary key, name varchar(45), sex enum("male", "female")); -// hasPK: create table t(id int, name varchar(45), sex enum("male", "female"), PRIMARY KEY(id, name)); -// normal: create table t(id int, name varchar(45), sex enum("male", "female")); -func testGenTable(tt string) *model.TableInfo { - t := &model.TableInfo{State: model.StatePublic} - t.Name = model.NewCIStr("account") - - // the hard values are from TiDB :-), so just ingore them - userIDCol := &model.ColumnInfo{ - ID: 1, - Name: model.NewCIStr("ID"), - Offset: 0, - FieldType: types.FieldType{ - Tp: mysql.TypeLong, - Flag: mysql.BinaryFlag, - Flen: 11, - Decimal: -1, - Charset: "binary", - Collate: "binary", - }, - State: model.StatePublic, - } - - userNameCol := &model.ColumnInfo{ - ID: 2, - Name: model.NewCIStr("NAME"), - Offset: 1, - FieldType: types.FieldType{ - Tp: mysql.TypeVarchar, - Flag: 0, - Flen: 45, - Decimal: -1, - Charset: "utf8", - Collate: "utf8_unicode_ci", - }, - State: model.StatePublic, - } - - sexCol := &model.ColumnInfo{ - ID: 3, - Name: model.NewCIStr("SEX"), - Offset: 2, - FieldType: types.FieldType{ - Tp: mysql.TypeEnum, - Flag: mysql.BinaryFlag, - Flen: -1, - Decimal: -1, - Charset: "binary", - Collate: "binary", - Elems: []string{"male", "female"}, - }, - State: model.StatePublic, - } - - t.Columns = []*model.ColumnInfo{userIDCol, userNameCol, sexCol} - - switch tt { - case "hasID": - userIDCol.Flag = mysql.NotNullFlag | mysql.PriKeyFlag | mysql.BinaryFlag | mysql.NoDefaultValueFlag - - t.PKIsHandle = true - t.Indices = append(t.Indices, &model.IndexInfo{ - Primary: true, - Columns: []*model.IndexColumn{{Name: userIDCol.Name}}, - }) - - case "hasPK": - userIDCol.Flag = mysql.NotNullFlag | mysql.PriKeyFlag | mysql.BinaryFlag | mysql.NoDefaultValueFlag | mysql.UniqueKeyFlag - userNameCol.Flag = mysql.NotNullFlag | mysql.PriKeyFlag | mysql.NoDefaultValueFlag - - t.Indices = append(t.Indices, &model.IndexInfo{ - Primary: true, - Unique: true, - Columns: []*model.IndexColumn{{Name: userIDCol.Name}, {Name: userNameCol.Name}}, - }) - } - - return t -} - -func testIsPKHandleColumn(table *model.TableInfo, column *model.ColumnInfo) bool { - return mysql.HasPriKeyFlag(column.Flag) && table.PKIsHandle -} diff --git a/pkg/loader/load.go b/pkg/loader/load.go index f344ecb60..09d8aa11b 100644 --- a/pkg/loader/load.go +++ b/pkg/loader/load.go @@ -117,7 +117,7 @@ func NewLoader(db *gosql.DB, opt ...Option) (*Loader, error) { } func (s *Loader) metricsInputTxn(txn *Txn) { - if s.metrics == nil { + if s.metrics == nil || s.metrics.EventCounterVec == nil { return } @@ -337,7 +337,7 @@ func (s *Loader) execDMLs(dmls []*DML) error { errg, _ := errgroup.WithContext(context.Background()) executor := newExecutor(s.db).withBatchSize(s.batchSize) - if s.metrics != nil { + if s.metrics != nil && s.metrics.QueryHistogramVec != nil { executor = executor.withQueryHistogramVec(s.metrics.QueryHistogramVec) } diff --git a/pkg/loader/util.go b/pkg/loader/util.go index ea6707109..970bc04d5 100644 --- a/pkg/loader/util.go +++ b/pkg/loader/util.go @@ -4,6 +4,7 @@ import ( gosql "database/sql" "fmt" "hash/crc32" + "net/url" "strings" "github.com/pingcap/errors" @@ -133,9 +134,13 @@ WHERE table_schema = ? AND table_name = ?;` return } -// CreateDB return sql.DB -func CreateDB(user string, password string, host string, port int) (db *gosql.DB, err error) { +// CreateDBWithSQLMode return sql.DB +func CreateDBWithSQLMode(user string, password string, host string, port int, sqlMode *string) (db *gosql.DB, err error) { dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/?charset=utf8mb4,utf8&interpolateParams=true&readTimeout=1m&multiStatements=true", user, password, host, port) + if sqlMode != nil { + // same as "set sql_mode = ''" + dsn += "&sql_mode='" + url.QueryEscape(*sqlMode) + "'" + } db, err = gosql.Open("mysql", dsn) if err != nil { @@ -144,6 +149,11 @@ func CreateDB(user string, password string, host string, port int) (db *gosql.DB return } +// CreateDB return sql.DB +func CreateDB(user string, password string, host string, port int) (db *gosql.DB, err error) { + return CreateDBWithSQLMode(user, password, host, port, nil) +} + func quoteSchema(schema string, table string) string { return fmt.Sprintf("`%s`.`%s`", escapeName(schema), escapeName(table)) }