Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: check for null values when comparing different groups during streamAgg #15742

Merged
merged 9 commits into from
Mar 27, 2020
96 changes: 73 additions & 23 deletions executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -1111,8 +1111,15 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
return err
}

previousIsNull := col.IsNull(0)
var firstRowDatum, lastRowDatum types.Datum
firstRowIsNull, lastRowIsNull := col.IsNull(0), col.IsNull(numRows-1)
if firstRowIsNull {
firstRowDatum.SetNull()
}
if lastRowIsNull {
lastRowDatum.SetNull()
}
Comment on lines +1116 to +1121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this behand switch eType, and reduce the if to check null in the switch.

previousIsNull := firstRowIsNull
switch eType {
case types.ETInt:
vals := col.Int64s()
Expand All @@ -1128,8 +1135,12 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
}
previousIsNull = isNull
}
firstRowDatum.SetInt64(vals[0])
lastRowDatum.SetInt64(vals[numRows-1])
if !firstRowIsNull {
firstRowDatum.SetInt64(vals[0])
}
if !lastRowIsNull {
lastRowDatum.SetInt64(vals[numRows-1])
}
case types.ETReal:
vals := col.Float64s()
for i := 1; i < numRows; i++ {
Expand All @@ -1144,8 +1155,12 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
}
previousIsNull = isNull
}
firstRowDatum.SetFloat64(vals[0])
lastRowDatum.SetFloat64(vals[numRows-1])
if !firstRowIsNull {
firstRowDatum.SetFloat64(vals[0])
}
if !lastRowIsNull {
lastRowDatum.SetFloat64(vals[numRows-1])
}
case types.ETDecimal:
vals := col.Decimals()
for i := 1; i < numRows; i++ {
Expand All @@ -1160,10 +1175,16 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
}
previousIsNull = isNull
}
// make a copy to avoid DATA RACE
firstDatum, lastDatum := vals[0], vals[numRows-1]
firstRowDatum.SetMysqlDecimal(&firstDatum)
lastRowDatum.SetMysqlDecimal(&lastDatum)
if !firstRowIsNull {
// make a copy to avoid DATA RACE
firstDatum := vals[0]
firstRowDatum.SetMysqlDecimal(&firstDatum)
}
if !lastRowIsNull {
// make a copy to avoid DATA RACE
lastDatum := vals[numRows-1]
lastRowDatum.SetMysqlDecimal(&lastDatum)
}
case types.ETDatetime, types.ETTimestamp:
vals := col.Times()
for i := 1; i < numRows; i++ {
Expand All @@ -1178,8 +1199,12 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
}
previousIsNull = isNull
}
firstRowDatum.SetMysqlTime(vals[0])
lastRowDatum.SetMysqlTime(vals[numRows-1])
if !firstRowIsNull {
firstRowDatum.SetMysqlTime(vals[0])
}
if !lastRowIsNull {
lastRowDatum.SetMysqlTime(vals[numRows-1])
}
case types.ETDuration:
vals := col.GoDurations()
for i := 1; i < numRows; i++ {
Expand All @@ -1194,24 +1219,44 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
}
previousIsNull = isNull
}
firstRowDatum.SetMysqlDuration(types.Duration{Duration: vals[0], Fsp: int8(item.GetType().Decimal)})
lastRowDatum.SetMysqlDuration(types.Duration{Duration: vals[numRows-1], Fsp: int8(item.GetType().Decimal)})
if !firstRowIsNull {
firstRowDatum.SetMysqlDuration(types.Duration{Duration: vals[0], Fsp: int8(item.GetType().Decimal)})
}
if !lastRowIsNull {
lastRowDatum.SetMysqlDuration(types.Duration{Duration: vals[numRows-1], Fsp: int8(item.GetType().Decimal)})
}
case types.ETJson:
previousKey := col.GetJSON(0)
var previousKey, key json.BinaryJSON
if !previousIsNull {
previousKey = col.GetJSON(0)
}
for i := 1; i < numRows; i++ {
key := col.GetJSON(i)
isNull := col.IsNull(i)
if !isNull {
key = col.GetJSON(i)
}
if e.sameGroup[i] {
if isNull != previousIsNull || json.CompareBinary(previousKey, key) != 0 {
if isNull == previousIsNull {
if !isNull && json.CompareBinary(previousKey, key) != 0 {
e.sameGroup[i] = false
}
} else {
e.sameGroup[i] = false
}
}
previousKey = key
if !isNull {
previousKey = key
}
previousIsNull = isNull
}
// make a copy to avoid DATA RACE
firstRowDatum.SetMysqlJSON(col.GetJSON(0).Copy())
lastRowDatum.SetMysqlJSON(col.GetJSON(numRows - 1).Copy())
if !firstRowIsNull {
// make a copy to avoid DATA RACE
firstRowDatum.SetMysqlJSON(col.GetJSON(0).Copy())
}
if !lastRowIsNull {
// make a copy to avoid DATA RACE
lastRowDatum.SetMysqlJSON(col.GetJSON(numRows - 1).Copy())
}
case types.ETString:
previousKey := codec.ConvertByCollationStr(col.GetString(0), tp)
for i := 1; i < numRows; i++ {
Expand All @@ -1225,9 +1270,14 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
previousKey = key
previousIsNull = isNull
}
// don't use col.GetString since it will cause DATA RACE
firstRowDatum.SetString(string(col.GetBytes(0)), tp.Collate)
lastRowDatum.SetString(string(col.GetBytes(numRows-1)), tp.Collate)
if !firstRowIsNull {
// don't use col.GetString since it will cause DATA RACE
firstRowDatum.SetString(string(col.GetBytes(0)), tp.Collate)
}
if !lastRowIsNull {
// don't use col.GetString since it will cause DATA RACE
lastRowDatum.SetString(string(col.GetBytes(numRows-1)), tp.Collate)
}
default:
err = errors.New(fmt.Sprintf("invalid eval type %v", eType))
}
Expand Down
51 changes: 51 additions & 0 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -844,3 +844,54 @@ func (s *testSuiteAgg) TestPR15242ShallowCopy(c *C) {
tk.MustQuery(`select max(JSON_EXTRACT(a, '$.score')) as max_score,JSON_EXTRACT(a,'$.id') as id from t group by id order by id;`).Check(testkit.Rows("233 1", "233 2", "233 3"))

}

