Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: check for null values when comparing different groups during streamAgg (#15742) #15777

Merged
merged 3 commits into from
Mar 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 73 additions & 23 deletions executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -1111,8 +1111,15 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
return err
}

previousIsNull := col.IsNull(0)
var firstRowDatum, lastRowDatum types.Datum
firstRowIsNull, lastRowIsNull := col.IsNull(0), col.IsNull(numRows-1)
if firstRowIsNull {
firstRowDatum.SetNull()
}
if lastRowIsNull {
lastRowDatum.SetNull()
}
previousIsNull := firstRowIsNull
switch eType {
case types.ETInt:
vals := col.Int64s()
Expand All @@ -1128,8 +1135,12 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
}
previousIsNull = isNull
}
firstRowDatum.SetInt64(vals[0])
lastRowDatum.SetInt64(vals[numRows-1])
if !firstRowIsNull {
firstRowDatum.SetInt64(vals[0])
}
if !lastRowIsNull {
lastRowDatum.SetInt64(vals[numRows-1])
}
case types.ETReal:
vals := col.Float64s()
for i := 1; i < numRows; i++ {
Expand All @@ -1144,8 +1155,12 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
}
previousIsNull = isNull
}
firstRowDatum.SetFloat64(vals[0])
lastRowDatum.SetFloat64(vals[numRows-1])
if !firstRowIsNull {
firstRowDatum.SetFloat64(vals[0])
}
if !lastRowIsNull {
lastRowDatum.SetFloat64(vals[numRows-1])
}
case types.ETDecimal:
vals := col.Decimals()
for i := 1; i < numRows; i++ {
Expand All @@ -1160,10 +1175,16 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
}
previousIsNull = isNull
}
// make a copy to avoid DATA RACE
firstDatum, lastDatum := vals[0], vals[numRows-1]
firstRowDatum.SetMysqlDecimal(&firstDatum)
lastRowDatum.SetMysqlDecimal(&lastDatum)
if !firstRowIsNull {
// make a copy to avoid DATA RACE
firstDatum := vals[0]
firstRowDatum.SetMysqlDecimal(&firstDatum)
}
if !lastRowIsNull {
// make a copy to avoid DATA RACE
lastDatum := vals[numRows-1]
lastRowDatum.SetMysqlDecimal(&lastDatum)
}
case types.ETDatetime, types.ETTimestamp:
vals := col.Times()
for i := 1; i < numRows; i++ {
Expand All @@ -1178,8 +1199,12 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
}
previousIsNull = isNull
}
firstRowDatum.SetMysqlTime(vals[0])
lastRowDatum.SetMysqlTime(vals[numRows-1])
if !firstRowIsNull {
firstRowDatum.SetMysqlTime(vals[0])
}
if !lastRowIsNull {
lastRowDatum.SetMysqlTime(vals[numRows-1])
}
case types.ETDuration:
vals := col.GoDurations()
for i := 1; i < numRows; i++ {
Expand All @@ -1194,24 +1219,44 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
}
previousIsNull = isNull
}
firstRowDatum.SetMysqlDuration(types.Duration{Duration: vals[0], Fsp: int8(item.GetType().Decimal)})
lastRowDatum.SetMysqlDuration(types.Duration{Duration: vals[numRows-1], Fsp: int8(item.GetType().Decimal)})
if !firstRowIsNull {
firstRowDatum.SetMysqlDuration(types.Duration{Duration: vals[0], Fsp: int8(item.GetType().Decimal)})
}
if !lastRowIsNull {
lastRowDatum.SetMysqlDuration(types.Duration{Duration: vals[numRows-1], Fsp: int8(item.GetType().Decimal)})
}
case types.ETJson:
previousKey := col.GetJSON(0)
var previousKey, key json.BinaryJSON
if !previousIsNull {
previousKey = col.GetJSON(0)
}
for i := 1; i < numRows; i++ {
key := col.GetJSON(i)
isNull := col.IsNull(i)
if !isNull {
key = col.GetJSON(i)
}
if e.sameGroup[i] {
if isNull != previousIsNull || json.CompareBinary(previousKey, key) != 0 {
if isNull == previousIsNull {
if !isNull && json.CompareBinary(previousKey, key) != 0 {
e.sameGroup[i] = false
}
} else {
e.sameGroup[i] = false
}
}
previousKey = key
if !isNull {
previousKey = key
}
previousIsNull = isNull
}
// make a copy to avoid DATA RACE
firstRowDatum.SetMysqlJSON(col.GetJSON(0).Copy())
lastRowDatum.SetMysqlJSON(col.GetJSON(numRows - 1).Copy())
if !firstRowIsNull {
// make a copy to avoid DATA RACE
firstRowDatum.SetMysqlJSON(col.GetJSON(0).Copy())
}
if !lastRowIsNull {
// make a copy to avoid DATA RACE
lastRowDatum.SetMysqlJSON(col.GetJSON(numRows - 1).Copy())
}
case types.ETString:
previousKey := codec.ConvertByCollationStr(col.GetString(0), tp)
for i := 1; i < numRows; i++ {
Expand All @@ -1225,9 +1270,14 @@ func (e *vecGroupChecker) evalGroupItemsAndResolveGroups(item expression.Express
previousKey = key
previousIsNull = isNull
}
// don't use col.GetString since it will cause DATA RACE
firstRowDatum.SetString(string(col.GetBytes(0)), tp.Collate)
lastRowDatum.SetString(string(col.GetBytes(numRows-1)), tp.Collate)
if !firstRowIsNull {
// don't use col.GetString since it will cause DATA RACE
firstRowDatum.SetString(string(col.GetBytes(0)), tp.Collate)
}
if !lastRowIsNull {
// don't use col.GetString since it will cause DATA RACE
lastRowDatum.SetString(string(col.GetBytes(numRows-1)), tp.Collate)
}
default:
err = errors.New(fmt.Sprintf("invalid eval type %v", eType))
}
Expand Down
51 changes: 51 additions & 0 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -844,3 +844,54 @@ func (s *testSuiteAgg) TestPR15242ShallowCopy(c *C) {
tk.MustQuery(`select max(JSON_EXTRACT(a, '$.score')) as max_score,JSON_EXTRACT(a,'$.id') as id from t group by id order by id;`).Check(testkit.Rows("233 1", "233 2", "233 3"))

}

func (s *testSuiteAgg) TestIssue15690(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.Se.GetSessionVars().MaxChunkSize = 2
// check for INT type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a int);`)
tk.MustExec(`insert into t values(null),(null);`)
tk.MustExec(`insert into t values(0),(2),(2),(4),(8);`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>", "0", "2", "4", "8"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))

// check for FLOAT type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a float);`)
tk.MustExec(`insert into t values(null),(null),(null),(null);`)
tk.MustExec(`insert into t values(1.1),(1.1);`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>", "1.1"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))

// check for DECIMAL type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a decimal(5,1));`)
tk.MustExec(`insert into t values(null),(null),(null);`)
tk.MustExec(`insert into t values(1.1),(2.2),(2.2);`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>", "1.1", "2.2"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))

// check for DATETIME type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a datetime);`)
tk.MustExec(`insert into t values(null);`)
tk.MustExec(`insert into t values("2019-03-20 21:50:00"),("2019-03-20 21:50:01"), ("2019-03-20 21:50:00");`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>", "2019-03-20 21:50:00", "2019-03-20 21:50:01"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))

// check for JSON type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a json);`)
tk.MustExec(`insert into t values(null),(null),(null),(null);`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))

// check for char type
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a char);`)
tk.MustExec(`insert into t values(null),(null),(null),(null);`)
tk.MustExec(`insert into t values('a'),('b');`)
tk.MustQuery(`select /*+ stream_agg() */ distinct * from t;`).Check(testkit.Rows("<nil>", "a", "b"))
c.Assert(tk.Se.GetSessionVars().StmtCtx.WarningCount(), Equals, uint16(0))
}
33 changes: 30 additions & 3 deletions executor/executor_required_rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,25 +713,40 @@ func (s *testExecSuite) TestMergeJoinRequiredRows(c *C) {

func genTestChunk4VecGroupChecker(chkRows []int, sameNum int) (expr []expression.Expression, inputs []*chunk.Chunk) {
chkNum := len(chkRows)
numRows := 0
inputs = make([]*chunk.Chunk, chkNum)
fts := make([]*types.FieldType, 1)
fts[0] = types.NewFieldType(mysql.TypeLonglong)
for i := 0; i < chkNum; i++ {
inputs[i] = chunk.New(fts, chkRows[i], chkRows[i])
numRows += chkRows[i]
}
var numGroups int
if numRows%sameNum == 0 {
numGroups = numRows / sameNum
} else {
numGroups = numRows/sameNum + 1
}

rand.Seed(time.Now().Unix())
nullPos := rand.Intn(numGroups)
cnt := 0
val := 0
val := rand.Int63()
for i := 0; i < chkNum; i++ {
col := inputs[i].Column(0)
col.ResizeInt64(chkRows[i], false)
i64s := col.Int64s()
for j := 0; j < chkRows[i]; j++ {
if cnt == sameNum {
val++
val = rand.Int63()
cnt = 0
nullPos--
}
if nullPos == 0 {
col.SetNull(j, true)
} else {
i64s[j] = val
}
i64s[j] = int64(val)
cnt++
}
}
Expand Down Expand Up @@ -775,6 +790,18 @@ func (s *testExecSuite) TestVecGroupChecker(c *C) {
expectedFlag: []bool{false, false},
sameNum: 1,
},
{
chunkRows: []int{2, 2},
expectedGroups: 2,
expectedFlag: []bool{false, false},
sameNum: 2,
},
{
chunkRows: []int{2, 2},
expectedGroups: 1,
expectedFlag: []bool{false, true},
sameNum: 4,
},
}

ctx := mock.NewContext()
Expand Down