diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index bfbb005814eb0..70626720e6e83 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -469,6 +469,14 @@ func (c *greatestFunctionClass) getFunction(ctx sessionctx.Context, args []Expre } switch tp { case types.ETInt: + // adjust unsigned flag + greastInitUnsignedFlag := false + if isEqualsInitUnsignedFlag(greastInitUnsignedFlag, args) { + bf.tp.Flag &= ^mysql.UnsignedFlag + } else { + bf.tp.Flag |= mysql.UnsignedFlag + } + sig = &builtinGreatestIntSig{bf} sig.setPbCode(tipb.ScalarFuncSig_GreatestInt) case types.ETReal: @@ -736,6 +744,14 @@ func (c *leastFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi } switch tp { case types.ETInt: + // adjust unsigned flag + leastInitUnsignedFlag := true + if isEqualsInitUnsignedFlag(leastInitUnsignedFlag, args) { + bf.tp.Flag |= mysql.UnsignedFlag + } else { + bf.tp.Flag &= ^mysql.UnsignedFlag + } + sig = &builtinLeastIntSig{bf} sig.setPbCode(tipb.ScalarFuncSig_LeastInt) case types.ETReal: @@ -2846,3 +2862,15 @@ func CompareJSON(sctx sessionctx.Context, lhsArg, rhsArg Expression, lhsRow, rhs } return int64(json.CompareBinary(arg0, arg1)), false, nil } + +// isEqualsInitUnsignedFlag can adjust unsigned flag for greatest/least function. +// For greatest, returns unsigned result if there is at least one argument is unsigned. +// For least, returns signed result if there is at least one argument is signed. +func isEqualsInitUnsignedFlag(initUnsigned bool, args []Expression) bool { + for _, arg := range args { + if initUnsigned != mysql.HasUnsignedFlag(arg.GetType().Flag) { + return false + } + } + return true +} diff --git a/expression/builtin_compare_test.go b/expression/builtin_compare_test.go index 9515d80ded3bb..16cb7f906448b 100644 --- a/expression/builtin_compare_test.go +++ b/expression/builtin_compare_test.go @@ -263,6 +263,8 @@ func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) { sc := s.ctx.GetSessionVars().StmtCtx originIgnoreTruncate := sc.IgnoreTruncate sc.IgnoreTruncate = true + decG := &types.MyDecimal{} + decL := &types.MyDecimal{} defer func() { sc.IgnoreTruncate = originIgnoreTruncate }() @@ -274,6 +276,14 @@ func (s *testEvaluatorSuite) TestGreatestLeastFunc(c *C) { isNil bool getErr bool }{ + { + []interface{}{int64(-9223372036854775808), uint64(9223372036854775809)}, + decG.FromUint(9223372036854775809), decL.FromInt(-9223372036854775808), false, false, + }, + { + []interface{}{uint64(9223372036854775808), uint64(9223372036854775809)}, + uint64(9223372036854775809), uint64(9223372036854775808), false, false, + }, { []interface{}{1, 2, 3, 4}, int64(4), int64(1), false, false, diff --git a/expression/integration_test.go b/expression/integration_test.go index 8fef748f8f2ed..a333053d416ac 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -9367,6 +9367,15 @@ func (s *testIntegrationSuite) TestConstPropNullFunctions(c *C) { tk.MustQuery("select * from t2 where t2.i2=((select count(1) from t1 where t1.i1=t2.i2))").Check(testkit.Rows("1 0.1")) } +func (s *testIntegrationSuite) TestIssue30101(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(c1 bigint unsigned, c2 bigint unsigned);") + tk.MustExec("insert into t1 values(9223372036854775808, 9223372036854775809);") + tk.MustQuery("select greatest(c1, c2) from t1;").Sort().Check(testkit.Rows("9223372036854775809")) +} + func (s *testIntegrationSuite) TestIssue28643(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 1039eb1bef4bb..4f37419eefc03 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -1035,6 +1035,13 @@ func (s *testInferTypeSuite) createTestCase4CompareFuncs() []typeInferTestCase { {"interval(c_int_d, c_int_d, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, {"interval(c_int_d, c_float_d, c_double_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + + {"greatest(c_bigint_d, c_ubigint_d, c_int_d)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"greatest(c_ubigint_d, c_ubigint_d, c_uint_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, mysql.MaxIntWidth, 0}, + {"greatest(c_uint_d, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 11, 0}, + {"least(c_bigint_d, c_ubigint_d, c_int_d)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"least(c_ubigint_d, c_ubigint_d, c_uint_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, mysql.MaxIntWidth, 0}, + {"least(c_uint_d, c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0}, } }