Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: fix case-sensitive problem for function INSTR and LOCATE (#16792) #17068

Merged
merged 5 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -1349,16 +1349,16 @@ func (c *locateFunctionClass) getFunction(ctx sessionctx.Context, args []Express
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, argTps...)
var sig builtinFunc
// Loacte is multibyte safe, and is case-sensitive only if at least one argument is a binary string.
hasBianryInput := types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[1].GetType())
// Locate is multibyte safe.
useBinary := bf.collation == charset.CollationBin
switch {
case hasStartPos && hasBianryInput:
case hasStartPos && useBinary:
sig = &builtinLocate3ArgsSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_Locate3Args)
case hasStartPos:
sig = &builtinLocate3ArgsUTF8Sig{bf}
sig.setPbCode(tipb.ScalarFuncSig_Locate3ArgsUTF8)
case hasBianryInput:
case useBinary:
sig = &builtinLocate2ArgsSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_Locate2Args)
default:
Expand Down Expand Up @@ -1410,7 +1410,7 @@ func (b *builtinLocate2ArgsUTF8Sig) Clone() builtinFunc {
return newSig
}

// evalInt evals LOCATE(substr,str), non case-sensitive.
// evalInt evals LOCATE(substr,str).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_locate
func (b *builtinLocate2ArgsUTF8Sig) evalInt(row chunk.Row) (int64, bool, error) {
subStr, isNull, err := b.args[0].EvalString(b.ctx, row)
Expand All @@ -1424,10 +1424,13 @@ func (b *builtinLocate2ArgsUTF8Sig) evalInt(row chunk.Row) (int64, bool, error)
if int64(len([]rune(subStr))) == 0 {
return 1, false, nil
}
slice := string([]rune(strings.ToLower(str)))
ret, idx := 0, strings.Index(slice, strings.ToLower(subStr))
if collate.IsCICollation(b.collation) {
str = strings.ToLower(str)
subStr = strings.ToLower(subStr)
}
ret, idx := 0, strings.Index(str, subStr)
if idx != -1 {
ret = utf8.RuneCountInString(slice[:idx]) + 1
ret = utf8.RuneCountInString(str[:idx]) + 1
}
return int64(ret), false, nil
}
Expand Down Expand Up @@ -1483,7 +1486,7 @@ func (b *builtinLocate3ArgsUTF8Sig) Clone() builtinFunc {
return newSig
}

