From 67a98f8dcb7d1d90396efecae4a9adf50f34cb97 Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Wed, 20 May 2020 13:55:18 +0800 Subject: [PATCH] executor: fix the issue that UNIQUE constraint on boolean column results in an incorrect result in a comparison (#17245) --- executor/batch_checker.go | 2 +- executor/distsql.go | 2 +- executor/executor.go | 2 +- executor/insert.go | 2 +- executor/insert_common.go | 16 ++++++++-------- executor/point_get.go | 5 ++++- executor/point_get_test.go | 13 +++++++++++++ executor/union_scan.go | 2 +- executor/update.go | 4 ++-- executor/write.go | 2 +- table/column.go | 16 ++++++++++------ table/column_test.go | 10 +++++----- util/rowDecoder/decoder.go | 2 +- 13 files changed, 49 insertions(+), 29 deletions(-) diff --git a/executor/batch_checker.go b/executor/batch_checker.go index 15b23829d3027..69cf451cfc5a7 100644 --- a/executor/batch_checker.go +++ b/executor/batch_checker.go @@ -198,7 +198,7 @@ func getOldRow(ctx context.Context, sctx sessionctx.Context, txn kv.Transaction, if err != nil { return nil, err } - oldRow[col.Offset], err = table.CastValue(sctx, val, col.ToInfo(), false) + oldRow[col.Offset], err = table.CastValue(sctx, val, col.ToInfo(), false, false) if err != nil { return nil, err } diff --git a/executor/distsql.go b/executor/distsql.go index 3a72094af8b7c..69996dca4eebf 100644 --- a/executor/distsql.go +++ b/executor/distsql.go @@ -863,7 +863,7 @@ func (w *tableWorker) compareData(ctx context.Context, task *lookupTableTask, ta if err != nil { return errors.Trace(err) } - val, err = table.CastValue(tblReaderExec.ctx, val, col.ColumnInfo, false) + val, err = table.CastValue(tblReaderExec.ctx, val, col.ColumnInfo, false, false) if err != nil { return errors.Trace(err) } diff --git a/executor/executor.go b/executor/executor.go index dc184da341fcf..b602f6cf9c120 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1715,7 +1715,7 @@ func FillVirtualColumnValue(virtualRetTypes []*types.FieldType, virtualColumnInd } // Because the expression might return different type from // the generated column, we should wrap a CAST on the result. - castDatum, err := table.CastValue(sctx, datum, columns[idx], true) + castDatum, err := table.CastValue(sctx, datum, columns[idx], false, true) if err != nil { return err } diff --git a/executor/insert.go b/executor/insert.go index 7df70c11a57e0..c00d2dc82088a 100644 --- a/executor/insert.go +++ b/executor/insert.go @@ -333,7 +333,7 @@ func (e *InsertExec) doDupRowUpdate(ctx context.Context, handle int64, oldRow [] if err1 != nil { return nil, false, 0, err1 } - e.row4Update[col.Col.Index], err1 = table.CastValue(e.ctx, val, col.Col.ToInfo(), false) + e.row4Update[col.Col.Index], err1 = table.CastValue(e.ctx, val, col.Col.ToInfo(), false, false) if err1 != nil { return nil, false, 0, err1 } diff --git a/executor/insert_common.go b/executor/insert_common.go index bfdba11fb2098..a33c319b8bb00 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -320,7 +320,7 @@ func (e *InsertValues) evalRow(ctx context.Context, list []expression.Expression if err = e.handleErr(e.insertColumns[i], &val, rowIdx, err); err != nil { return nil, err } - val1, err := table.CastValue(e.ctx, val, e.insertColumns[i].ToInfo(), false) + val1, err := table.CastValue(e.ctx, val, e.insertColumns[i].ToInfo(), false, false) if err = e.handleErr(e.insertColumns[i], &val, rowIdx, err); err != nil { return nil, err } @@ -349,7 +349,7 @@ func (e *InsertValues) fastEvalRow(ctx context.Context, list []expression.Expres if err = e.handleErr(e.insertColumns[i], &val, rowIdx, err); err != nil { return nil, err } - val1, err := table.CastValue(e.ctx, val, e.insertColumns[i].ToInfo(), false) + val1, err := table.CastValue(e.ctx, val, e.insertColumns[i].ToInfo(), false, false) if err = e.handleErr(e.insertColumns[i], &val, rowIdx, err); err != nil { return nil, err } @@ -473,7 +473,7 @@ func (e *InsertValues) getRow(ctx context.Context, vals []types.Datum) ([]types. row := make([]types.Datum, len(e.Table.Cols())) hasValue := make([]bool, len(e.Table.Cols())) for i, v := range vals { - casted, err := table.CastValue(e.ctx, v, e.insertColumns[i].ToInfo(), false) + casted, err := table.CastValue(e.ctx, v, e.insertColumns[i].ToInfo(), false, false) if e.handleErr(nil, &v, 0, err) != nil { return nil, err } @@ -576,7 +576,7 @@ func (e *InsertValues) fillRow(ctx context.Context, row []types.Datum, hasValue if e.handleErr(gCol, &val, 0, err) != nil { return nil, err } - row[colIdx], err = table.CastValue(e.ctx, val, gCol.ToInfo(), false) + row[colIdx], err = table.CastValue(e.ctx, val, gCol.ToInfo(), false, false) if err != nil { return nil, err } @@ -723,7 +723,7 @@ func (e *InsertValues) lazyAdjustAutoIncrementDatum(ctx context.Context, rows [] retryInfo.AddAutoIncrementID(id) // The value of d is adjusted by auto ID, so we need to cast it again. - d, err := table.CastValue(e.ctx, d, col.ToInfo(), false) + d, err := table.CastValue(e.ctx, d, col.ToInfo(), false, false) if err != nil { return nil, err } @@ -736,7 +736,7 @@ func (e *InsertValues) lazyAdjustAutoIncrementDatum(ctx context.Context, rows [] retryInfo.AddAutoIncrementID(recordID) // the value of d is adjusted by auto ID, so we need to cast it again. - autoDatum, err = table.CastValue(e.ctx, autoDatum, col.ToInfo(), false) + autoDatum, err = table.CastValue(e.ctx, autoDatum, col.ToInfo(), false, false) if err != nil { return nil, err } @@ -796,7 +796,7 @@ func (e *InsertValues) adjustAutoIncrementDatum(ctx context.Context, d types.Dat retryInfo.AddAutoIncrementID(recordID) // the value of d is adjusted by auto ID, so we need to cast it again. - casted, err := table.CastValue(e.ctx, d, c.ToInfo(), false) + casted, err := table.CastValue(e.ctx, d, c.ToInfo(), false, false) if err != nil { return types.Datum{}, err } @@ -878,7 +878,7 @@ func (e *InsertValues) adjustAutoRandomDatum(ctx context.Context, d types.Datum, d.SetAutoID(recordID, c.Flag) retryInfo.AddAutoRandomID(recordID) - casted, err := table.CastValue(e.ctx, d, c.ToInfo(), false) + casted, err := table.CastValue(e.ctx, d, c.ToInfo(), false, false) if err != nil { return types.Datum{}, err } diff --git a/executor/point_get.go b/executor/point_get.go index 13957aef77d87..aa0e6a2b37552 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -310,7 +310,10 @@ func encodeIndexKey(e *baseExecutor, tblInfo *model.TableInfo, idxInfo *model.In str, err = idxVals[i].ToString() idxVals[i].SetString(str, colInfo.FieldType.Collate) } else { - idxVals[i], err = table.CastValue(e.ctx, idxVals[i], colInfo, false) + idxVals[i], err = table.CastValue(e.ctx, idxVals[i], colInfo, true, false) + if types.ErrOverflow.Equal(err) { + return nil, kv.ErrNotExist + } } if err != nil { return nil, err diff --git a/executor/point_get_test.go b/executor/point_get_test.go index 89a09d9eade62..24ef8513298f4 100644 --- a/executor/point_get_test.go +++ b/executor/point_get_test.go @@ -112,6 +112,19 @@ func (s *testPointGetSuite) TestPointGet(c *C) { " ")) } +func (s *testPointGetSuite) TestPointGetOverflow(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t0") + tk.MustExec("CREATE TABLE t0(c1 BOOL UNIQUE)") + tk.MustExec("INSERT INTO t0(c1) VALUES (-128)") + tk.MustExec("INSERT INTO t0(c1) VALUES (127)") + tk.MustQuery("SELECT t0.c1 FROM t0 WHERE t0.c1=-129").Check(testkit.Rows()) // no result + tk.MustQuery("SELECT t0.c1 FROM t0 WHERE t0.c1=-128").Check(testkit.Rows("-128")) + tk.MustQuery("SELECT t0.c1 FROM t0 WHERE t0.c1=128").Check(testkit.Rows()) + tk.MustQuery("SELECT t0.c1 FROM t0 WHERE t0.c1=127").Check(testkit.Rows("127")) +} + func (s *testPointGetSuite) TestPointGetCharPK(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec(`use test;`) diff --git a/executor/union_scan.go b/executor/union_scan.go index d4e922bb1e9ad..e2e744c4b6aff 100644 --- a/executor/union_scan.go +++ b/executor/union_scan.go @@ -174,7 +174,7 @@ func (us *UnionScanExec) Next(ctx context.Context, req *chunk.Chunk) error { } // Because the expression might return different type from // the generated column, we should wrap a CAST on the result. - castDatum, err := table.CastValue(us.ctx, datum, us.columns[idx], true) + castDatum, err := table.CastValue(us.ctx, datum, us.columns[idx], false, true) if err != nil { return err } diff --git a/executor/update.go b/executor/update.go index 565fdf8b78e5d..e5ebe0d783a36 100644 --- a/executor/update.go +++ b/executor/update.go @@ -214,7 +214,7 @@ func (e *UpdateExec) fastComposeNewRow(rowIdx int, oldRow []types.Datum, cols [] // info of `_tidb_rowid` column is nil. // No need to cast `_tidb_rowid` column value. if cols[assign.Col.Index] != nil { - val, err = table.CastValue(e.ctx, val, cols[assign.Col.Index].ColumnInfo, false) + val, err = table.CastValue(e.ctx, val, cols[assign.Col.Index].ColumnInfo, false, false) if err = e.handleErr(assign.ColName, rowIdx, err); err != nil { return nil, err } @@ -241,7 +241,7 @@ func (e *UpdateExec) composeNewRow(rowIdx int, oldRow []types.Datum, cols []*tab // info of `_tidb_rowid` column is nil. // No need to cast `_tidb_rowid` column value. if cols[assign.Col.Index] != nil { - val, err = table.CastValue(e.ctx, val, cols[assign.Col.Index].ColumnInfo, false) + val, err = table.CastValue(e.ctx, val, cols[assign.Col.Index].ColumnInfo, false, false) if err = e.handleErr(assign.ColName, rowIdx, err); err != nil { return nil, err } diff --git a/executor/write.go b/executor/write.go index 890538e1f25e8..33bdff8afcbb0 100644 --- a/executor/write.go +++ b/executor/write.go @@ -76,7 +76,7 @@ func updateRecord(ctx context.Context, sctx sessionctx.Context, h int64, oldData for i, col := range t.Cols() { if modified[i] { // Cast changed fields with respective columns. - v, err := table.CastValue(sctx, newData[i], col.ToInfo(), false) + v, err := table.CastValue(sctx, newData[i], col.ToInfo(), false, false) if err != nil { return false, false, 0, err } diff --git a/table/column.go b/table/column.go index 307e4c36ef0cd..909f673e7512d 100644 --- a/table/column.go +++ b/table/column.go @@ -143,7 +143,7 @@ func CastValues(ctx sessionctx.Context, rec []types.Datum, cols []*Column) (err sc := ctx.GetSessionVars().StmtCtx for _, c := range cols { var converted types.Datum - converted, err = CastValue(ctx, rec[c.Offset], c.ToInfo(), false) + converted, err = CastValue(ctx, rec[c.Offset], c.ToInfo(), false, false) if err != nil { if sc.DupKeyAsWarning { sc.AppendWarning(err) @@ -168,15 +168,19 @@ func handleWrongUtf8Value(ctx sessionctx.Context, col *model.ColumnInfo, casted } // CastValue casts a value based on column type. -// If forceIgnoreTruncate is true, the err returned will be always nil. +// If forceIgnoreTruncate is true, truncated errors will be ignored. +// If returnOverflow is true, don't handle overflow errors in this function. // It's safe now and it's the same as the behavior of select statement. // Set it to true only in FillVirtualColumnValue and UnionScanExec.Next() // If the handle of err is changed latter, the behavior of forceIgnoreTruncate also need to change. // TODO: change the third arg to TypeField. Not pass ColumnInfo. -func CastValue(ctx sessionctx.Context, val types.Datum, col *model.ColumnInfo, forceIgnoreTruncate bool) (casted types.Datum, err error) { +func CastValue(ctx sessionctx.Context, val types.Datum, col *model.ColumnInfo, returnOverflow, forceIgnoreTruncate bool) (casted types.Datum, err error) { sc := ctx.GetSessionVars().StmtCtx casted, err = val.ConvertTo(sc, &col.FieldType) // TODO: make sure all truncate errors are handled by ConvertTo. + if types.ErrOverflow.Equal(err) && returnOverflow { + return casted, err + } if types.ErrTruncated.Equal(err) { str, err1 := val.ToString() if err1 != nil { @@ -398,7 +402,7 @@ func EvalColDefaultExpr(ctx sessionctx.Context, col *model.ColumnInfo, defaultEx return types.Datum{}, err } // Check the evaluated data type by cast. - value, err := CastValue(ctx, d, col, false) + value, err := CastValue(ctx, d, col, false, false) if err != nil { return types.Datum{}, err } @@ -417,7 +421,7 @@ func getColDefaultExprValue(ctx sessionctx.Context, col *model.ColumnInfo, defau return types.Datum{}, err } // Check the evaluated data type by cast. - value, err := CastValue(ctx, d, col, false) + value, err := CastValue(ctx, d, col, false, false) if err != nil { return types.Datum{}, err } @@ -430,7 +434,7 @@ func getColDefaultValue(ctx sessionctx.Context, col *model.ColumnInfo, defaultVa } if col.Tp != mysql.TypeTimestamp && col.Tp != mysql.TypeDatetime { - value, err := CastValue(ctx, types.NewDatum(defaultVal), col, false) + value, err := CastValue(ctx, types.NewDatum(defaultVal), col, false, false) if err != nil { return types.Datum{}, err } diff --git a/table/column_test.go b/table/column_test.go index 904d12f86fff8..feffb801ebf6a 100644 --- a/table/column_test.go +++ b/table/column_test.go @@ -254,11 +254,11 @@ func (t *testTableSuite) TestCastValue(c *C) { State: model.StatePublic, } colInfo.Charset = mysql.UTF8Charset - val, err := CastValue(ctx, types.Datum{}, &colInfo, false) + val, err := CastValue(ctx, types.Datum{}, &colInfo, false, false) c.Assert(err, Equals, nil) c.Assert(val.GetInt64(), Equals, int64(0)) - val, err = CastValue(ctx, types.NewDatum("test"), &colInfo, false) + val, err = CastValue(ctx, types.NewDatum("test"), &colInfo, false, false) c.Assert(err, Not(Equals), nil) c.Assert(val.GetInt64(), Equals, int64(0)) @@ -278,16 +278,16 @@ func (t *testTableSuite) TestCastValue(c *C) { FieldType: *types.NewFieldType(mysql.TypeString), State: model.StatePublic, } - val, err = CastValue(ctx, types.NewDatum("test"), &colInfoS, false) + val, err = CastValue(ctx, types.NewDatum("test"), &colInfoS, false, false) c.Assert(err, IsNil) c.Assert(val, NotNil) colInfoS.Charset = mysql.UTF8Charset - _, err = CastValue(ctx, types.NewDatum([]byte{0xf0, 0x9f, 0x8c, 0x80}), &colInfoS, false) + _, err = CastValue(ctx, types.NewDatum([]byte{0xf0, 0x9f, 0x8c, 0x80}), &colInfoS, false, false) c.Assert(err, NotNil) colInfoS.Charset = mysql.UTF8MB4Charset - _, err = CastValue(ctx, types.NewDatum([]byte{0xf0, 0x9f, 0x80}), &colInfoS, false) + _, err = CastValue(ctx, types.NewDatum([]byte{0xf0, 0x9f, 0x80}), &colInfoS, false, false) c.Assert(err, NotNil) } diff --git a/util/rowDecoder/decoder.go b/util/rowDecoder/decoder.go index 73bd318d78922..3550cf78ef0df 100644 --- a/util/rowDecoder/decoder.go +++ b/util/rowDecoder/decoder.go @@ -123,7 +123,7 @@ func (rd *RowDecoder) DecodeAndEvalRowWithMap(ctx sessionctx.Context, handle int if err != nil { return nil, err } - val, err = table.CastValue(ctx, val, col.Col.ColumnInfo, false) + val, err = table.CastValue(ctx, val, col.Col.ColumnInfo, false, false) if err != nil { return nil, err }