Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: fix the issue that UNIQUE constraint on boolean column results in an incorrect result in a comparison (#17245) #17306

Merged
merged 1 commit into from
May 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion executor/batch_checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func getOldRow(ctx context.Context, sctx sessionctx.Context, txn kv.Transaction,
if err != nil {
return nil, err
}
oldRow[col.Offset], err = table.CastValue(sctx, val, col.ToInfo(), false)
oldRow[col.Offset], err = table.CastValue(sctx, val, col.ToInfo(), false, false)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion executor/distsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@ func (w *tableWorker) compareData(ctx context.Context, task *lookupTableTask, ta
if err != nil {
return errors.Trace(err)
}
val, err = table.CastValue(tblReaderExec.ctx, val, col.ColumnInfo, false)
val, err = table.CastValue(tblReaderExec.ctx, val, col.ColumnInfo, false, false)
if err != nil {
return errors.Trace(err)
}
Expand Down
2 changes: 1 addition & 1 deletion executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1715,7 +1715,7 @@ func FillVirtualColumnValue(virtualRetTypes []*types.FieldType, virtualColumnInd
}
// Because the expression might return different type from
// the generated column, we should wrap a CAST on the result.
castDatum, err := table.CastValue(sctx, datum, columns[idx], true)
castDatum, err := table.CastValue(sctx, datum, columns[idx], false, true)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion executor/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ func (e *InsertExec) doDupRowUpdate(ctx context.Context, handle int64, oldRow []
if err1 != nil {
return nil, false, 0, err1
}
e.row4Update[col.Col.Index], err1 = table.CastValue(e.ctx, val, col.Col.ToInfo(), false)
e.row4Update[col.Col.Index], err1 = table.CastValue(e.ctx, val, col.Col.ToInfo(), false, false)
if err1 != nil {
return nil, false, 0, err1
}
Expand Down
16 changes: 8 additions & 8 deletions executor/insert_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ func (e *InsertValues) evalRow(ctx context.Context, list []expression.Expression
if err = e.handleErr(e.insertColumns[i], &val, rowIdx, err); err != nil {
return nil, err
}
val1, err := table.CastValue(e.ctx, val, e.insertColumns[i].ToInfo(), false)
val1, err := table.CastValue(e.ctx, val, e.insertColumns[i].ToInfo(), false, false)
if err = e.handleErr(e.insertColumns[i], &val, rowIdx, err); err != nil {
return nil, err
}
Expand Down Expand Up @@ -349,7 +349,7 @@ func (e *InsertValues) fastEvalRow(ctx context.Context, list []expression.Expres
if err = e.handleErr(e.insertColumns[i], &val, rowIdx, err); err != nil {
return nil, err
}
val1, err := table.CastValue(e.ctx, val, e.insertColumns[i].ToInfo(), false)
val1, err := table.CastValue(e.ctx, val, e.insertColumns[i].ToInfo(), false, false)
if err = e.handleErr(e.insertColumns[i], &val, rowIdx, err); err != nil {
return nil, err
}
Expand Down Expand Up @@ -473,7 +473,7 @@ func (e *InsertValues) getRow(ctx context.Context, vals []types.Datum) ([]types.
row := make([]types.Datum, len(e.Table.Cols()))
hasValue := make([]bool, len(e.Table.Cols()))
for i, v := range vals {
casted, err := table.CastValue(e.ctx, v, e.insertColumns[i].ToInfo(), false)
casted, err := table.CastValue(e.ctx, v, e.insertColumns[i].ToInfo(), false, false)
if e.handleErr(nil, &v, 0, err) != nil {
return nil, err
}
Expand Down Expand Up @@ -576,7 +576,7 @@ func (e *InsertValues) fillRow(ctx context.Context, row []types.Datum, hasValue
if e.handleErr(gCol, &val, 0, err) != nil {
return nil, err
}
row[colIdx], err = table.CastValue(e.ctx, val, gCol.ToInfo(), false)
row[colIdx], err = table.CastValue(e.ctx, val, gCol.ToInfo(), false, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -723,7 +723,7 @@ func (e *InsertValues) lazyAdjustAutoIncrementDatum(ctx context.Context, rows []
retryInfo.AddAutoIncrementID(id)

// The value of d is adjusted by auto ID, so we need to cast it again.
d, err := table.CastValue(e.ctx, d, col.ToInfo(), false)
d, err := table.CastValue(e.ctx, d, col.ToInfo(), false, false)
if err != nil {
return nil, err
}
Expand All @@ -736,7 +736,7 @@ func (e *InsertValues) lazyAdjustAutoIncrementDatum(ctx context.Context, rows []
retryInfo.AddAutoIncrementID(recordID)

// the value of d is adjusted by auto ID, so we need to cast it again.
autoDatum, err = table.CastValue(e.ctx, autoDatum, col.ToInfo(), false)
autoDatum, err = table.CastValue(e.ctx, autoDatum, col.ToInfo(), false, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -796,7 +796,7 @@ func (e *InsertValues) adjustAutoIncrementDatum(ctx context.Context, d types.Dat
retryInfo.AddAutoIncrementID(recordID)

// the value of d is adjusted by auto ID, so we need to cast it again.
casted, err := table.CastValue(e.ctx, d, c.ToInfo(), false)
casted, err := table.CastValue(e.ctx, d, c.ToInfo(), false, false)
if err != nil {
return types.Datum{}, err
}
Expand Down Expand Up @@ -878,7 +878,7 @@ func (e *InsertValues) adjustAutoRandomDatum(ctx context.Context, d types.Datum,
d.SetAutoID(recordID, c.Flag)
retryInfo.AddAutoRandomID(recordID)

casted, err := table.CastValue(e.ctx, d, c.ToInfo(), false)
casted, err := table.CastValue(e.ctx, d, c.ToInfo(), false, false)
if err != nil {
return types.Datum{}, err
}
Expand Down
5 changes: 4 additions & 1 deletion executor/point_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,10 @@ func encodeIndexKey(e *baseExecutor, tblInfo *model.TableInfo, idxInfo *model.In
str, err = idxVals[i].ToString()
idxVals[i].SetString(str, colInfo.FieldType.Collate)
} else {
idxVals[i], err = table.CastValue(e.ctx, idxVals[i], colInfo, false)
idxVals[i], err = table.CastValue(e.ctx, idxVals[i], colInfo, true, false)
if types.ErrOverflow.Equal(err) {
return nil, kv.ErrNotExist
}
}
if err != nil {
return nil, err
Expand Down
13 changes: 13 additions & 0 deletions executor/point_get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,19 @@ func (s *testPointGetSuite) TestPointGet(c *C) {
"<nil> <nil>"))
}

func (s *testPointGetSuite) TestPointGetOverflow(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t0")
tk.MustExec("CREATE TABLE t0(c1 BOOL UNIQUE)")
tk.MustExec("INSERT INTO t0(c1) VALUES (-128)")
tk.MustExec("INSERT INTO t0(c1) VALUES (127)")
tk.MustQuery("SELECT t0.c1 FROM t0 WHERE t0.c1=-129").Check(testkit.Rows()) // no result
tk.MustQuery("SELECT t0.c1 FROM t0 WHERE t0.c1=-128").Check(testkit.Rows("-128"))
tk.MustQuery("SELECT t0.c1 FROM t0 WHERE t0.c1=128").Check(testkit.Rows())
tk.MustQuery("SELECT t0.c1 FROM t0 WHERE t0.c1=127").Check(testkit.Rows("127"))
}

func (s *testPointGetSuite) TestPointGetCharPK(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec(`use test;`)
Expand Down
2 changes: 1 addition & 1 deletion executor/union_scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func (us *UnionScanExec) Next(ctx context.Context, req *chunk.Chunk) error {
}
// Because the expression might return different type from
// the generated column, we should wrap a CAST on the result.
castDatum, err := table.CastValue(us.ctx, datum, us.columns[idx], true)
castDatum, err := table.CastValue(us.ctx, datum, us.columns[idx], false, true)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions executor/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func (e *UpdateExec) fastComposeNewRow(rowIdx int, oldRow []types.Datum, cols []
// info of `_tidb_rowid` column is nil.
// No need to cast `_tidb_rowid` column value.
if cols[assign.Col.Index] != nil {
val, err = table.CastValue(e.ctx, val, cols[assign.Col.Index].ColumnInfo, false)
val, err = table.CastValue(e.ctx, val, cols[assign.Col.Index].ColumnInfo, false, false)
if err = e.handleErr(assign.ColName, rowIdx, err); err != nil {
return nil, err
}
Expand All @@ -241,7 +241,7 @@ func (e *UpdateExec) composeNewRow(rowIdx int, oldRow []types.Datum, cols []*tab
// info of `_tidb_rowid` column is nil.
// No need to cast `_tidb_rowid` column value.
if cols[assign.Col.Index] != nil {
val, err = table.CastValue(e.ctx, val, cols[assign.Col.Index].ColumnInfo, false)
val, err = table.CastValue(e.ctx, val, cols[assign.Col.Index].ColumnInfo, false, false)
if err = e.handleErr(assign.ColName, rowIdx, err); err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion executor/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func updateRecord(ctx context.Context, sctx sessionctx.Context, h int64, oldData
for i, col := range t.Cols() {
if modified[i] {
// Cast changed fields with respective columns.
v, err := table.CastValue(sctx, newData[i], col.ToInfo(), false)
v, err := table.CastValue(sctx, newData[i], col.ToInfo(), false, false)
if err != nil {
return false, false, 0, err
}
Expand Down
16 changes: 10 additions & 6 deletions table/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func CastValues(ctx sessionctx.Context, rec []types.Datum, cols []*Column) (err
sc := ctx.GetSessionVars().StmtCtx
for _, c := range cols {
var converted types.Datum
converted, err = CastValue(ctx, rec[c.Offset], c.ToInfo(), false)
converted, err = CastValue(ctx, rec[c.Offset], c.ToInfo(), false, false)
if err != nil {
if sc.DupKeyAsWarning {
sc.AppendWarning(err)
Expand All @@ -168,15 +168,19 @@ func handleWrongUtf8Value(ctx sessionctx.Context, col *model.ColumnInfo, casted
}

// CastValue casts a value based on column type.
// If forceIgnoreTruncate is true, the err returned will be always nil.
// If forceIgnoreTruncate is true, truncated errors will be ignored.
// If returnOverflow is true, don't handle overflow errors in this function.
// It's safe now and it's the same as the behavior of select statement.
// Set it to true only in FillVirtualColumnValue and UnionScanExec.Next()
// If the handle of err is changed latter, the behavior of forceIgnoreTruncate also need to change.
// TODO: change the third arg to TypeField. Not pass ColumnInfo.
func CastValue(ctx sessionctx.Context, val types.Datum, col *model.ColumnInfo, forceIgnoreTruncate bool) (casted types.Datum, err error) {
func CastValue(ctx sessionctx.Context, val types.Datum, col *model.ColumnInfo, returnOverflow, forceIgnoreTruncate bool) (casted types.Datum, err error) {
sc := ctx.GetSessionVars().StmtCtx
casted, err = val.ConvertTo(sc, &col.FieldType)
// TODO: make sure all truncate errors are handled by ConvertTo.
if types.ErrOverflow.Equal(err) && returnOverflow {
return casted, err
}
if types.ErrTruncated.Equal(err) {
str, err1 := val.ToString()
if err1 != nil {
Expand Down Expand Up @@ -398,7 +402,7 @@ func EvalColDefaultExpr(ctx sessionctx.Context, col *model.ColumnInfo, defaultEx
return types.Datum{}, err
}
// Check the evaluated data type by cast.
value, err := CastValue(ctx, d, col, false)
value, err := CastValue(ctx, d, col, false, false)
if err != nil {
return types.Datum{}, err
}
Expand All @@ -417,7 +421,7 @@ func getColDefaultExprValue(ctx sessionctx.Context, col *model.ColumnInfo, defau
return types.Datum{}, err
}
// Check the evaluated data type by cast.
value, err := CastValue(ctx, d, col, false)
value, err := CastValue(ctx, d, col, false, false)
if err != nil {
return types.Datum{}, err
}
Expand All @@ -430,7 +434,7 @@ func getColDefaultValue(ctx sessionctx.Context, col *model.ColumnInfo, defaultVa
}

if col.Tp != mysql.TypeTimestamp && col.Tp != mysql.TypeDatetime {
value, err := CastValue(ctx, types.NewDatum(defaultVal), col, false)
value, err := CastValue(ctx, types.NewDatum(defaultVal), col, false, false)
if err != nil {
return types.Datum{}, err
}
Expand Down
10 changes: 5 additions & 5 deletions table/column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,11 @@ func (t *testTableSuite) TestCastValue(c *C) {
State: model.StatePublic,
}
colInfo.Charset = mysql.UTF8Charset
val, err := CastValue(ctx, types.Datum{}, &colInfo, false)
val, err := CastValue(ctx, types.Datum{}, &colInfo, false, false)
c.Assert(err, Equals, nil)
c.Assert(val.GetInt64(), Equals, int64(0))

val, err = CastValue(ctx, types.NewDatum("test"), &colInfo, false)
val, err = CastValue(ctx, types.NewDatum("test"), &colInfo, false, false)
c.Assert(err, Not(Equals), nil)
c.Assert(val.GetInt64(), Equals, int64(0))

Expand All @@ -278,16 +278,16 @@ func (t *testTableSuite) TestCastValue(c *C) {
FieldType: *types.NewFieldType(mysql.TypeString),
State: model.StatePublic,
}
val, err = CastValue(ctx, types.NewDatum("test"), &colInfoS, false)
val, err = CastValue(ctx, types.NewDatum("test"), &colInfoS, false, false)
c.Assert(err, IsNil)
c.Assert(val, NotNil)

colInfoS.Charset = mysql.UTF8Charset
_, err = CastValue(ctx, types.NewDatum([]byte{0xf0, 0x9f, 0x8c, 0x80}), &colInfoS, false)
_, err = CastValue(ctx, types.NewDatum([]byte{0xf0, 0x9f, 0x8c, 0x80}), &colInfoS, false, false)
c.Assert(err, NotNil)

colInfoS.Charset = mysql.UTF8MB4Charset
_, err = CastValue(ctx, types.NewDatum([]byte{0xf0, 0x9f, 0x80}), &colInfoS, false)
_, err = CastValue(ctx, types.NewDatum([]byte{0xf0, 0x9f, 0x80}), &colInfoS, false, false)
c.Assert(err, NotNil)
}

Expand Down
2 changes: 1 addition & 1 deletion util/rowDecoder/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (rd *RowDecoder) DecodeAndEvalRowWithMap(ctx sessionctx.Context, handle int
if err != nil {
return nil, err
}
val, err = table.CastValue(ctx, val, col.Col.ColumnInfo, false)
val, err = table.CastValue(ctx, val, col.Col.ColumnInfo, false, false)
if err != nil {
return nil, err
}
Expand Down