Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: handle signed/unsigned in the partition pruning (#15436) #17230

Merged
merged 4 commits into from
May 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions expression/simple_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package expression

import (
"context"

"github.com/pingcap/errors"
"github.com/pingcap/parser"
"github.com/pingcap/parser/ast"
Expand Down Expand Up @@ -41,10 +43,20 @@ type simpleRewriter struct {
// The expression string must only reference the column in table Info.
func ParseSimpleExprWithTableInfo(ctx sessionctx.Context, exprStr string, tableInfo *model.TableInfo) (Expression, error) {
exprStr = "select " + exprStr
stmts, warns, err := parser.New().Parse(exprStr, "", "")
var stmts []ast.StmtNode
var err error
var warns []error
if p, ok := ctx.(interface {
ParseSQL(context.Context, string, string, string) ([]ast.StmtNode, []error, error)
}); ok {
stmts, warns, err = p.ParseSQL(context.Background(), exprStr, "", "")
} else {
stmts, warns, err = parser.New().Parse(exprStr, "", "")
}
for _, warn := range warns {
ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn))
}

if err != nil {
return nil, util.SyntaxError(err)
}
Expand Down Expand Up @@ -80,12 +92,13 @@ func RewriteSimpleExprWithTableInfo(ctx sessionctx.Context, tbl *model.TableInfo
func ParseSimpleExprsWithSchema(ctx sessionctx.Context, exprStr string, schema *Schema) ([]Expression, error) {
exprStr = "select " + exprStr
stmts, warns, err := parser.New().Parse(exprStr, "", "")
for _, warn := range warns {
ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn))
}
if err != nil {
return nil, util.SyntaxWarn(err)
}
for _, warn := range warns {
ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn))
}

