Skip to content

Commit

Permalink
expression: fix case-sensitive problem for function INSTR and LOCATE (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
sre-bot authored May 13, 2020
1 parent 77d8f23 commit 32db22c
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 33 deletions.
44 changes: 26 additions & 18 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -1399,16 +1399,16 @@ func (c *locateFunctionClass) getFunction(ctx sessionctx.Context, args []Express
return nil, err
}
var sig builtinFunc
// Loacte is multibyte safe, and is case-sensitive only if at least one argument is a binary string.
hasBianryInput := types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[1].GetType())
// Locate is multibyte safe.
useBinary := bf.collation == charset.CollationBin
switch {
case hasStartPos && hasBianryInput:
case hasStartPos && useBinary:
sig = &builtinLocate3ArgsSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_Locate3Args)
case hasStartPos:
sig = &builtinLocate3ArgsUTF8Sig{bf}
sig.setPbCode(tipb.ScalarFuncSig_Locate3ArgsUTF8)
case hasBianryInput:
case useBinary:
sig = &builtinLocate2ArgsSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_Locate2Args)
default:
Expand Down Expand Up @@ -1460,7 +1460,7 @@ func (b *builtinLocate2ArgsUTF8Sig) Clone() builtinFunc {
return newSig
}

// evalInt evals LOCATE(substr,str), non case-sensitive.
// evalInt evals LOCATE(substr,str).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_locate
func (b *builtinLocate2ArgsUTF8Sig) evalInt(row chunk.Row) (int64, bool, error) {
subStr, isNull, err := b.args[0].EvalString(b.ctx, row)
Expand All @@ -1474,10 +1474,13 @@ func (b *builtinLocate2ArgsUTF8Sig) evalInt(row chunk.Row) (int64, bool, error)
if int64(len([]rune(subStr))) == 0 {
return 1, false, nil
}
slice := string([]rune(strings.ToLower(str)))
ret, idx := 0, strings.Index(slice, strings.ToLower(subStr))
if collate.IsCICollation(b.collation) {
str = strings.ToLower(str)
subStr = strings.ToLower(subStr)
}
ret, idx := 0, strings.Index(str, subStr)
if idx != -1 {
ret = utf8.RuneCountInString(slice[:idx]) + 1
ret = utf8.RuneCountInString(str[:idx]) + 1
}
return int64(ret), false, nil
}
Expand Down Expand Up @@ -1533,7 +1536,7 @@ func (b *builtinLocate3ArgsUTF8Sig) Clone() builtinFunc {
return newSig
}

