Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner, type: fix AggFieldType error when encouter unsigned and sign type #21062

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,21 @@ import (

// NewOne stands for a number 1.
func NewOne() *Constant {
retT := types.NewFieldType(mysql.TypeTiny)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand why we need to change these 2 funcs

Copy link
Contributor Author

@rogeryk rogeryk Nov 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because i make test fail when i modify AggFieldType. The test is https://github.com/pingcap/tidb/blob/master/planner/cascades/transformation_rules_test.go#L389. "SQL": "select count(b), sum(b), avg(b), b, max(b), min(b), bit_and(b), bit_or(b), bit_xor(b) from t group by a having sum(b) >= 0 and count(b) >= 0 order by b",

The reason for fail is that the bit_or will add a cast wrap (cast(b as bigint unsigned binary), the arg is unsigned.
the TransformAggToProj will add ifnull wrap ifnull(cast(b as bigint unsigned binary), expression.NewZero()).
ifnull use AggFieldType infer type to tigger integral promotion.
https://github.com/pingcap/tidb/blob/master/planner/core/rule_aggregation_elimination.go#L124

I think this is a easy way to process. what do you think?

retT.Flag |= mysql.UnsignedFlag // shrink range to avoid integral promotion
return &Constant{
Value: types.NewDatum(1),
RetType: types.NewFieldType(mysql.TypeTiny),
RetType: retT,
}
}

// NewZero stands for a number 0.
func NewZero() *Constant {
retT := types.NewFieldType(mysql.TypeTiny)
retT.Flag |= mysql.UnsignedFlag // shrink range to avoid integral promotion
return &Constant{
Value: types.NewDatum(0),
RetType: types.NewFieldType(mysql.TypeTiny),
RetType: retT,
}
}

Expand Down
36 changes: 36 additions & 0 deletions planner/core/expression_rewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,39 @@ func (s *testExpressionRewriterSuite) TestIssue20007(c *C) {
testkit.Rows("2 epic wiles 2020-01-02 23:29:51", "3 silly burnell 2020-02-25 07:43:07"))
}
}

func (s *testExpressionRewriterSuite) TestIssue9869(c *C) {
defer testleak.AfterTest(c)()
store, dom, err := newStoreWithBootstrap()
c.Assert(err, IsNil)
tk := testkit.NewTestKit(c, store)
defer func() {
dom.Close()
store.Close()
}()

tk.MustExec("use test;")
tk.MustExec("drop table if exists t1;")
tk.MustExec("create table t1(a int, b bigint unsigned);")
tk.MustExec("insert into t1 (a, b) values (1,4572794622775114594), (2,18196094287899841997),(3,11120436154190595086);")
tk.MustQuery("select (case t1.a when 0 then 0 else t1.b end), cast(t1.b as signed) from t1;").Check(
testkit.Rows("4572794622775114594 4572794622775114594", "18196094287899841997 -250649785809709619", "11120436154190595086 -7326307919518956530"))
}

func (s *testExpressionRewriterSuite) TestIssue17652(c *C) {
defer testleak.AfterTest(c)()
store, dom, err := newStoreWithBootstrap()
c.Assert(err, IsNil)
tk := testkit.NewTestKit(c, store)
defer func() {
dom.Close()
store.Close()
}()

tk.MustExec("use test;")
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(x bigint unsigned);")
tk.MustExec("insert into t values( 9999999703771440633);")
tk.MustQuery("select ifnull(max(x), 0) from t").Check(
testkit.Rows("9999999703771440633"))
}
9 changes: 9 additions & 0 deletions types/etc.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ func IsTypeTime(tp byte) bool {
return tp == mysql.TypeDatetime || tp == mysql.TypeDate || tp == mysql.TypeTimestamp
}

// IsTypeInteger returns a boolean indicating whether the tp is integer type.
func IsTypeInteger(tp byte) bool {
switch tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear:
return true
}
return false
}

