Skip to content

Commit

Permalink
planner: fix wrong collation when rewrite in condition (#30492)
Browse files Browse the repository at this point in the history
close #30486
  • Loading branch information
wjhuang2016 authored Dec 21, 2021
1 parent e12342b commit 416617e
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
10 changes: 10 additions & 0 deletions expression/integration_serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,16 @@ func TestCollationBasic(t *testing.T) {
tk.MustQuery("select * from t1 where col1 >= 0xc484 and col1 <= 0xc3b3;").Check(testkit.Rows("Ȇ"))

tk.MustQuery("select collation(IF('a' < 'B' collate utf8mb4_general_ci, 'smaller', 'greater' collate utf8mb4_unicode_ci));").Check(testkit.Rows("utf8mb4_unicode_ci"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a char(10))")
tk.MustExec("insert into t values ('a')")
tk.MustQuery("select * from t where a in ('b' collate utf8mb4_general_ci, 'A', 3)").Check(testkit.Rows("a"))
// These test cases may not the same as MySQL, but it's more reasonable.
tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_general_ci));").Check(testkit.Rows("1"))
tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_bin));").Check(testkit.Rows("0"))
tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_general_ci), ('b', 'b'));").Check(testkit.Rows("1"))
tk.MustQuery("select ('a', 'a') in (('A' collate utf8mb4_general_ci, 'A' collate utf8mb4_bin), ('b', 'b'));").Check(testkit.Rows("0"))
}

func TestWeightString(t *testing.T) {
Expand Down
60 changes: 60 additions & 0 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1492,6 +1492,12 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field
if allSameType && l == 1 && lLen > 1 {
function = er.notToExpression(not, ast.In, tp, er.ctxStack[stkLen-lLen-1:]...)
} else {
// If we rewrite IN to EQ, we need to decide what's the collation EQ uses.
coll := er.deriveCollationForIn(l, lLen, stkLen, args)
if er.err != nil {
return
}
er.castCollationForIn(l, lLen, stkLen, coll)
eqFunctions := make([]expression.Expression, 0, lLen)
for i := stkLen - lLen; i < stkLen; i++ {
expr, err := er.constructBinaryOpFunction(args[0], er.ctxStack[i], ast.EQ)
Expand All @@ -1515,6 +1521,60 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field
er.ctxStackAppend(function, types.EmptyName)
}

// deriveCollationForIn derives collation for in expression.
func (er *expressionRewriter) deriveCollationForIn(colLen int, elemCnt int, stkLen int, args []expression.Expression) []*expression.ExprCollation {
coll := make([]*expression.ExprCollation, 0, colLen)
if colLen == 1 {
// a in (x, y, z) => coll[0]
coll2, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "IN", types.ETInt, args...)
er.err = err
if er.err != nil {
return nil
}
coll = append(coll, coll2)
} else {
// (a, b, c) in ((x1, x2, x3), (y1, y2, y3), (z1, z2, z3)) => coll[0], coll[1], coll[2]
for i := 0; i < colLen; i++ {
args := make([]expression.Expression, 0, elemCnt)
for j := stkLen - elemCnt - 1; j < stkLen; j++ {
rowFunc, _ := er.ctxStack[j].(*expression.ScalarFunction)
args = append(args, rowFunc.GetArgs()[i])
}
coll2, err := expression.CheckAndDeriveCollationFromExprs(er.sctx, "IN", types.ETInt, args...)
er.err = err
if er.err != nil {
return nil
}
coll = append(coll, coll2)
}
}
return coll
}

// castCollationForIn casts collation info for arguments in the `in clause` to make sure the used collation is correct after we
// rewrite it to equal expression.
func (er *expressionRewriter) castCollationForIn(colLen int, elemCnt int, stkLen int, coll []*expression.ExprCollation) {
for i := stkLen - elemCnt; i < stkLen; i++ {
if colLen == 1 && er.ctxStack[i].GetType().EvalType() == types.ETString {
tp := er.ctxStack[i].GetType().Clone()
tp.Charset, tp.Collate = coll[0].Charset, coll[0].Collation
er.ctxStack[i] = expression.BuildCastFunction(er.sctx, er.ctxStack[i], tp)
er.ctxStack[i].SetCoercibility(expression.CoercibilityExplicit)
} else {
rowFunc, _ := er.ctxStack[i].(*expression.ScalarFunction)
for j := 0; j < colLen; j++ {
if er.ctxStack[i].GetType().EvalType() != types.ETString {
continue
}
tp := rowFunc.GetArgs()[j].GetType().Clone()
tp.Charset, tp.Collate = coll[j].Charset, coll[j].Collation
rowFunc.GetArgs()[j] = expression.BuildCastFunction(er.sctx, rowFunc.GetArgs()[j], tp)
rowFunc.GetArgs()[j].SetCoercibility(expression.CoercibilityExplicit)
}
}
}
}

func (er *expressionRewriter) caseToExpression(v *ast.CaseExpr) {
stkLen := len(er.ctxStack)
argsLen := 2 * len(v.WhenClauses)
Expand Down

0 comments on commit 416617e

Please sign in to comment.