From f365238981a577e25e86029060947e67d729e08a Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Mon, 8 Nov 2021 10:37:03 +0800 Subject: [PATCH] expression: fix wrong caseWhen function for enum type (#29454) (#29512) --- expression/builtin_control.go | 8 ++++++++ expression/constant_fold.go | 2 +- expression/integration_test.go | 13 +++++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/expression/builtin_control.go b/expression/builtin_control.go index e9b39bf36ab5c..6ea2c119176c7 100644 --- a/expression/builtin_control.go +++ b/expression/builtin_control.go @@ -237,6 +237,14 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre return nil, err } bf.tp = fieldTp + if fieldTp.Tp == mysql.TypeEnum || fieldTp.Tp == mysql.TypeSet { + switch tp { + case types.ETInt: + fieldTp.Tp = mysql.TypeLonglong + case types.ETString: + fieldTp.Tp = mysql.TypeVarchar + } + } switch tp { case types.ETInt: diff --git a/expression/constant_fold.go b/expression/constant_fold.go index f08a5c45abf60..d7cbaf5d8edd6 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -143,7 +143,7 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { foldedExpr.GetType().Decimal = expr.GetType().Decimal return foldedExpr, isDeferredConst } - return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst + return foldedExpr, isDeferredConst } return expr, isDeferredConst } diff --git a/expression/integration_test.go b/expression/integration_test.go index 6e3cae50fe249..22338090000c4 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -9936,6 +9936,19 @@ func (s *testIntegrationSuite) TestControlFunctionWithEnumOrSet(c *C) { tk.MustExec("insert into t values(1,1,1),(2,1,1),(1,1,1),(2,1,1);") tk.MustQuery("select if(A, null,b)=1 from t;").Check(testkit.Rows("", "", "", "")) tk.MustQuery("select if(A, null,b)='a' from t;").Check(testkit.Rows("", "", "", "")) + + // issue 29357 + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(`a` enum('y','b','Abc','null','1','2','0')) CHARSET=binary;") + tk.MustExec("insert into t values(\"1\");") + tk.MustQuery("SELECT count(*) from t where (null like 'a') = (case when cast('2015' as real) <=> round(\"1200\",\"1\") then a end);\n").Check(testkit.Rows("0")) + tk.MustQuery("SELECT (null like 'a') = (case when cast('2015' as real) <=> round(\"1200\",\"1\") then a end) from t;\n").Check(testkit.Rows("")) + tk.MustQuery("SELECT 5 = (case when 0 <=> 0 then a end) from t;").Check(testkit.Rows("1")) + tk.MustQuery("SELECT '1' = (case when 0 <=> 0 then a end) from t;").Check(testkit.Rows("1")) + tk.MustQuery("SELECT 5 = (case when 0 <=> 1 then a end) from t;").Check(testkit.Rows("")) + tk.MustQuery("SELECT '1' = (case when 0 <=> 1 then a end) from t;").Check(testkit.Rows("")) + tk.MustQuery("SELECT 5 = (case when 0 <=> 1 then a else a end) from t;").Check(testkit.Rows("1")) + tk.MustQuery("SELECT '1' = (case when 0 <=> 1 then a else a end) from t;").Check(testkit.Rows("1")) } func (s *testIntegrationSuite) TestComplexShowVariables(c *C) {