Skip to content

Commit

Permalink
fixes bugs around expression precedence and LIKE (#16934)
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
Signed-off-by: Manan Gupta <manan@planetscale.com>
Co-authored-by: Manan Gupta <manan@planetscale.com>
  • Loading branch information
systay and GuptaManan100 committed Oct 14, 2024
1 parent 7797b49 commit add5652
Show file tree
Hide file tree
Showing 18 changed files with 149 additions and 110 deletions.
30 changes: 30 additions & 0 deletions go/test/endtoend/vtgate/vitess_tester/expressions/expressions.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# This file contains queries that test expressions in Vitess.
# We've found a number of bugs around precedences that we want to test.
CREATE TABLE t0
(
c1 BIT,
INDEX idx_c1 (c1)
);

INSERT INTO t0(c1)
VALUES ('');


SELECT *
FROM t0;

SELECT ((t0.c1 = 'a'))
FROM t0;

SELECT *
FROM t0
WHERE ((t0.c1 = 'a'));


SELECT (1 LIKE ('a' IS NULL));
SELECT (NOT (1 LIKE ('a' IS NULL)));

SELECT (~ (1 || 0)) IS NULL;

SELECT 1
WHERE (~ (1 || 0)) IS NULL;
2 changes: 1 addition & 1 deletion go/vt/sqlparser/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func ASTToStatementType(stmt Statement) StatementType {
// CanNormalize takes Statement and returns if the statement can be normalized.
func CanNormalize(stmt Statement) bool {
switch stmt.(type) {
case *Select, *Union, *Insert, *Update, *Delete, *Set, *CallProc, *Stream: // TODO: we could merge this logic into ASTrewriter
case *Select, *Union, *Insert, *Update, *Delete, *Set, *CallProc, *Stream, *VExplainStmt: // TODO: we could merge this logic into ASTrewriter
return true
}
return false
Expand Down
4 changes: 0 additions & 4 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1468,10 +1468,6 @@ func (op BinaryExprOperator) ToString() string {
return ShiftLeftStr
case ShiftRightOp:
return ShiftRightStr
case JSONExtractOp:
return JSONExtractOpStr
case JSONUnquoteExtractOp:
return JSONUnquoteExtractOpStr
default:
return "Unknown BinaryExprOperator"
}
Expand Down
26 changes: 11 additions & 15 deletions go/vt/sqlparser/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,17 @@ const (
IsNotFalseStr = "is not false"

// BinaryExpr.Operator
BitAndStr = "&"
BitOrStr = "|"
BitXorStr = "^"
PlusStr = "+"
MinusStr = "-"
MultStr = "*"
DivStr = "/"
IntDivStr = "div"
ModStr = "%"
ShiftLeftStr = "<<"
ShiftRightStr = ">>"
JSONExtractOpStr = "->"
JSONUnquoteExtractOpStr = "->>"
BitAndStr = "&"
BitOrStr = "|"
BitXorStr = "^"
PlusStr = "+"
MinusStr = "-"
MultStr = "*"
DivStr = "/"
IntDivStr = "div"
ModStr = "%"
ShiftLeftStr = "<<"
ShiftRightStr = ">>"

// UnaryExpr.Operator
UPlusStr = "+"
Expand Down Expand Up @@ -727,8 +725,6 @@ const (
ModOp
ShiftLeftOp
ShiftRightOp
JSONExtractOp
JSONUnquoteExtractOp
)

// Constant for Enum Type - UnaryExprOperator
Expand Down
8 changes: 5 additions & 3 deletions go/vt/sqlparser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1005,9 +1005,11 @@ var (
}, {
input: "select /* u~ */ 1 from t where a = ~b",
}, {
input: "select /* -> */ a.b -> 'ab' from t",
input: "select /* -> */ a.b -> 'ab' from t",
output: "select /* -> */ json_extract(a.b, 'ab') from t",
}, {
input: "select /* -> */ a.b ->> 'ab' from t",
input: "select /* -> */ a.b ->> 'ab' from t",
output: "select /* -> */ json_unquote(json_extract(a.b, 'ab')) from t",
}, {
input: "select /* empty function */ 1 from t where a = b()",
}, {
Expand Down Expand Up @@ -5772,7 +5774,7 @@ partition by range (YEAR(purchased)) subpartition by hash (TO_DAYS(purchased))
},
{
input: "create table t (id int, info JSON, INDEX zips((CAST(info->'$.field' AS unsigned ARRAY))))",
output: "create table t (\n\tid int,\n\tinfo JSON,\n\tkey zips ((cast(info -> '$.field' as unsigned array)))\n)",
output: "create table t (\n\tid int,\n\tinfo JSON,\n\tkey zips ((cast(json_extract(info, '$.field') as unsigned array)))\n)",
},
}
parser := NewTestParser()
Expand Down
5 changes: 1 addition & 4 deletions go/vt/sqlparser/precedence.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ func precedenceFor(in Expr) Precendence {
case *BetweenExpr:
return P12
case *ComparisonExpr:
switch node.Operator {
case EqualOp, NotEqualOp, GreaterThanOp, GreaterEqualOp, LessThanOp, LessEqualOp, LikeOp, InOp, RegexpOp, NullSafeEqualOp:
return P11
}
return P11
case *IsExpr:
return P11
case *BinaryExpr:
Expand Down
2 changes: 2 additions & 0 deletions go/vt/sqlparser/precedence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ func TestParens(t *testing.T) {
{in: "(10 - 2) - 1", expected: "10 - 2 - 1"},
{in: "10 - (2 - 1)", expected: "10 - (2 - 1)"},
{in: "0 <=> (1 and 0)", expected: "0 <=> (1 and 0)"},
{in: "1 not like ('a' is null)", expected: "1 not like ('a' is null)"},
{in: ":vtg1 not like (:vtg2 is null)", expected: ":vtg1 not like (:vtg2 is null)"},
}

parser := NewTestParser()
Expand Down
4 changes: 2 additions & 2 deletions go/vt/sqlparser/sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions go/vt/sqlparser/sql.y
Original file line number Diff line number Diff line change
Expand Up @@ -5513,11 +5513,11 @@ function_call_keyword
}
| column_name_or_offset JSON_EXTRACT_OP text_literal_or_arg
{
$$ = &BinaryExpr{Left: $1, Operator: JSONExtractOp, Right: $3}
$$ = &JSONExtractExpr{JSONDoc: $1, PathList: []Expr{$3}}
}
| column_name_or_offset JSON_UNQUOTE_EXTRACT_OP text_literal_or_arg
{
$$ = &BinaryExpr{Left: $1, Operator: JSONUnquoteExtractOp, Right: $3}
$$ = &JSONUnquoteExpr{JSONValue: &JSONExtractExpr{JSONDoc: $1, PathList: []Expr{$3}}}
}

column_names_opt_paren:
Expand Down
2 changes: 1 addition & 1 deletion go/vt/sqlparser/tracked_buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ func TestCanonicalOutput(t *testing.T) {
},
{
"create table t (id int, info JSON, INDEX zips((CAST(info->'$.field' AS unsigned array))))",
"CREATE TABLE `t` (\n\t`id` int,\n\t`info` JSON,\n\tKEY `zips` ((CAST(`info` -> '$.field' AS unsigned array)))\n)",
"CREATE TABLE `t` (\n\t`id` int,\n\t`info` JSON,\n\tKEY `zips` ((CAST(JSON_EXTRACT(`info`, '$.field') AS unsigned array)))\n)",
},
{
"select 1 from t1 into outfile 'test/t1.txt'",
Expand Down
105 changes: 61 additions & 44 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/olekukonko/tablewriter"

"vitess.io/vitess/go/mysql/collations"
Expand Down Expand Up @@ -93,7 +95,18 @@ func (s *Tracker) String() string {
return s.buf.String()
}

func TestOneCase(t *testing.T) {
query := ``
if query == "" {
t.Skip("no query to test")
}
venv := vtenv.NewTestEnv()
env := evalengine.EmptyExpressionEnv(venv)
testCompilerCase(t, query, venv, nil, env)
}

func TestCompilerReference(t *testing.T) {
// This test runs a lot of queries and compares the results of the evalengine in eval mode to the results of the compiler.
now := time.Now()
evalengine.SystemTime = func() time.Time { return now }
defer func() { evalengine.SystemTime = time.Now }()
Expand All @@ -107,52 +120,11 @@ func TestCompilerReference(t *testing.T) {

tc.Run(func(query string, row []sqltypes.Value) {
env.Row = row

stmt, err := venv.Parser().ParseExpr(query)
if err != nil {
// no need to test un-parseable queries
return
}

fields := evalengine.FieldResolver(tc.Schema)
cfg := &evalengine.Config{
ResolveColumn: fields.Column,
ResolveType: fields.Type,
Collation: collations.CollationUtf8mb4ID,
Environment: venv,
NoConstantFolding: true,
}

converted, err := evalengine.Translate(stmt, cfg)
if err != nil {
return
}

expected, evalErr := env.EvaluateAST(converted)
total++

res, vmErr := env.Evaluate(converted)
if vmErr != nil {
switch {
case evalErr == nil:
t.Errorf("failed evaluation from compiler:\nSQL: %s\nError: %s", query, vmErr)
case evalErr.Error() != vmErr.Error():
t.Errorf("error mismatch:\nSQL: %s\nError eval: %s\nError comp: %s", query, evalErr, vmErr)
default:
supported++
}
return
testCompilerCase(t, query, venv, tc.Schema, env)
if !t.Failed() {
supported++
}

eval := expected.String()
comp := res.String()

if eval != comp {
t.Errorf("bad evaluation from compiler:\nSQL: %s\nEval: %s\nComp: %s", query, eval, comp)
return
}

supported++
})

track.Add(tc.Name(), supported, total)
Expand All @@ -162,6 +134,51 @@ func TestCompilerReference(t *testing.T) {
t.Logf("\n%s", track.String())
}

func testCompilerCase(t *testing.T, query string, venv *vtenv.Environment, schema []*querypb.Field, env *evalengine.ExpressionEnv) {
stmt, err := venv.Parser().ParseExpr(query)
if err != nil {
// no need to test un-parseable queries
return
}

fields := evalengine.FieldResolver(schema)
cfg := &evalengine.Config{
ResolveColumn: fields.Column,
ResolveType: fields.Type,
Collation: collations.CollationUtf8mb4ID,
Environment: venv,
NoConstantFolding: true,
}

converted, err := evalengine.Translate(stmt, cfg)
if err != nil {
return
}

var expected evalengine.EvalResult
var evalErr error
assert.NotPanics(t, func() {
expected, evalErr = env.EvaluateAST(converted)
})
var res evalengine.EvalResult
var vmErr error
assert.NotPanics(t, func() {
res, vmErr = env.Evaluate(converted)
})
switch {
case vmErr == nil && evalErr == nil:
eval := expected.String()
comp := res.String()
assert.Equalf(t, eval, comp, "bad evaluation from compiler:\nSQL: %s\nEval: %s\nComp: %s", query, eval, comp)
case vmErr == nil:
t.Errorf("failed evaluation from evalengine:\nSQL: %s\nError: %s", query, evalErr)
case evalErr == nil:
t.Errorf("failed evaluation from compiler:\nSQL: %s\nError: %s", query, vmErr)
case evalErr.Error() != vmErr.Error():
t.Errorf("error mismatch:\nSQL: %s\nError eval: %s\nError comp: %s", query, evalErr, vmErr)
}
}

func TestCompilerSingle(t *testing.T) {
var testCases = []struct {
expression string
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (expr *BinaryExpr) arguments(env *ExpressionEnv) (eval, eval, error) {
}
right, err := expr.Right.eval(env)
if err != nil {
return nil, nil, err
return left, nil, err
}
return left, right, nil
}
34 changes: 16 additions & 18 deletions go/vt/vtgate/evalengine/expr_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -580,13 +580,18 @@ func (l *LikeExpr) matchWildcard(left, right []byte, coll collations.ID) bool {
}
fullColl := colldata.Lookup(coll)
wc := fullColl.Wildcard(right, 0, 0, 0)
return wc.Match(left)
return wc.Match(left) == !l.Negate
}

func (l *LikeExpr) eval(env *ExpressionEnv) (eval, error) {
left, right, err := l.arguments(env)
if left == nil || right == nil || err != nil {
return nil, err
left, err := l.Left.eval(env)
if err != nil || left == nil {
return left, err
}

right, err := l.Right.eval(env)
if err != nil || right == nil {
return right, err
}

var col collations.TypedCollation
Expand All @@ -595,18 +600,9 @@ func (l *LikeExpr) eval(env *ExpressionEnv) (eval, error) {
return nil, err
}

var matched bool
switch {
case typeIsTextual(left.SQLType()) && typeIsTextual(right.SQLType()):
matched = l.matchWildcard(left.(*evalBytes).bytes, right.(*evalBytes).bytes, col.Collation)
case typeIsTextual(right.SQLType()):
matched = l.matchWildcard(left.ToRawBytes(), right.(*evalBytes).bytes, col.Collation)
case typeIsTextual(left.SQLType()):
matched = l.matchWildcard(left.(*evalBytes).bytes, right.ToRawBytes(), col.Collation)
default:
matched = l.matchWildcard(left.ToRawBytes(), right.ToRawBytes(), collations.CollationBinaryID)
}
return newEvalBool(matched == !l.Negate), nil
matched := l.matchWildcard(left.ToRawBytes(), right.ToRawBytes(), col.Collation)

return newEvalBool(matched), nil
}

func (expr *LikeExpr) compile(c *compiler) (ctype, error) {
Expand All @@ -615,12 +611,14 @@ func (expr *LikeExpr) compile(c *compiler) (ctype, error) {
return ctype{}, err
}

skip1 := c.compileNullCheck1(lt)

rt, err := expr.Right.compile(c)
if err != nil {
return ctype{}, err
}

skip := c.compileNullCheck2(lt, rt)
skip2 := c.compileNullCheck1(rt)

if !lt.isTextual() {
c.asm.Convert_xc(2, sqltypes.VarChar, c.collation, nil)
Expand Down Expand Up @@ -672,6 +670,6 @@ func (expr *LikeExpr) compile(c *compiler) (ctype, error) {
})
}

c.asm.jumpDestination(skip)
c.asm.jumpDestination(skip1, skip2)
return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean | flagNullable}, nil
}
Loading

0 comments on commit add5652

Please sign in to comment.