// evalInt evals LOCATE(substr,str,pos), non case-sensitive.
// evalInt evals LOCATE(substr,str,pos).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_locate
func (b *builtinLocate3ArgsUTF8Sig) evalInt(row chunk.Row) (int64, bool, error) {
subStr, isNull, err := b.args[0].EvalString(b.ctx, row)
Expand All @@ -1494,20 +1497,24 @@ func (b *builtinLocate3ArgsUTF8Sig) evalInt(row chunk.Row) (int64, bool, error)
if isNull || err != nil {
return 0, isNull, err
}
if collate.IsCICollation(b.collation) {
subStr = strings.ToLower(subStr)
str = strings.ToLower(str)
}
pos, isNull, err := b.args[2].EvalInt(b.ctx, row)
// Transfer the argument which starts from 1 to real index which starts from 0.
pos--
if isNull || err != nil {
return 0, isNull, err
}
subStrLen := len([]rune(subStr))
if pos < 0 || pos > int64(len([]rune(strings.ToLower(str)))-subStrLen) {
if pos < 0 || pos > int64(len([]rune(str))-subStrLen) {
return 0, false, nil
} else if subStrLen == 0 {
return pos + 1, false, nil
}
slice := string([]rune(strings.ToLower(str))[pos:])
idx := strings.Index(slice, strings.ToLower(subStr))
slice := string([]rune(str)[pos:])
idx := strings.Index(slice, subStr)
if idx != -1 {
return pos + int64(utf8.RuneCountInString(slice[:idx])) + 1, false, nil
}
Expand Down Expand Up @@ -3588,7 +3595,7 @@ func (c *instrFunctionClass) getFunction(ctx sessionctx.Context, args []Expressi
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETString, types.ETString)
bf.tp.Flen = 11
if types.IsBinaryStr(bf.args[0].GetType()) || types.IsBinaryStr(bf.args[1].GetType()) {
if bf.collation == charset.CollationBin {
sig := &builtinInstrSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_Instr)
return sig, nil
Expand All @@ -3614,20 +3621,21 @@ func (b *builtinInstrSig) Clone() builtinFunc {
return newSig
}

// evalInt evals INSTR(str,substr), case insensitive
// evalInt evals INSTR(str,substr).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_instr
func (b *builtinInstrUTF8Sig) evalInt(row chunk.Row) (int64, bool, error) {
str, IsNull, err := b.args[0].EvalString(b.ctx, row)
if IsNull || err != nil {
return 0, true, err
}
str = strings.ToLower(str)

substr, IsNull, err := b.args[1].EvalString(b.ctx, row)
if IsNull || err != nil {
return 0, true, err
}
substr = strings.ToLower(substr)
if collate.IsCICollation(b.collation) {
str = strings.ToLower(str)
substr = strings.ToLower(substr)
}

idx := strings.Index(str, substr)
if idx == -1 {
Expand All @@ -3636,7 +3644,7 @@ func (b *builtinInstrUTF8Sig) evalInt(row chunk.Row) (int64, bool, error) {
return int64(utf8.RuneCountInString(str[:idx]) + 1), false, nil
}

// evalInt evals INSTR(str,substr), case sensitive
// evalInt evals INSTR(str,substr), case sensitive.
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_instr
func (b *builtinInstrSig) evalInt(row chunk.Row) (int64, bool, error) {
str, IsNull, err := b.args[0].EvalString(b.ctx, row)
Expand Down
12 changes: 6 additions & 6 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ func (s *testEvaluatorSuite) TestLocate(c *C) {
{[]interface{}{"好世", "你好世界"}, 2},
{[]interface{}{"界面", "你好世界"}, 0},
{[]interface{}{"b", "中a英b文"}, 4},
{[]interface{}{"BaR", "foobArbar"}, 4},
{[]interface{}{"bAr", "foobArbar"}, 4},
{[]interface{}{nil, "foobar"}, nil},
{[]interface{}{"bar", nil}, nil},
{[]interface{}{"bar", "foobarbar", 5}, 7},
Expand All @@ -957,7 +957,7 @@ func (s *testEvaluatorSuite) TestLocate(c *C) {
{[]interface{}{"A", "大A写的A", 1}, 2},
{[]interface{}{"A", "大A写的A", 2}, 2},
{[]interface{}{"A", "大A写的A", 3}, 5},
{[]interface{}{"bAr", "foobarBaR", 5}, 7},
{[]interface{}{"BaR", "foobarBaR", 5}, 7},
{[]interface{}{nil, nil}, nil},
{[]interface{}{"", nil}, nil},
{[]interface{}{nil, ""}, nil},
Expand Down Expand Up @@ -1624,11 +1624,11 @@ func (s *testEvaluatorSuite) TestInstr(c *C) {
{[]interface{}{"中文美好", "世界"}, 0},
{[]interface{}{"中文abc", "a"}, 3},

{[]interface{}{"live LONG and prosper", "long"}, 6},
{[]interface{}{"live long and prosper", "long"}, 6},

{[]interface{}{"not BINARY string", "binary"}, 5},
{[]interface{}{"UPPER case", "upper"}, 1},
{[]interface{}{"UPPER case", "CASE"}, 7},
{[]interface{}{"not binary string", "binary"}, 5},
{[]interface{}{"upper case", "upper"}, 1},
{[]interface{}{"UPPER CASE", "CASE"}, 7},
{[]interface{}{"中文abc", "abc"}, 3},

{[]interface{}{"foobar", nil}, nil},
Expand Down
33 changes: 24 additions & 9 deletions expression/builtin_string_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/logutil"
"go.uber.org/zap"
"golang.org/x/text/transform"
Expand Down Expand Up @@ -366,7 +367,7 @@ func (b *builtinLocate3ArgsUTF8Sig) vectorized() bool {
return true
}

// vecEvalInt evals LOCATE(substr,str,pos), non case-sensitive.
// vecEvalInt evals LOCATE(substr,str,pos).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_locate
func (b *builtinLocate3ArgsUTF8Sig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
Expand All @@ -393,6 +394,7 @@ func (b *builtinLocate3ArgsUTF8Sig) vecEvalInt(input *chunk.Chunk, result *chunk

result.MergeNulls(buf, buf1)
i64s := result.Int64s()
ci := collate.IsCICollation(b.collation)
for i := 0; i < n; i++ {
if result.IsNull(i) {
continue
Expand All @@ -413,8 +415,10 @@ func (b *builtinLocate3ArgsUTF8Sig) vecEvalInt(input *chunk.Chunk, result *chunk
continue
}
slice := string([]rune(str)[pos:])
subStr = strings.ToLower(subStr)
slice = strings.ToLower(slice)
if ci {
subStr = strings.ToLower(subStr)
slice = strings.ToLower(slice)
}
idx := strings.Index(slice, subStr)
if idx != -1 {
i64s[i] = pos + int64(utf8.RuneCountInString(slice[:idx])) + 1
Expand Down Expand Up @@ -1631,12 +1635,20 @@ func (b *builtinInstrUTF8Sig) vecEvalInt(input *chunk.Chunk, result *chunk.Colum
result.ResizeInt64(n, false)
result.MergeNulls(str, substr)
res := result.Int64s()
ci := collate.IsCICollation(b.collation)
var strI string
var substrI string
for i := 0; i < n; i++ {
if result.IsNull(i) {
continue
}
strI := strings.ToLower(str.GetString(i))
substrI := strings.ToLower(substr.GetString(i))
if ci {
strI = strings.ToLower(str.GetString(i))
substrI = strings.ToLower(substr.GetString(i))
} else {
strI = str.GetString(i)
substrI = substr.GetString(i)
}
idx := strings.Index(strI, substrI)
if idx == -1 {
res[i] = 0
Expand Down Expand Up @@ -2126,7 +2138,7 @@ func (b *builtinLocate2ArgsUTF8Sig) vectorized() bool {
return true
}

// vecEvalInt evals LOCATE(substr,str), non case-sensitive.
// vecEvalInt evals LOCATE(substr,str).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_locate
func (b *builtinLocate2ArgsUTF8Sig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
Expand All @@ -2150,6 +2162,7 @@ func (b *builtinLocate2ArgsUTF8Sig) vecEvalInt(input *chunk.Chunk, result *chunk
result.ResizeInt64(n, false)
result.MergeNulls(buf, buf1)
i64s := result.Int64s()
ci := collate.IsCICollation(b.collation)
for i := 0; i < n; i++ {
if result.IsNull(i) {
continue
Expand All @@ -2161,9 +2174,11 @@ func (b *builtinLocate2ArgsUTF8Sig) vecEvalInt(input *chunk.Chunk, result *chunk
i64s[i] = 1
continue
}
slice := string([]rune(str))
slice = strings.ToLower(slice)
subStr = strings.ToLower(subStr)
slice := str
if ci {
slice = strings.ToLower(slice)
subStr = strings.ToLower(subStr)
}
idx := strings.Index(slice, subStr)
if idx != -1 {
i64s[i] = int64(utf8.RuneCountInString(slice[:idx])) + 1
Expand Down
45 changes: 45 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6108,6 +6108,51 @@ func (s *testIntegrationSerialSuite) TestCollateStringFunction(c *C) {
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a char)")
tk.MustGetErrMsg("select * from t t1 join t t2 on t1.a collate utf8mb4_bin = t2.a collate utf8mb4_general_ci;", "[expression:1267]Illegal mix of collations (utf8mb4_bin,EXPLICIT) and (utf8mb4_general_ci,EXPLICIT) for operation 'eq'")

tk.MustExec("DROP TABLE IF EXISTS t1;")
tk.MustExec("CREATE TABLE t1 ( a int, p1 VARCHAR(255) CHARACTER SET utf8 COLLATE utf8_bin,p2 VARCHAR(255) CHARACTER SET utf8 COLLATE utf8_general_ci , p3 VARCHAR(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin,p4 VARCHAR(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci ,n1 VARCHAR(255) CHARACTER SET utf8 COLLATE utf8_bin,n2 VARCHAR(255) CHARACTER SET utf8 COLLATE utf8_general_ci , n3 VARCHAR(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin,n4 VARCHAR(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci );")
tk.MustExec("insert into t1 (a,p1,p2,p3,p4,n1,n2,n3,n4) values(1,' 0aA1!测试テストמבחן ',' 0aA1!测试テストמבחן ',' 0aA1!测试テストמבחן ',' 0aA1!测试テストמבחן ',' 0Aa1!测试テストמבחן ',' 0Aa1!测试テストמבחן ',' 0Aa1!测试テストמבחן ',' 0Aa1!测试テストמבחן ');")

tk.MustQuery("select INSTR(p1,n1) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p1,n2) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p1,n3) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p1,n4) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p2,n1) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p2,n2) from t1;").Check(testkit.Rows("1"))
tk.MustQuery("select INSTR(p2,n3) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p2,n4) from t1;").Check(testkit.Rows("1"))
tk.MustQuery("select INSTR(p3,n1) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p3,n2) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p3,n3) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p3,n4) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p4,n1) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p4,n2) from t1;").Check(testkit.Rows("1"))
tk.MustQuery("select INSTR(p4,n3) from t1;").Check(testkit.Rows("0"))
tk.MustQuery("select INSTR(p4,n4) from t1;").Check(testkit.Rows("1"))

tk.MustExec("truncate table t1;")
tk.MustExec("insert into t1 (a,p1,p2,p3,p4,n1,n2,n3,n4) values (2,'0aA1!测试テストמבחן ','0aA1!测试テストמבחן ','0aA1!测试テストמבחן ','0aA1!测试テストמבחן ','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן');")
tk.MustExec("insert into t1 (a,p1,p2,p3,p4,n1,n2,n3,n4) values (2,'0aA1!测试テストמבחן','0aA1!测试テストמבחן','0aA1!测试テストמבחן','0aA1!测试テストמבחן','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן');")
tk.MustExec("insert into t1 (a,p1,p2,p3,p4,n1,n2,n3,n4) values (3,'0aA1!测试テストמבחן','0aA1!测试テストמבחן','0aA1!测试テストמבחן','0aA1!测试テストמבחן','0Aa1!测试テストמבחן ','0Aa1!测试テストמבחן ','0Aa1!测试テストמבחן ','0Aa1!测试テストמבחן ');")

tk.MustQuery("select LOCATE(p1,n1) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p1,n2) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p1,n3) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p1,n4) from t1;").Check(testkit.Rows("0", "1", "1"))
tk.MustQuery("select LOCATE(p2,n1) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p2,n2) from t1;").Check(testkit.Rows("0", "1", "1"))
tk.MustQuery("select LOCATE(p2,n3) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p2,n4) from t1;").Check(testkit.Rows("0", "1", "1"))
tk.MustQuery("select LOCATE(p3,n1) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p3,n2) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p3,n3) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p3,n4) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p4,n1) from t1;").Check(testkit.Rows("0", "1", "1"))
tk.MustQuery("select LOCATE(p4,n2) from t1;").Check(testkit.Rows("0", "1", "1"))
tk.MustQuery("select LOCATE(p4,n3) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p4,n4) from t1;").Check(testkit.Rows("0", "1", "1"))

tk.MustExec("drop table t1;")
}

func (s *testIntegrationSerialSuite) TestCollateLike(c *C) {
Expand Down
5 changes: 5 additions & 0 deletions util/collate/collate.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ func truncateTailingSpace(str string) string {
return str
}

// IsCICollation returns if the collation is case-sensitive
func IsCICollation(collate string) bool {
return collate == "utf8_general_ci" || collate == "utf8mb4_general_ci"
}

func init() {
newCollatorMap = make(map[string]Collator)
newCollatorIDMap = make(map[int]Collator)
Expand Down