fields := stmts[0].(*ast.SelectStmt).Fields.Fields
exprs := make([]Expression, 0, len(fields))
for _, field := range fields {
Expand All @@ -102,13 +115,23 @@ func ParseSimpleExprsWithSchema(ctx sessionctx.Context, exprStr string, schema *
// The expression string must only reference the column in the given NameSlice.
func ParseSimpleExprsWithNames(ctx sessionctx.Context, exprStr string, schema *Schema, names types.NameSlice) ([]Expression, error) {
exprStr = "select " + exprStr
stmts, warns, err := parser.New().Parse(exprStr, "", "")
for _, warn := range warns {
ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn))
var stmts []ast.StmtNode
var err error
var warns []error
if p, ok := ctx.(interface {
ParseSQL(context.Context, string, string, string) ([]ast.StmtNode, []error, error)
}); ok {
stmts, warns, err = p.ParseSQL(context.Background(), exprStr, "", "")
} else {
stmts, warns, err = parser.New().Parse(exprStr, "", "")
}
if err != nil {
return nil, util.SyntaxWarn(err)
}
for _, warn := range warns {
ctx.GetSessionVars().StmtCtx.AppendWarning(util.SyntaxWarn(warn))
}

fields := stmts[0].(*ast.SelectStmt).Fields.Fields
exprs := make([]Expression, 0, len(fields))
for _, field := range fields {
Expand Down
2 changes: 1 addition & 1 deletion planner/core/partition_pruning_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (s *testPartitionPruningSuite) TestPruneUseBinarySearch(c *C) {
}

for i, ca := range cases {
start, end := pruneUseBinarySearch(lessThan, ca.input)
start, end := pruneUseBinarySearch(lessThan, ca.input, false)
c.Assert(ca.result.start, Equals, start, Commentf("fail = %d", i))
c.Assert(ca.result.end, Equals, end, Commentf("fail = %d", i))
}
Expand Down
2 changes: 1 addition & 1 deletion planner/core/point_get_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ func getHashPartitionColumnName(ctx sessionctx.Context, tbl *model.TableInfo) *a
return nil
}
// PartitionExpr don't need columns and names for hash partition.
partitionExpr, err := table.(partitionTable).PartitionExpr(ctx, nil, nil)
partitionExpr, err := table.(partitionTable).PartitionExpr()
if err != nil {
return nil
}
Expand Down
66 changes: 33 additions & 33 deletions planner/core/rule_partition_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ package core
import (
"context"
"sort"
"strconv"
"strings"

"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/model"
"github.com/pingcap/parser/mysql"
Expand Down Expand Up @@ -98,7 +96,7 @@ func (s *partitionProcessor) rewriteDataSource(lp LogicalPlan) (LogicalPlan, err

// partitionTable is for those tables which implement partition.
type partitionTable interface {
PartitionExpr(ctx sessionctx.Context, columns []*expression.Column, names types.NameSlice) (*tables.PartitionExpr, error)
PartitionExpr() (*tables.PartitionExpr, error)
}

func generateHashPartitionExpr(t table.Table, ctx sessionctx.Context, columns []*expression.Column, names types.NameSlice) (expression.Expression, error) {
Expand Down Expand Up @@ -191,12 +189,25 @@ func (lt *lessThanDataInt) length() int {
return len(lt.data)
}

func (lt *lessThanDataInt) compare(ith int, v int64) int {
func compareUnsigned(v1, v2 int64) int {
switch {
case uint64(v1) > uint64(v2):
return 1
case uint64(v1) == uint64(v2):
return 0
}
return -1
}

func (lt *lessThanDataInt) compare(ith int, v int64, unsigned bool) int {
if ith == len(lt.data)-1 {
if lt.maxvalue {
return 1
}
}
if unsigned {
return compareUnsigned(lt.data[ith], v)
}
switch {
case lt.data[ith] > v:
return 1
Expand Down Expand Up @@ -334,37 +345,24 @@ func (s *partitionProcessor) pruneRangePartition(ds *DataSource, pi *model.Parti
result := fullRange(len(pi.Definitions))
// Extract the partition column, if the column is not null, it's possible to prune.
if col != nil {
// TODO: Store LessThanData in the partitionExpr, avoid allocating here.
lessThan, err := makeLessThanData(pi)
partExpr, err := ds.table.(partitionTable).PartitionExpr()
if err != nil {
return nil, err
}
pruner := rangePruner{lessThan, col, fn}
pruner := rangePruner{
lessThan: lessThanDataInt{
data: partExpr.ForRangePruning.LessThan,
maxvalue: partExpr.ForRangePruning.MaxValue,
},
col: col,
partFn: fn,
}
result = partitionRangeForCNFExpr(ds.ctx, ds.allConds, &pruner, result)
}

return s.makeUnionAllChildren(ds, pi, result)
}

// makeLessThanData extracts the less than parts from 'partition p0 less than xx ... partitoin p1 less than ...'
func makeLessThanData(pi *model.PartitionInfo) (lessThanDataInt, error) {
var maxValue bool
lessThan := make([]int64, len(pi.Definitions))
for i := 0; i < len(pi.Definitions); i++ {
if strings.EqualFold(pi.Definitions[i].LessThan[0], "MAXVALUE") {
// Use a bool flag instead of math.MaxInt64 to avoid the corner cases.
maxValue = true
} else {
var err error
lessThan[i], err = strconv.ParseInt(pi.Definitions[i].LessThan[0], 10, 64)
if err != nil {
return lessThanDataInt{}, errors.WithStack(err)
}
}
}
return lessThanDataInt{lessThan, maxValue}, nil
}

// makePartitionByFnCol extracts the column and function information in 'partition by ... fn(col)'.
func makePartitionByFnCol(sctx sessionctx.Context, columns []*expression.Column, names types.NameSlice, partitionExpr string) (*expression.Column, *expression.ScalarFunction, error) {
schema := expression.NewSchema(columns...)
Expand Down Expand Up @@ -456,7 +454,9 @@ func (p *rangePruner) partitionRangeForExpr(sctx sessionctx.Context, expr expres
if !ok {
return 0, 0, false
}
start, end := pruneUseBinarySearch(p.lessThan, dataForPrune)

unsigned := mysql.HasUnsignedFlag(p.col.RetType.Flag)
start, end := pruneUseBinarySearch(p.lessThan, dataForPrune, unsigned)
return start, end, true
}

Expand Down Expand Up @@ -598,44 +598,44 @@ func relaxOP(op string) string {
return op
}

func pruneUseBinarySearch(lessThan lessThanDataInt, data dataForPrune) (start int, end int) {
func pruneUseBinarySearch(lessThan lessThanDataInt, data dataForPrune, unsigned bool) (start int, end int) {
length := lessThan.length()
switch data.op {
case ast.EQ:
// col = 66, lessThan = [4 7 11 14 17] => [5, 6)
// col = 14, lessThan = [4 7 11 14 17] => [4, 5)
// col = 10, lessThan = [4 7 11 14 17] => [2, 3)
// col = 3, lessThan = [4 7 11 14 17] => [0, 1)
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c) > 0 })
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c, unsigned) > 0 })
start, end = pos, pos+1
case ast.LT:
// col < 66, lessThan = [4 7 11 14 17] => [0, 5)
// col < 14, lessThan = [4 7 11 14 17] => [0, 4)
// col < 10, lessThan = [4 7 11 14 17] => [0, 3)
// col < 3, lessThan = [4 7 11 14 17] => [0, 1)
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c) >= 0 })
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c, unsigned) >= 0 })
start, end = 0, pos+1
case ast.GE:
// col >= 66, lessThan = [4 7 11 14 17] => [5, 5)
// col >= 14, lessThan = [4 7 11 14 17] => [4, 5)
// col >= 10, lessThan = [4 7 11 14 17] => [2, 5)
// col >= 3, lessThan = [4 7 11 14 17] => [0, 5)
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c) > 0 })
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c, unsigned) > 0 })
start, end = pos, length
case ast.GT:
// col > 66, lessThan = [4 7 11 14 17] => [5, 5)
// col > 14, lessThan = [4 7 11 14 17] => [4, 5)
// col > 10, lessThan = [4 7 11 14 17] => [3, 5)
// col > 3, lessThan = [4 7 11 14 17] => [1, 5)
// col > 2, lessThan = [4 7 11 14 17] => [0, 5)
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c+1) > 0 })
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c+1, unsigned) > 0 })
start, end = pos, length
case ast.LE:
// col <= 66, lessThan = [4 7 11 14 17] => [0, 6)
// col <= 14, lessThan = [4 7 11 14 17] => [0, 5)
// col <= 10, lessThan = [4 7 11 14 17] => [0, 3)
// col <= 3, lessThan = [4 7 11 14 17] => [0, 1)
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c) > 0 })
pos := sort.Search(length, func(i int) bool { return lessThan.compare(i, data.c, unsigned) > 0 })
start, end = 0, pos+1
case ast.IsNull:
start, end = 0, 1
Expand Down
75 changes: 53 additions & 22 deletions table/tables/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ package tables

