diff --git a/cmd/explaintest/r/collation_check_use_collation.result b/cmd/explaintest/r/collation_check_use_collation.result new file mode 100644 index 0000000000000..ffd787a4cef43 --- /dev/null +++ b/cmd/explaintest/r/collation_check_use_collation.result @@ -0,0 +1,25 @@ +create database collation_check_use_collation; +use collation_check_use_collation; +CREATE TABLE `t` ( +`a` char(10) DEFAULT NULL +); +CREATE TABLE `t1` ( +`a` char(10) COLLATE utf8mb4_general_ci DEFAULT NULL +); +insert into t values ("a"); +insert into t1 values ("A"); +select a as a_col from t where t.a = all (select a collate utf8mb4_general_ci from t1); +a_col +a +select a as a_col from t where t.a != any (select a collate utf8mb4_general_ci from t1); +a_col +select a as a_col from t where t.a <= all (select a collate utf8mb4_general_ci from t1); +a_col +a +select a as a_col from t where t.a <= any (select a collate utf8mb4_general_ci from t1); +a_col +a +select a as a_col from t where t.a = (select a collate utf8mb4_general_ci from t1); +a_col +a +use test diff --git a/cmd/explaintest/t/collation_check_use_collation.test b/cmd/explaintest/t/collation_check_use_collation.test new file mode 100644 index 0000000000000..67e75f32e38f9 --- /dev/null +++ b/cmd/explaintest/t/collation_check_use_collation.test @@ -0,0 +1,23 @@ +# These tests check that the used collation is correct. + +# prepare database +create database collation_check_use_collation; +use collation_check_use_collation; + +# Check subquery. +CREATE TABLE `t` ( + `a` char(10) DEFAULT NULL +); +CREATE TABLE `t1` ( + `a` char(10) COLLATE utf8mb4_general_ci DEFAULT NULL +); +insert into t values ("a"); +insert into t1 values ("A"); +select a as a_col from t where t.a = all (select a collate utf8mb4_general_ci from t1); +select a as a_col from t where t.a != any (select a collate utf8mb4_general_ci from t1); +select a as a_col from t where t.a <= all (select a collate utf8mb4_general_ci from t1); +select a as a_col from t where t.a <= any (select a collate utf8mb4_general_ci from t1); +select a as a_col from t where t.a = (select a collate utf8mb4_general_ci from t1); + +# cleanup environment +use test diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 8f950b3d3ece5..775aeda6a880f 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -618,6 +618,7 @@ func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression. UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), RetType: funcMaxOrMin.RetTp, } + colMaxOrMin.SetCoercibility(rexpr.Coercibility()) schema := expression.NewSchema(colMaxOrMin) plan4Agg.names = append(plan4Agg.names, types.EmptyName) @@ -735,6 +736,7 @@ func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), RetType: maxFunc.RetTp, } + maxResultCol.SetCoercibility(rexpr.Coercibility()) count := &expression.Column{ UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), RetType: countFunc.RetTp, @@ -772,6 +774,7 @@ func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), RetType: firstRowFunc.RetTp, } + firstRowResultCol.SetCoercibility(rexpr.Coercibility()) plan4Agg.names = append(plan4Agg.names, types.EmptyName) count := &expression.Column{ UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), @@ -1008,9 +1011,11 @@ func (er *expressionRewriter) handleScalarSubquery(ctx context.Context, v *ast.S if np.Schema().Len() > 1 { newCols := make([]expression.Expression, 0, np.Schema().Len()) for i, data := range row { - newCols = append(newCols, &expression.Constant{ + constant := &expression.Constant{ Value: data, - RetType: np.Schema().Columns[i].GetType()}) + RetType: np.Schema().Columns[i].GetType()} + constant.SetCoercibility(np.Schema().Columns[i].Coercibility()) + newCols = append(newCols, constant) } expr, err1 := er.newFunction(ast.RowFunc, newCols[0].GetType(), newCols...) if err1 != nil { @@ -1019,10 +1024,12 @@ func (er *expressionRewriter) handleScalarSubquery(ctx context.Context, v *ast.S } er.ctxStackAppend(expr, types.EmptyName) } else { - er.ctxStackAppend(&expression.Constant{ + constant := &expression.Constant{ Value: row[0], RetType: np.Schema().Columns[0].GetType(), - }, types.EmptyName) + } + constant.SetCoercibility(np.Schema().Columns[0].Coercibility()) + er.ctxStackAppend(constant, types.EmptyName) } return v, true }