func (s *testSuiteAgg) TestIssue15690(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.Se.GetSessionVars().MaxChunkSize = 2
// check for INT type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a int);`)
tk.MustExec(`insert into t values(null),(null);`)
tk.MustExec(`insert into t values(0),(2),(2),(4),(8);`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>", "0", "2", "4", "8"))
SunRunAway marked this conversation as resolved.
Show resolved Hide resolved
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))

// check for FLOAT type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a float);`)
tk.MustExec(`insert into t values(null),(null),(null),(null);`)
tk.MustExec(`insert into t values(1.1),(1.1);`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>", "1.1"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))

// check for DECIMAL type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a decimal(5,1));`)
tk.MustExec(`insert into t values(null),(null),(null);`)
tk.MustExec(`insert into t values(1.1),(2.2),(2.2);`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>", "1.1", "2.2"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))

// check for DATETIME type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a datetime);`)
tk.MustExec(`insert into t values(null);`)
tk.MustExec(`insert into t values("2019-03-20 21:50:00"),("2019-03-20 21:50:01"), ("2019-03-20 21:50:00");`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>", "2019-03-20 21:50:00", "2019-03-20 21:50:01"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))

// check for JSON type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a json);`)
tk.MustExec(`insert into t values(null),(null),(null),(null);`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))

// check for char type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a char);`)
tk.MustExec(`insert into t values(null),(null),(null),(null);`)
tk.MustExec(`insert into t values('a'),('b');`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>", "a", "b"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))
}
33 changes: 30 additions & 3 deletions executor/executor_required_rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,25 +713,40 @@ func (s *testExecSuite) TestMergeJoinRequiredRows(c *C) {

func genTestChunk4VecGroupChecker(chkRows []int, sameNum int) (expr []expression.Expression, inputs []*chunk.Chunk) {
chkNum := len(chkRows)
numRows := 0
inputs = make([]*chunk.Chunk, chkNum)
fts := make([]*types.FieldType, 1)
fts[0] = types.NewFieldType(mysql.TypeLonglong)
for i := 0; i < chkNum; i++ {
inputs[i] = chunk.New(fts, chkRows[i], chkRows[i])
numRows += chkRows[i]
}
var numGroups int
if numRows%sameNum == 0 {
numGroups = numRows / sameNum
} else {
numGroups = numRows/sameNum + 1
}

rand.Seed(time.Now().Unix())
nullPos := rand.Intn(numGroups)
cnt := 0
val := 0
val := rand.Int63()
for i := 0; i < chkNum; i++ {
col := inputs[i].Column(0)
col.ResizeInt64(chkRows[i], false)
i64s := col.Int64s()
for j := 0; j < chkRows[i]; j++ {
if cnt == sameNum {
val++
val = rand.Int63()
cnt = 0
nullPos--
}
if nullPos == 0 {
col.SetNull(j, true)
} else {
i64s[j] = val
}
i64s[j] = int64(val)
cnt++
}
}
Expand Down Expand Up @@ -775,6 +790,18 @@ func (s *testExecSuite) TestVecGroupChecker(c *C) {
expectedFlag: []bool{false, false},
sameNum: 1,
},
{
chunkRows: []int{2, 2},
expectedGroups: 2,
expectedFlag: []bool{false, false},
sameNum: 2,
},
{
chunkRows: []int{2, 2},
expectedGroups: 1,
expectedFlag: []bool{false, true},
sameNum: 4,
},
}

ctx := mock.NewContext()
Expand Down