Skip to content

Commit

Permalink
expression: fix wrong caseWhen function for enum type (#29454) (#29510)
Browse files Browse the repository at this point in the history
close #29357
  • Loading branch information
ti-srebot authored Jan 27, 2022
1 parent f9157d8 commit 2780dfb
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
8 changes: 8 additions & 0 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,14 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
return nil, err
}
bf.tp = fieldTp
if fieldTp.Tp == mysql.TypeEnum || fieldTp.Tp == mysql.TypeSet {
switch tp {
case types.ETInt:
fieldTp.Tp = mysql.TypeLonglong
case types.ETString:
fieldTp.Tp = mysql.TypeVarchar
}
}

switch tp {
case types.ETInt:
Expand Down
2 changes: 1 addition & 1 deletion expression/constant_fold.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func caseWhenHandler(expr *ScalarFunction) (Expression, bool) {
foldedExpr.GetType().Decimal = expr.GetType().Decimal
return foldedExpr, isDeferredConst
}
return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst
return foldedExpr, isDeferredConst
}
return expr, isDeferredConst
}
Expand Down
13 changes: 13 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9737,6 +9737,19 @@ func (s *testIntegrationSuite) TestControlFunctionWithEnumOrSet(c *C) {
tk.MustExec("insert into t values(1,1,1),(2,1,1),(1,1,1),(2,1,1);")
tk.MustQuery("select if(A, null,b)=1 from t;").Check(testkit.Rows("<nil>", "<nil>", "<nil>", "<nil>"))
tk.MustQuery("select if(A, null,b)='a' from t;").Check(testkit.Rows("<nil>", "<nil>", "<nil>", "<nil>"))

// issue 29357
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(`a` enum('y','b','Abc','null','1','2','0')) CHARSET=binary;")
tk.MustExec("insert into t values(\"1\");")
tk.MustQuery("SELECT count(*) from t where (null like 'a') = (case when cast('2015' as real) <=> round(\"1200\",\"1\") then a end);\n").Check(testkit.Rows("0"))
tk.MustQuery("SELECT (null like 'a') = (case when cast('2015' as real) <=> round(\"1200\",\"1\") then a end) from t;\n").Check(testkit.Rows("<nil>"))
tk.MustQuery("SELECT 5 = (case when 0 <=> 0 then a end) from t;").Check(testkit.Rows("1"))
tk.MustQuery("SELECT '1' = (case when 0 <=> 0 then a end) from t;").Check(testkit.Rows("1"))
tk.MustQuery("SELECT 5 = (case when 0 <=> 1 then a end) from t;").Check(testkit.Rows("<nil>"))
tk.MustQuery("SELECT '1' = (case when 0 <=> 1 then a end) from t;").Check(testkit.Rows("<nil>"))
tk.MustQuery("SELECT 5 = (case when 0 <=> 1 then a else a end) from t;").Check(testkit.Rows("1"))
tk.MustQuery("SELECT '1' = (case when 0 <=> 1 then a else a end) from t;").Check(testkit.Rows("1"))
}

func (s *testIntegrationSuite) TestComplexShowVariables(c *C) {
Expand Down

0 comments on commit 2780dfb

Please sign in to comment.