// IsTypeNumeric returns a boolean indicating whether the tp is numeric type.
func IsTypeNumeric(tp byte) bool {
switch tp {
Expand Down
29 changes: 26 additions & 3 deletions types/field_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,38 @@ func NewFieldTypeWithCollation(tp byte, collation string, length int) *FieldType
// Aggregation is performed by MergeFieldType function.
func AggFieldType(tps []*FieldType) *FieldType {
var currType FieldType
isMixedSign := false
for i, t := range tps {
if i == 0 && currType.Tp == mysql.TypeUnspecified {
currType = *t
continue
}
mtp := MergeFieldType(currType.Tp, t.Tp)
isMixedSign = isMixedSign || (mysql.HasUnsignedFlag(currType.Flag) != mysql.HasUnsignedFlag(t.Flag))
currType.Tp = mtp
currType.Flag = mergeTypeFlag(currType.Flag, t.Flag)
}
// integral promotion when tps contains signed and unsigned
if isMixedSign && IsTypeInteger(currType.Tp) {
bumpRange := false // indicate one of tps bump currType range
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does MySQL also bumpRange ?

Copy link
Contributor Author

@rogeryk rogeryk Nov 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for _, t := range tps {
bumpRange = bumpRange || (mysql.HasUnsignedFlag(t.Flag) && (t.Tp == currType.Tp || t.Tp == mysql.TypeBit))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why check mysql.TypeBit specially

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the TypeBit can't be mergeType and assume it max range bit(64). so it must bump.
I find the all IntegerType merege mysql.TypeBit is mysql.TypeVarchar.
It is inconsistent with mysql 8.0 https://github.com/mysql/mysql-server/blob/8.0/sql/field.cc#L248.

Whether it will be consistent with mysql in the future. if not, No need to consider mysql.TypeBit here

}
if bumpRange {
switch currType.Tp {
case mysql.TypeTiny:
currType.Tp = mysql.TypeShort
case mysql.TypeShort:
currType.Tp = mysql.TypeInt24
case mysql.TypeInt24:
currType.Tp = mysql.TypeLong
case mysql.TypeLong:
currType.Tp = mysql.TypeLonglong
case mysql.TypeLonglong:
currType.Tp = mysql.TypeNewDecimal
}
}
}

return &currType
}
Expand Down Expand Up @@ -311,10 +334,10 @@ func MergeFieldType(a byte, b byte) byte {
}

// mergeTypeFlag merges two MySQL type flag to a new one
// currently only NotNullFlag is checked
// todo more flag need to be checked, for example: UnsignedFlag
// currently only NotNullFlag and UnsignedFlag is checked
// todo more flag need to be checked
func mergeTypeFlag(a, b uint) uint {
return a & (b&mysql.NotNullFlag | ^mysql.NotNullFlag)
return a & (b&mysql.NotNullFlag | ^mysql.NotNullFlag) & (b&mysql.UnsignedFlag | ^mysql.UnsignedFlag)
}

func getFieldTypeIndex(tp byte) int {
Expand Down
38 changes: 38 additions & 0 deletions types/field_type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,44 @@ func (s *testFieldTypeSuite) TestAggFieldTypeForTypeFlag(c *C) {
c.Assert(aggTp.Flag, Equals, mysql.NotNullFlag)
}

func (s testFieldTypeSuite) TestAggFieldTypeForIntegralPromotion(c *C) {
fts := []*FieldType{
NewFieldType(mysql.TypeTiny),
NewFieldType(mysql.TypeShort),
NewFieldType(mysql.TypeInt24),
NewFieldType(mysql.TypeLong),
NewFieldType(mysql.TypeLonglong),
NewFieldType(mysql.TypeNewDecimal),
}

for i := 1; i < len(fts)-1; i++ {
tps := fts[i-1 : i+1]

tps[0].Flag = 0
tps[1].Flag = 0
aggTp := AggFieldType(tps)
c.Assert(aggTp.Tp, Equals, fts[i].Tp)
c.Assert(aggTp.Flag, Equals, uint(0))

tps[0].Flag = mysql.UnsignedFlag
aggTp = AggFieldType(tps)
c.Assert(aggTp.Tp, Equals, fts[i].Tp)
c.Assert(aggTp.Flag, Equals, uint(0))

tps[0].Flag = mysql.UnsignedFlag
tps[1].Flag = mysql.UnsignedFlag
aggTp = AggFieldType(tps)
c.Assert(aggTp.Tp, Equals, fts[i].Tp)
c.Assert(aggTp.Flag, Equals, mysql.UnsignedFlag)

tps[0].Flag = 0
tps[1].Flag = mysql.UnsignedFlag
aggTp = AggFieldType(tps)
c.Assert(aggTp.Tp, Equals, fts[i+1].Tp)
c.Assert(aggTp.Flag, Equals, uint(0))
}
}

func (s *testFieldTypeSuite) TestAggregateEvalType(c *C) {
defer testleak.AfterTest(c)()
fts := []*FieldType{
Expand Down