// evalInt evals LOCATE(substr,str,pos), non case-sensitive.
// evalInt evals LOCATE(substr,str,pos).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_locate
func (b *builtinLocate3ArgsUTF8Sig) evalInt(row chunk.Row) (int64, bool, error) {
subStr, isNull, err := b.args[0].EvalString(b.ctx, row)
Expand All @@ -1544,20 +1547,24 @@ func (b *builtinLocate3ArgsUTF8Sig) evalInt(row chunk.Row) (int64, bool, error)
if isNull || err != nil {
return 0, isNull, err
}
if collate.IsCICollation(b.collation) {
subStr = strings.ToLower(subStr)
str = strings.ToLower(str)
}
pos, isNull, err := b.args[2].EvalInt(b.ctx, row)
// Transfer the argument which starts from 1 to real index which starts from 0.
pos--
if isNull || err != nil {
return 0, isNull, err
}
subStrLen := len([]rune(subStr))
if pos < 0 || pos > int64(len([]rune(strings.ToLower(str)))-subStrLen) {
if pos < 0 || pos > int64(len([]rune(str))-subStrLen) {
return 0, false, nil
} else if subStrLen == 0 {
return pos + 1, false, nil
}
slice := string([]rune(strings.ToLower(str))[pos:])
idx := strings.Index(slice, strings.ToLower(subStr))
slice := string([]rune(str)[pos:])
idx := strings.Index(slice, subStr)
if idx != -1 {
return pos + int64(utf8.RuneCountInString(slice[:idx])) + 1, false, nil
}
Expand Down Expand Up @@ -3722,7 +3729,7 @@ func (c *instrFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
return nil, err
}
bf.tp.Flen = 11
if types.IsBinaryStr(bf.args[0].GetType()) || types.IsBinaryStr(bf.args[1].GetType()) {
if bf.collation == charset.CollationBin {
sig := &builtinInstrSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_Instr)
return sig, nil
Expand All @@ -3748,20 +3755,21 @@ func (b *builtinInstrSig) Clone() builtinFunc {
return newSig
}

// evalInt evals INSTR(str,substr), case insensitive
// evalInt evals INSTR(str,substr).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_instr
func (b *builtinInstrUTF8Sig) evalInt(row chunk.Row) (int64, bool, error) {
str, IsNull, err := b.args[0].EvalString(b.ctx, row)
if IsNull || err != nil {
return 0, true, err
}
str = strings.ToLower(str)

substr, IsNull, err := b.args[1].EvalString(b.ctx, row)
if IsNull || err != nil {
return 0, true, err
}
substr = strings.ToLower(substr)
if collate.IsCICollation(b.collation) {
str = strings.ToLower(str)
substr = strings.ToLower(substr)
}

idx := strings.Index(str, substr)
if idx == -1 {
Expand All @@ -3770,7 +3778,7 @@ func (b *builtinInstrUTF8Sig) evalInt(row chunk.Row) (int64, bool, error) {
return int64(utf8.RuneCountInString(str[:idx]) + 1), false, nil
}

// evalInt evals INSTR(str,substr), case sensitive
// evalInt evals INSTR(str,substr), case sensitive.
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_instr
func (b *builtinInstrSig) evalInt(row chunk.Row) (int64, bool, error) {
str, IsNull, err := b.args[0].EvalString(b.ctx, row)
Expand Down
12 changes: 6 additions & 6 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ func (s *testEvaluatorSuite) TestLocate(c *C) {
{[]interface{}{"好世", "你好世界"}, 2},
{[]interface{}{"界面", "你好世界"}, 0},
{[]interface{}{"b", "中a英b文"}, 4},
{[]interface{}{"BaR", "foobArbar"}, 4},
{[]interface{}{"bAr", "foobArbar"}, 4},
{[]interface{}{nil, "foobar"}, nil},
{[]interface{}{"bar", nil}, nil},
{[]interface{}{"bar", "foobarbar", 5}, 7},
Expand All @@ -957,7 +957,7 @@ func (s *testEvaluatorSuite) TestLocate(c *C) {
{[]interface{}{"A", "大A写的A", 1}, 2},
{[]interface{}{"A", "大A写的A", 2}, 2},
{[]interface{}{"A", "大A写的A", 3}, 5},
{[]interface{}{"bAr", "foobarBaR", 5}, 7},
{[]interface{}{"BaR", "foobarBaR", 5}, 7},
{[]interface{}{nil, nil}, nil},
{[]interface{}{"", nil}, nil},
{[]interface{}{nil, ""}, nil},
Expand Down Expand Up @@ -1624,11 +1624,11 @@ func (s *testEvaluatorSuite) TestInstr(c *C) {
{[]interface{}{"中文美好", "世界"}, 0},
{[]interface{}{"中文abc", "a"}, 3},

{[]interface{}{"live LONG and prosper", "long"}, 6},
{[]interface{}{"live long and prosper", "long"}, 6},

{[]interface{}{"not BINARY string", "binary"}, 5},
{[]interface{}{"UPPER case", "upper"}, 1},
{[]interface{}{"UPPER case", "CASE"}, 7},
{[]interface{}{"not binary string", "binary"}, 5},
{[]interface{}{"upper case", "upper"}, 1},
{[]interface{}{"UPPER CASE", "CASE"}, 7},
{[]interface{}{"中文abc", "abc"}, 3},

{[]interface{}{"foobar", nil}, nil},
Expand Down
33 changes: 24 additions & 9 deletions expression/builtin_string_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/logutil"
"go.uber.org/zap"
"golang.org/x/text/transform"
Expand Down Expand Up @@ -366,7 +367,7 @@ func (b *builtinLocate3ArgsUTF8Sig) vectorized() bool {
return true
}

// vecEvalInt evals LOCATE(substr,str,pos), non case-sensitive.
// vecEvalInt evals LOCATE(substr,str,pos).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_locate
func (b *builtinLocate3ArgsUTF8Sig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
Expand All @@ -393,6 +394,7 @@ func (b *builtinLocate3ArgsUTF8Sig) vecEvalInt(input *chunk.Chunk, result *chunk

result.MergeNulls(buf, buf1)
i64s := result.Int64s()
ci := collate.IsCICollation(b.collation)
for i := 0; i < n; i++ {
if result.IsNull(i) {
continue
Expand All @@ -413,8 +415,10 @@ func (b *builtinLocate3ArgsUTF8Sig) vecEvalInt(input *chunk.Chunk, result *chunk
continue
}
slice := string([]rune(str)[pos:])
subStr = strings.ToLower(subStr)
slice = strings.ToLower(slice)
if ci {
subStr = strings.ToLower(subStr)
slice = strings.ToLower(slice)
}
idx := strings.Index(slice, subStr)
if idx != -1 {
i64s[i] = pos + int64(utf8.RuneCountInString(slice[:idx])) + 1
Expand Down Expand Up @@ -1631,12 +1635,20 @@ func (b *builtinInstrUTF8Sig) vecEvalInt(input *chunk.Chunk, result *chunk.Colum
result.ResizeInt64(n, false)
result.MergeNulls(str, substr)
res := result.Int64s()
ci := collate.IsCICollation(b.collation)
var strI string
var substrI string
for i := 0; i < n; i++ {
if result.IsNull(i) {
continue
}
strI := strings.ToLower(str.GetString(i))
substrI := strings.ToLower(substr.GetString(i))
if ci {
strI = strings.ToLower(str.GetString(i))
substrI = strings.ToLower(substr.GetString(i))
} else {
strI = str.GetString(i)
substrI = substr.GetString(i)
}
idx := strings.Index(strI, substrI)
if idx == -1 {
res[i] = 0
Expand Down Expand Up @@ -2126,7 +2138,7 @@ func (b *builtinLocate2ArgsUTF8Sig) vectorized() bool {
return true
}

// vecEvalInt evals LOCATE(substr,str), non case-sensitive.
// vecEvalInt evals LOCATE(substr,str).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_locate
func (b *builtinLocate2ArgsUTF8Sig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
Expand All @@ -2150,6 +2162,7 @@ func (b *builtinLocate2ArgsUTF8Sig) vecEvalInt(input *chunk.Chunk, result *chunk
result.ResizeInt64(n, false)
result.MergeNulls(buf, buf1)
i64s := result.Int64s()
ci := collate.IsCICollation(b.collation)
for i := 0; i < n; i++ {
if result.IsNull(i) {
continue
Expand All @@ -2161,9 +2174,11 @@ func (b *builtinLocate2ArgsUTF8Sig) vecEvalInt(input *chunk.Chunk, result *chunk
i64s[i] = 1
continue
}
slice := string([]rune(str))
slice = strings.ToLower(slice)
subStr = strings.ToLower(subStr)
slice := str
if ci {
slice = strings.ToLower(slice)
subStr = strings.ToLower(subStr)
}
idx := strings.Index(slice, subStr)
if idx != -1 {
i64s[i] = int64(utf8.RuneCountInString(slice[:idx])) + 1
Expand Down
45 changes: 45 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6108,6 +6108,51 @@ func (s *testIntegrationSerialSuite) TestCollateStringFunction(c *C) {
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a char)")
tk.MustGetErrMsg("select * from t t1 join t t2 on t1.a collate utf8mb4_bin = t2.a collate utf8mb4_general_ci;", "[expression:1267]Illegal mix of collations (utf8mb4_bin,EXPLICIT) and (utf8mb4_general_ci,EXPLICIT) for operation 'eq'")

tk.MustExec("DROP TABLE IF EXISTS t1;")
tk.MustExec("CREATE TABLE t1 ( a int, p1 VARCHAR(255) CHARACTER SET utf8 COLLATE utf8_bin,p2 VARCHAR(255) CHARACTER SET utf8 COLLATE utf8_general_ci , p3 VARCHAR(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin,p4 VARCHAR(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci ,n1 VARCHAR(255) CHARACTER SET utf8 COLLATE utf8_bin,n2 VARCHAR(255) CHARACTER SET utf8 COLLATE utf8_general_ci , n3 VARCHAR(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin,n4 VARCHAR(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci );")
tk.MustExec("insert into t1 (a,p1,p2,p3,p4,n1,n2,n3,n4) values(1,' 0aA1!测试テストמבחן ',' 0aA1!测试テストמבחן ',' 0aA1!测试テストמבחן ',' 0aA1!测试テストמבחן ',' 0Aa1!测试テストמבחן ',' 0Aa1!测试テストמבחן ',' 0Aa1!测试テストמבחן ',' 0Aa1!测试テストמבחן ');")

tk.MustQuery("select INSTR(p1,n1) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p1,n2) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p1,n3) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p1,n4) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p2,n1) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p2,n2) from t1;").Check(testkit.Rows("1"))
tk.MustQuery("select INSTR(p2,n3) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p2,n4) from t1;").Check(testkit.Rows("1"))
tk.MustQuery("select INSTR(p3,n1) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p3,n2) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p3,n3) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p3,n4) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p4,n1) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p4,n2) from t1;").Check(testkit.Rows("1"))
tk.MustQuery("select INSTR(p4,n3) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p4,n4) from t1;").Check(testkit.Rows("1"))

tk.MustExec("truncate table t1;")
tk.MustExec("insert into t1 (a,p1,p2,p3,p4,n1,n2,n3,n4) values (2,'0aA1!测试テストמבחן ','0aA1!测试テストמבחן ','0aA1!测试テストמבחן ','0aA1!测试テストמבחן ','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן');")
tk.MustExec("insert into t1 (a,p1,p2,p3,p4,n1,n2,n3,n4) values (2,'0aA1!测试テストמבחן','0aA1!测试テストמבחן','0aA1!测试テストמבחן','0aA1!测试テストמבחן','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן');")
tk.MustExec("insert into t1 (a,p1,p2,p3,p4,n1,n2,n3,n4) values (3,'0aA1!测试テストמבחן','0aA1!测试テストמבחן','0aA1!测试テストמבחן','0aA1!测试テストמבחן','0Aa1!测试テストמבחן ','0Aa1!测试テストמבחן ','0Aa1!测试テストמבחן ','0Aa1!测试テストמבחן ');")

tk.MustQuery("select LOCATE(p1,n1) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p1,n2) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p1,n3) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p1,n4) from t1;").Check(testkit.Rows("0", "1", "1"))
tk.MustQuery("select LOCATE(p2,n1) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p2,n2) from t1;").Check(testkit.Rows("0", "1", "1"))
tk.MustQuery("select LOCATE(p2,n3) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p2,n4) from t1;").Check(testkit.Rows("0", "1", "1"))
tk.MustQuery("select LOCATE(p3,n1) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p3,n2) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p3,n3) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p3,n4) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p4,n1) from t1;").Check(testkit.Rows("0", "1", "1"))
tk.MustQuery("select LOCATE(p4,n2) from t1;").Check(testkit.Rows("0", "1", "1"))
tk.MustQuery("select LOCATE(p4,n3) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p4,n4) from t1;").Check(testkit.Rows("0", "1", "1"))

tk.MustExec("drop table t1;")
}

func (s *testIntegrationSerialSuite) TestCollateLike(c *C) {
Expand Down
5 changes: 5 additions & 0 deletions util/collate/collate.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ func truncateTailingSpace(str string) string {
return str
}

// IsCICollation returns if the collation is case-sensitive
func IsCICollation(collate string) bool {
return collate == "utf8_general_ci" || collate == "utf8mb4_general_ci"
}

func init() {
newCollatorMap = make(map[string]Collator)
newCollatorIDMap = make(map[int]Collator)
Expand Down

0 comments on commit 32db22c

Please sign in to comment.