diff --git a/expression/builtin.go b/expression/builtin.go index 3625602eb8fea..05a978ae8a307 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -27,6 +27,7 @@ import ( "sort" "strings" "sync" + "unicode/utf8" "github.com/gogo/protobuf/proto" "github.com/pingcap/errors" @@ -117,16 +118,67 @@ func CheckIllegalMixCollation(funcName string, args []Expression, evalType types if len(args) < 2 { return nil } - _, _, coercibility, legal := inferCollation(args...) + _, dstCharset, coercibility, legal := inferCollation(args...) if !legal { return illegalMixCollationErr(funcName, args) } if coercibility == CoercibilityNone && evalType != types.ETString { return illegalMixCollationErr(funcName, args) } + + // For every constant and string type arguments, MySQL will try to convert its charset to the inferred charset, + // return illegalMixCollationErr if convert failed. + // e.g. + // select _utf8 'a' collate utf8_general_ci = '😁'; # error + // select _utf8 'a' collate utf8_general_ci = 'ㅂ'; # fine + // it is because '😁' cannot convert to utf8mb3, it is both fine in tidb for now, and I would not change the + // behavior because we tread utf8mb3 same as utf8mb4 except insert, just keep things simple. + // Since we only have utf8mb4 and its subset charset for now, so that just check the data's validity is ok, no + // need to convert it. + // In the future, we may add more charsets support, but it actually still utf8mb4 encoding at runtime, so we + // also no need to do the convert. + for _, arg := range args { + _, ok := arg.(*Constant) + if ok && arg.GetType().EvalType() == types.ETString { + val, isNull, err := arg.EvalString(nil, chunk.Row{}) + if err != nil || isNull { + return err + } + + if !isValidString(val, arg.GetType().Charset, dstCharset) { + return illegalMixCollationErr(funcName, args) + } + } + } + return nil } +// isValidString test if the val is valid for charset dstChs. +// TODO: Remove this function when the full charset is supported, we will have a better way to do this job. +func isValidString(val string, srcChs, dstChs string) bool { + if srcChs == dstChs { + return true + } + switch dstChs { + case charset.CharsetASCII: + for _, c := range val { + if c >= 0x80 { + return false + } + } + case charset.CharsetLatin1: + // TODO: For latin, we now have no convenience way to check it, just let it go. + // We will fix it after the full charset is supported. + default: + // The charset in tidb are all utf8mb4 or its subset except binary, so that we just check binary charset is enough. + if srcChs == charset.CharsetBin { + return utf8.ValidString(val) + } + } + return true +} + func illegalMixCollationErr(funcName string, args []Expression) error { funcName = GetDisplayName(funcName) diff --git a/expression/builtin_compare_test.go b/expression/builtin_compare_test.go index 9d11d8b5ad18a..4746ddc7256d3 100644 --- a/expression/builtin_compare_test.go +++ b/expression/builtin_compare_test.go @@ -332,11 +332,11 @@ func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) { }, } { f0, err := newFunctionForTest(s.ctx, ast.Greatest, s.primitiveValsToConstants(t.args)...) - c.Assert(err, IsNil) - d, err := f0.Eval(chunk.Row{}) if t.getErr { c.Assert(err, NotNil) } else { + c.Assert(err, IsNil) + d, err := f0.Eval(chunk.Row{}) c.Assert(err, IsNil) if t.isNil { c.Assert(d.Kind(), Equals, types.KindNull) @@ -346,11 +346,11 @@ func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) { } f1, err := newFunctionForTest(s.ctx, ast.Least, s.primitiveValsToConstants(t.args)...) - c.Assert(err, IsNil) - d, err = f1.Eval(chunk.Row{}) if t.getErr { c.Assert(err, NotNil) } else { + c.Assert(err, IsNil) + d, err := f1.Eval(chunk.Row{}) c.Assert(err, IsNil) if t.isNil { c.Assert(d.Kind(), Equals, types.KindNull) diff --git a/expression/builtin_control_test.go b/expression/builtin_control_test.go index 7f6e35aaa8626..a382fc6f510d0 100644 --- a/expression/builtin_control_test.go +++ b/expression/builtin_control_test.go @@ -125,11 +125,11 @@ func (s *testEvaluatorSuite) TestIfNull(c *C) { for _, t := range tbl { f, err := newFunctionForTest(s.ctx, ast.Ifnull, s.primitiveValsToConstants([]interface{}{t.arg1, t.arg2})...) - c.Assert(err, IsNil) - d, err := f.Eval(chunk.Row{}) if t.getErr { c.Assert(err, NotNil) } else { + c.Assert(err, IsNil) + d, err := f.Eval(chunk.Row{}) c.Assert(err, IsNil) if t.isNil { c.Assert(d.Kind(), Equals, types.KindNull) diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 185cfa0d90b9e..ed2305b87c09a 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -264,12 +264,12 @@ func (s *testEvaluatorSuite) TestConcatWS(c *C) { c.Assert(err, NotNil) for _, t := range cases { - f, err := newFunctionForTest(s.ctx, fcName, s.primitiveValsToConstants(t.args)...) - c.Assert(err, IsNil) - val, err1 := f.Eval(chunk.Row{}) + f, err1 := newFunctionForTest(s.ctx, fcName, s.primitiveValsToConstants(t.args)...) if t.getErr { c.Assert(err1, NotNil) } else { + c.Assert(err1, IsNil) + val, err1 := f.Eval(chunk.Row{}) c.Assert(err1, IsNil) if t.isNil { c.Assert(val.Kind(), Equals, types.KindNull) @@ -361,11 +361,11 @@ func (s *testEvaluatorSuite) TestLeft(c *C) { } for _, t := range cases { f, err := newFunctionForTest(s.ctx, ast.Left, s.primitiveValsToConstants(t.args)...) - c.Assert(err, IsNil) - v, err := f.Eval(chunk.Row{}) if t.getErr { c.Assert(err, NotNil) } else { + c.Assert(err, IsNil) + v, err := f.Eval(chunk.Row{}) c.Assert(err, IsNil) if t.isNil { c.Assert(v.Kind(), Equals, types.KindNull) @@ -410,11 +410,11 @@ func (s *testEvaluatorSuite) TestRight(c *C) { } for _, t := range cases { f, err := newFunctionForTest(s.ctx, ast.Right, s.primitiveValsToConstants(t.args)...) - c.Assert(err, IsNil) - v, err := f.Eval(chunk.Row{}) if t.getErr { c.Assert(err, NotNil) } else { + c.Assert(err, IsNil) + v, err := f.Eval(chunk.Row{}) c.Assert(err, IsNil) if t.isNil { c.Assert(v.Kind(), Equals, types.KindNull) @@ -652,11 +652,11 @@ func (s *testEvaluatorSuite) TestStrcmp(c *C) { } for _, t := range cases { f, err := newFunctionForTest(s.ctx, ast.Strcmp, s.primitiveValsToConstants(t.args)...) - c.Assert(err, IsNil) - d, err := f.Eval(chunk.Row{}) if t.getErr { c.Assert(err, NotNil) } else { + c.Assert(err, IsNil) + d, err := f.Eval(chunk.Row{}) c.Assert(err, IsNil) if t.isNil { c.Assert(d.Kind(), Equals, types.KindNull) @@ -688,12 +688,12 @@ func (s *testEvaluatorSuite) TestReplace(c *C) { } for i, t := range cases { f, err := newFunctionForTest(s.ctx, ast.Replace, s.primitiveValsToConstants(t.args)...) - c.Assert(err, IsNil, Commentf("test %v", i)) - c.Assert(f.GetType().Flen, Equals, t.flen, Commentf("test %v", i)) - d, err := f.Eval(chunk.Row{}) if t.getErr { c.Assert(err, NotNil, Commentf("test %v", i)) } else { + c.Assert(err, IsNil, Commentf("test %v", i)) + c.Assert(f.GetType().Flen, Equals, t.flen, Commentf("test %v", i)) + d, err := f.Eval(chunk.Row{}) c.Assert(err, IsNil, Commentf("test %v", i)) if t.isNil { c.Assert(d.Kind(), Equals, types.KindNull, Commentf("test %v", i)) @@ -733,11 +733,11 @@ func (s *testEvaluatorSuite) TestSubstring(c *C) { } for _, t := range cases { f, err := newFunctionForTest(s.ctx, ast.Substring, s.primitiveValsToConstants(t.args)...) - c.Assert(err, IsNil) - d, err := f.Eval(chunk.Row{}) if t.getErr { c.Assert(err, NotNil) } else { + c.Assert(err, IsNil) + d, err := f.Eval(chunk.Row{}) c.Assert(err, IsNil) if t.isNil { c.Assert(d.Kind(), Equals, types.KindNull) @@ -841,11 +841,11 @@ func (s *testEvaluatorSuite) TestSubstringIndex(c *C) { } for _, t := range cases { f, err := newFunctionForTest(s.ctx, ast.SubstringIndex, s.primitiveValsToConstants(t.args)...) - c.Assert(err, IsNil) - d, err := f.Eval(chunk.Row{}) if t.getErr { c.Assert(err, NotNil) } else { + c.Assert(err, IsNil) + d, err := f.Eval(chunk.Row{}) c.Assert(err, IsNil) if t.isNil { c.Assert(d.Kind(), Equals, types.KindNull) diff --git a/expression/integration_test.go b/expression/integration_test.go index 72dda661201d2..f5df4cff8fc4a 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -5997,6 +5997,23 @@ func (s *testIntegrationSerialSuite) TestIssue15315(c *C) { tk.MustQuery("select cast('0-1234' as real)").Check(testkit.Rows("0")) } +func (s *testIntegrationSerialSuite) TestIssue23506(c *C) { + collate.SetNewCollationEnabledForTest(true) + defer collate.SetNewCollationEnabledForTest(false) + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + // FIXME: MySQL will get an error [HY000][3854] Cannot convert string '\x80' from binary to utf8mb4, at this moment, + // we can only get an error like this, will try to fix in the future. + tk.MustGetErrMsg("select 'a' collate utf8mb4_general_ci = 0x80;", "[expression:1267]Illegal mix of collations (utf8mb4_general_ci,EXPLICIT) and (binary,COERCIBLE) for operation '='") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(a char(10), primary key (a)) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci;") + tk.MustExec("insert into t1 values ('a')") + tk.MustGetErrMsg("select * from t1 where a > 0x80;", "[expression:1267]Illegal mix of collations (utf8mb4_general_ci,IMPLICIT) and (binary,COERCIBLE) for operation '>'") + tk.MustGetErrMsg("select * from t1 where a between 0x808E and 0x80BD;", "[expression:1267]Illegal mix of collations (utf8mb4_general_ci,IMPLICIT) and (binary,COERCIBLE) for operation '>='") + tk.MustGetErrMsg("select _ascii 'a' collate ascii_bin = 'ㅂ';", "[expression:1267]Illegal mix of collations (ascii_bin,EXPLICIT) and (utf8mb4_bin,COERCIBLE) for operation '='") + tk.MustGetErrMsg("select _ascii 'a' collate ascii_bin = 0x80;", "[expression:1267]Illegal mix of collations (ascii_bin,EXPLICIT) and (binary,COERCIBLE) for operation '='") +} + func (s *testIntegrationSuite) TestNotExistFunc(c *C) { tk := testkit.NewTestKit(c, s.store) diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index e0324123bbaae..a00ffe95a5b4f 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -240,7 +240,7 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"space(c_int_d)", mysql.TypeLongBlob, mysql.DefaultCharset, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"CONCAT(c_binary, c_int_d)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 40, types.UnspecifiedLength}, {"CONCAT(c_bchar, c_int_d)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 40, types.UnspecifiedLength}, - {"CONCAT(c_bchar, 0x80)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 23, types.UnspecifiedLength}, + {"CONCAT(c_bchar, 0x63)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 23, types.UnspecifiedLength}, {"CONCAT('T', 'i', 'DB')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0 | mysql.NotNullFlag, 4, types.UnspecifiedLength}, {"CONCAT('T', 'i', 'DB', c_binary)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 24, types.UnspecifiedLength}, {"CONCAT_WS('-', 'T', 'i', 'DB')", mysql.TypeVarString, charset.CharsetUTF8MB4, 0 | mysql.NotNullFlag, 6, types.UnspecifiedLength},