import (
"bytes"
stderr "errors"
"fmt"
"sort"
"strconv"
"strings"

"github.com/pingcap/errors"
Expand Down Expand Up @@ -113,6 +115,46 @@ type PartitionExpr struct {
OrigExpr ast.ExprNode
// Expr is the hash partition expression.
Expr expression.Expression
// Used in the range pruning process.
*ForRangePruning
}

// ForRangePruning is used for range partition pruning.
type ForRangePruning struct {
LessThan []int64
MaxValue bool
Unsigned bool
}

// dataForRangePruning extracts the less than parts from 'partition p0 less than xx ... partitoin p1 less than ...'
func dataForRangePruning(pi *model.PartitionInfo) (*ForRangePruning, error) {
var maxValue bool
var unsigned bool
lessThan := make([]int64, len(pi.Definitions))
for i := 0; i < len(pi.Definitions); i++ {
if strings.EqualFold(pi.Definitions[i].LessThan[0], "MAXVALUE") {
// Use a bool flag instead of math.MaxInt64 to avoid the corner cases.
maxValue = true
} else {
var err error
lessThan[i], err = strconv.ParseInt(pi.Definitions[i].LessThan[0], 10, 64)
var numErr *strconv.NumError
if stderr.As(err, &numErr) && numErr.Err == strconv.ErrRange {
var tmp uint64
tmp, err = strconv.ParseUint(pi.Definitions[i].LessThan[0], 10, 64)
lessThan[i] = int64(tmp)
unsigned = true
}
if err != nil {
return nil, errors.WithStack(err)
}
}
}
return &ForRangePruning{
LessThan: lessThan,
MaxValue: maxValue,
Unsigned: unsigned,
}, nil
}

// rangePartitionString returns the partition string for a range typed partition.
Expand All @@ -134,13 +176,11 @@ func rangePartitionString(pi *model.PartitionInfo) string {
func generateRangePartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo,
columns []*expression.Column, names types.NameSlice) (*PartitionExpr, error) {
// The caller should assure partition info is not nil.
partitionPruneExprs := make([]expression.Expression, 0, len(pi.Definitions))
locateExprs := make([]expression.Expression, 0, len(pi.Definitions))
var buf bytes.Buffer
schema := expression.NewSchema(columns...)
partStr := rangePartitionString(pi)
for i := 0; i < len(pi.Definitions); i++ {

if strings.EqualFold(pi.Definitions[i].LessThan[0], "MAXVALUE") {
// Expr less than maxvalue is always true.
fmt.Fprintf(&buf, "true")
Expand All @@ -155,28 +195,19 @@ func generateRangePartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo,
return nil, errors.Trace(err)
}
locateExprs = append(locateExprs, exprs[0])

if i > 0 {
fmt.Fprintf(&buf, " and ((%s) >= (%s))", partStr, pi.Definitions[i-1].LessThan[0])
} else {
// NULL will locate in the first partition, so its expression is (expr < value or expr is null).
fmt.Fprintf(&buf, " or ((%s) is null)", partStr)
}

exprs, err = expression.ParseSimpleExprsWithNames(ctx, buf.String(), schema, names)
buf.Reset()
}
ret := &PartitionExpr{
UpperBounds: locateExprs,
}
if len(pi.Columns) == 0 {
tmp, err := dataForRangePruning(pi)
if err != nil {
// If it got an error here, ddl may hang forever, so this error log is important.
logutil.BgLogger().Error("wrong table partition expression", zap.String("expression", buf.String()), zap.Error(err))
return nil, errors.Trace(err)
}
// Get a hash code in advance to prevent data race afterwards.
exprs[0].HashCode(ctx.GetSessionVars().StmtCtx)
partitionPruneExprs = append(partitionPruneExprs, exprs[0])
buf.Reset()
ret.ForRangePruning = tmp
}
return &PartitionExpr{
UpperBounds: locateExprs,
}, nil
return ret, nil
}

func generateHashPartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo,
Expand All @@ -201,13 +232,13 @@ func generateHashPartitionExpr(ctx sessionctx.Context, pi *model.PartitionInfo,
}

// PartitionExpr returns the partition expression.
func (t *partitionedTable) PartitionExpr(ctx sessionctx.Context, columns []*expression.Column, names types.NameSlice) (*PartitionExpr, error) {
func (t *partitionedTable) PartitionExpr() (*PartitionExpr, error) {
pi := t.meta.GetPartitionInfo()
switch pi.Type {
case model.PartitionTypeHash:
return t.partitionExpr, nil
case model.PartitionTypeRange:
return generateRangePartitionExpr(ctx, pi, columns, names)
return t.partitionExpr, nil
}
panic("cannot reach here")
}
Expand Down
Loading