Skip to content

Commit

Permalink
expression: fix the collation of functions with json arguments (pingc…
Browse files Browse the repository at this point in the history
  • Loading branch information
YangKeao authored May 11, 2024
1 parent f311d77 commit dcd1fa9
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 14 deletions.
4 changes: 2 additions & 2 deletions pkg/expression/builtin_ilike.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (b *builtinIlikeSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool,
var pattern collate.WildcardPattern
if b.args[1].ConstLevel() >= ConstOnlyInContext && b.args[2].ConstLevel() >= ConstOnlyInContext {
pattern, err = b.patternCache.getOrInitCache(ctx, func() (collate.WildcardPattern, error) {
ret := collate.ConvertAndGetBinCollation(b.collation).Pattern()
ret := collate.ConvertAndGetBinCollator(b.collation).Pattern()
ret.Compile(patternStr, byte(escape))
return ret, nil
})
Expand All @@ -106,7 +106,7 @@ func (b *builtinIlikeSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool,
return 0, true, err
}
} else {
pattern = collate.ConvertAndGetBinCollation(b.collation).Pattern()
pattern = collate.ConvertAndGetBinCollator(b.collation).Pattern()
pattern.Compile(patternStr, byte(escape))
}
return boolToInt64(pattern.DoMatch(valStr)), false, nil
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/builtin_ilike_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (b *builtinIlikeSig) tryToVecMemorize(ctx EvalContext, param *funcParam, es
}

pattern, err := b.patternCache.getOrInitCache(ctx, func() (collate.WildcardPattern, error) {
pattern := collate.ConvertAndGetBinCollation(b.collation).Pattern()
pattern := collate.ConvertAndGetBinCollator(b.collation).Pattern()
pattern.Compile(param.getStringVal(0), byte(escape))
return pattern, nil
})
Expand Down Expand Up @@ -201,7 +201,7 @@ func (b *builtinIlikeSig) vecEvalInt(ctx EvalContext, input *chunk.Chunk, result

pattern, ok := b.tryToVecMemorize(ctx, params[1], escape)
if !ok {
pattern = collate.ConvertAndGetBinCollation(b.collation).Pattern()
pattern = collate.ConvertAndGetBinCollator(b.collation).Pattern()
return b.ilikeWithoutMemorization(pattern, params, rowNum, escape, result)
}

Expand Down
40 changes: 39 additions & 1 deletion pkg/expression/collation.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,45 @@ func CheckAndDeriveCollationFromExprs(ctx BuildContext, funcName string, evalTyp
return nil, illegalMixCollationErr(funcName, args)
}

return ec, nil
return fixStringTypeForMaxLength(funcName, args, ec), nil
}

// fixStringTypeForMaxLength changes the type of string from `VARCHAR` to `MEDIUM BLOB` or `LONG BLOB` according to the max length of
// the argument. However, as TiDB doesn't have `MaxLength` for `FieldType`, this function handles the logic manually for different types. Now it only
// handles the `JSON` type, because in MySQL, `JSON` type has a big max length and will lead to `LONG BLOB` in many situations.
// To learn more about this case, read the discussion under https://github.com/pingcap/tidb/issues/52833
//
// TODO: also consider types other than `JSON`. And also think about when it'll become `MEDIUM BLOB`. This function only handles the collation, but
// not change the type and binary flag.
// TODO: some function will generate big values, like `repeat` and `space`. They should be handled according to the argument if it's a constant.
func fixStringTypeForMaxLength(funcName string, args []Expression, ec *ExprCollation) *ExprCollation {
// Be careful that the `args` is not all arguments of the `funcName`. You should check `deriveCollation` function to see which arguments are passed
// to the `CheckAndDeriveCollationFromExprs` function, and then passed here.
shouldChangeToBin := false

switch funcName {
case ast.Reverse, ast.Lower, ast.Upper, ast.SubstringIndex, ast.Trim, ast.Quote, ast.InsertFunc, ast.Substr, ast.Repeat, ast.Replace:
shouldChangeToBin = args[0].GetType().EvalType() == types.ETJson
case ast.Concat, ast.ConcatWS, ast.Elt, ast.MakeSet:
for _, arg := range args {
if arg.GetType().EvalType() == types.ETJson {
shouldChangeToBin = true
break
}
}
case ast.ExportSet:
if len(args) >= 2 {
shouldChangeToBin = args[0].GetType().EvalType() == types.ETJson || args[1].GetType().EvalType() == types.ETJson
}
if len(args) >= 3 {
shouldChangeToBin = shouldChangeToBin || args[2].GetType().EvalType() == types.ETJson
}
}

if shouldChangeToBin {
ec.Collation = collate.ConvertAndGetBinCollation(ec.Collation)
}
return ec
}

func safeConvert(ctx BuildContext, ec *ExprCollation, args ...Expression) bool {
Expand Down
24 changes: 15 additions & 9 deletions pkg/util/collate/collate.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,23 +334,29 @@ func IsCICollation(collate string) bool {
collate == "utf8mb4_0900_ai_ci"
}

// ConvertAndGetBinCollation converts collator to binary collator
func ConvertAndGetBinCollation(collate string) Collator {
// ConvertAndGetBinCollation converts collation to binary collation
func ConvertAndGetBinCollation(collate string) string {
switch collate {
case "utf8_general_ci":
return GetCollator("utf8_bin")
return "utf8_bin"
case "utf8_unicode_ci":
return GetCollator("utf8_bin")
return "utf8_bin"
case "utf8mb4_general_ci":
return GetCollator("utf8mb4_bin")
return "utf8mb4_bin"
case "utf8mb4_unicode_ci":
return GetCollator("utf8mb4_bin")
return "utf8mb4_bin"
case "utf8mb4_0900_ai_ci":
return GetCollator("utf8mb4_bin")
return "utf8mb4_bin"
case "gbk_chinese_ci":
return GetCollator("gbk_bin")
return "gbk_bin"
}
return GetCollator(collate)

return collate
}

// ConvertAndGetBinCollator converts collation to binary collator
func ConvertAndGetBinCollator(collate string) Collator {
return GetCollator(ConvertAndGetBinCollation(collate))
}

// IsBinCollation returns if the collation is 'xx_bin' or 'bin'.
Expand Down
65 changes: 65 additions & 0 deletions tests/integrationtest/r/expression/charset_and_collation.result
Original file line number Diff line number Diff line change
Expand Up @@ -1991,3 +1991,68 @@ LOCATE('bar' collate utf8mb4_0900_ai_ci, 'FOOBAR' collate utf8mb4_0900_ai_ci)
select 'FOOBAR' collate utf8mb4_0900_ai_ci REGEXP 'foo.*' collate utf8mb4_0900_ai_ci;
'FOOBAR' collate utf8mb4_0900_ai_ci REGEXP 'foo.*' collate utf8mb4_0900_ai_ci
1
set names utf8mb4 collate utf8mb4_0900_ai_ci;
select reverse(cast('[]' as json)) between 'W' and 'm';
reverse(cast('[]' as json)) between 'W' and 'm'
1
select lower(cast('[]' as json)) between 'W' and 'm';
lower(cast('[]' as json)) between 'W' and 'm'
1
select upper(cast('[]' as json)) between 'W' and 'm';
upper(cast('[]' as json)) between 'W' and 'm'
1
select substring_index(cast('[]' as json), '.', 1) between 'W' and 'm';
substring_index(cast('[]' as json), '.', 1) between 'W' and 'm'
1
select trim(cast('[]' as json)) between 'W' and 'm';
trim(cast('[]' as json)) between 'W' and 'm'
1
select quote(cast('[]' as json)) between "'W'" and "'m'";
quote(cast('[]' as json)) between "'W'" and "'m'"
1
select concat(cast('[]' as json), '1') between 'W' and 'm';
concat(cast('[]' as json), '1') between 'W' and 'm'
1
select concat('1', cast('[]' as json)) between '1W' and '1m';
concat('1', cast('[]' as json)) between '1W' and '1m'
1
select concat_ws(cast('[]' as json), '1', '1') between '1W' and '1m';
concat_ws(cast('[]' as json), '1', '1') between '1W' and '1m'
1
select concat_ws('1', cast('[]' as json)) between 'W' and 'm';
concat_ws('1', cast('[]' as json)) between 'W' and 'm'
1
select elt(1, cast('[]' as json), '[]') between 'W' and 'm';
elt(1, cast('[]' as json), '[]') between 'W' and 'm'
1
select elt(2, cast('[]' as json), '[]') between 'W' and 'm';
elt(2, cast('[]' as json), '[]') between 'W' and 'm'
1
select make_set(1, cast('[]' as json), '[]') between 'W' and 'm';
make_set(1, cast('[]' as json), '[]') between 'W' and 'm'
1
select make_set(2, cast('[]' as json), '[]') between 'W' and 'm';
make_set(2, cast('[]' as json), '[]') between 'W' and 'm'
1
select replace(cast('[]' as json), '[]', '[]') between 'W' and 'm';
replace(cast('[]' as json), '[]', '[]') between 'W' and 'm'
1
select replace('[]', '[]', cast('[]' as json)) between 'W' and 'm';
replace('[]', '[]', cast('[]' as json)) between 'W' and 'm'
0
select insert(cast('[]' as json), 0, 100, '[]') between 'W' and 'm';
insert(cast('[]' as json), 0, 100, '[]') between 'W' and 'm'
1
select insert('[]', 0, 100, cast('[]' as json)) between 'W' and 'm';
insert('[]', 0, 100, cast('[]' as json)) between 'W' and 'm'
0
select substr(cast('[]' as json), 1) between 'W' and 'm';
substr(cast('[]' as json), 1) between 'W' and 'm'
1
select repeat(cast('[]' as json), 10) between 'W' and 'm';
repeat(cast('[]' as json), 10) between 'W' and 'm'
1
select export_set(3,cast('[]' as json),'2','-',8) between 'W' and 'm';
export_set(3,cast('[]' as json),'2','-',8) between 'W' and 'm'
1
set names default;
25 changes: 25 additions & 0 deletions tests/integrationtest/t/expression/charset_and_collation.test
Original file line number Diff line number Diff line change
Expand Up @@ -814,3 +814,28 @@ select min(id) from t group by str order by str;
# TestUTF8MB40900AICIStrFunc
select LOCATE('bar' collate utf8mb4_0900_ai_ci, 'FOOBAR' collate utf8mb4_0900_ai_ci);
select 'FOOBAR' collate utf8mb4_0900_ai_ci REGEXP 'foo.*' collate utf8mb4_0900_ai_ci;

# TestCollationWithJSONArg
set names utf8mb4 collate utf8mb4_0900_ai_ci;
select reverse(cast('[]' as json)) between 'W' and 'm';
select lower(cast('[]' as json)) between 'W' and 'm';
select upper(cast('[]' as json)) between 'W' and 'm';
select substring_index(cast('[]' as json), '.', 1) between 'W' and 'm';
select trim(cast('[]' as json)) between 'W' and 'm';
select quote(cast('[]' as json)) between "'W'" and "'m'";
select concat(cast('[]' as json), '1') between 'W' and 'm';
select concat('1', cast('[]' as json)) between '1W' and '1m';
select concat_ws(cast('[]' as json), '1', '1') between '1W' and '1m';
select concat_ws('1', cast('[]' as json)) between 'W' and 'm';
select elt(1, cast('[]' as json), '[]') between 'W' and 'm';
select elt(2, cast('[]' as json), '[]') between 'W' and 'm';
select make_set(1, cast('[]' as json), '[]') between 'W' and 'm';
select make_set(2, cast('[]' as json), '[]') between 'W' and 'm';
select replace(cast('[]' as json), '[]', '[]') between 'W' and 'm';
select replace('[]', '[]', cast('[]' as json)) between 'W' and 'm';
select insert(cast('[]' as json), 0, 100, '[]') between 'W' and 'm';
select insert('[]', 0, 100, cast('[]' as json)) between 'W' and 'm';
select substr(cast('[]' as json), 1) between 'W' and 'm';
select repeat(cast('[]' as json), 10) between 'W' and 'm';
select export_set(3,cast('[]' as json),'2','-',8) between 'W' and 'm';
set names default;

0 comments on commit dcd1fa9

Please sign in to comment.