Skip to content

Commit

Permalink
Type Cast all update expressions in verify queries (vitessio#14555)
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <manan@planetscale.com>
Signed-off-by: Jun Wang <jun.wang@demonware.net>
  • Loading branch information
GuptaManan100 authored and Jun Wang committed Nov 24, 2023
1 parent be1e8f8 commit ac36409
Show file tree
Hide file tree
Showing 8 changed files with 344 additions and 112 deletions.
9 changes: 9 additions & 0 deletions go/test/endtoend/vtgate/foreignkey/fk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,15 @@ func TestFkQueries(t *testing.T) {
"update fk_t11 set col = id where id in (1, 5)",
},
},
{
name: "Update on child to 0 when parent has -0",
queries: []string{
"insert into fk_t15 (id, col) values (2, '-0')",
"insert /*+ SET_VAR(foreign_key_checks=0) */ into fk_t16 (id, col) values (3, '5'), (4, '-5')",
"update fk_t16 set col = col * (col - (col)) where id = 3",
"update fk_t16 set col = col * (col - (col)) where id = 4",
},
},
}

for _, testcase := range testcases {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 4 additions & 19 deletions go/vt/vtgate/engine/fk_cascade.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
querypb "vitess.io/vitess/go/vt/proto/query"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/evalengine"
)

// FkChild contains the Child Primitive to be executed collecting the values from the Selection Primitive using the column indexes.
Expand All @@ -42,12 +41,10 @@ type FkChild struct {

// NonLiteralUpdateInfo stores the information required to process non-literal update queries.
// It stores 4 information-
// 1. ExprCol- The index of the column being updated in the select query.
// 2. CompExprCol- The index of the comparison expression in the select query to know if the row value is actually being changed or not.
// 3. UpdateExprCol- The index of the updated expression in the select query.
// 4. UpdateExprBvName- The bind variable name to store the updated expression into.
// 1. CompExprCol- The index of the comparison expression in the select query to know if the row value is actually being changed or not.
// 2. UpdateExprCol- The index of the updated expression in the select query.
// 3. UpdateExprBvName- The bind variable name to store the updated expression into.
type NonLiteralUpdateInfo struct {
ExprCol int
CompExprCol int
UpdateExprCol int
UpdateExprBvName string
Expand Down Expand Up @@ -188,19 +185,7 @@ func (fkc *FkCascade) executeNonLiteralExprFkChild(ctx context.Context, vcursor

// Next, we need to copy the updated expressions value into the bind variables map.
for _, info := range child.NonLiteralInfo {
// Type case the value to that of the column that we are updating.
// This is required for example when we receive an updated float value of -0, but
// the column being updated is a varchar column, then if we don't coerce the value of -0 to
// varchar, MySQL ends up setting it to '0' instead of '-0'.
finalVal := row[info.UpdateExprCol]
if !finalVal.IsNull() {
var err error
finalVal, err = evalengine.CoerceTo(finalVal, selectionRes.Fields[info.ExprCol].Type)
if err != nil {
return err
}
}
bindVars[info.UpdateExprBvName] = sqltypes.ValueBindVariable(finalVal)
bindVars[info.UpdateExprBvName] = sqltypes.ValueBindVariable(row[info.UpdateExprCol])
}
_, err := vcursor.ExecutePrimitive(ctx, child.Exec, bindVars, wantfields)
if err != nil {
Expand Down
84 changes: 62 additions & 22 deletions go/vt/vtgate/planbuilder/operators/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"slices"
"strings"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/sysvars"
"vitess.io/vitess/go/vt/vterrors"
Expand Down Expand Up @@ -217,7 +219,7 @@ func buildFkOperator(ctx *plancontext.PlanningContext, updOp ops.Operator, updCl
return nil, err
}

return createFKVerifyOp(ctx, op, updClone, parentFks, restrictChildFks)
return createFKVerifyOp(ctx, op, updClone, parentFks, restrictChildFks, updatedTable)
}

// splitChildFks splits the child foreign keys into restrict and cascade list as restrict is handled through Verify operator and cascade is handled through Cascade operator.
Expand Down Expand Up @@ -268,7 +270,7 @@ func createFKCascadeOp(ctx *plancontext.PlanningContext, parentOp ops.Operator,
for _, updExpr := range ue {
// We add the expression and a comparison expression to the SELECT exprssion while storing their offsets.
var info engine.NonLiteralUpdateInfo
info, selectExprs = addNonLiteralUpdExprToSelect(ctx, updExpr, selectExprs)
info, selectExprs = addNonLiteralUpdExprToSelect(ctx, updatedTable, updExpr, selectExprs)
nonLiteralUpdateInfo = append(nonLiteralUpdateInfo, info)
}
}
Expand Down Expand Up @@ -329,25 +331,22 @@ func addColumns(ctx *plancontext.PlanningContext, columns sqlparser.Columns, exp
// For an update query having non-literal updates, we add the updated expression and a comparison expression to the select query.
// For example, for a query like `update fk_table set col = id * 100 + 1`
// We would add the expression `id * 100 + 1` and the comparison expression `col <=> id * 100 + 1` to the select query.
func addNonLiteralUpdExprToSelect(ctx *plancontext.PlanningContext, updExpr *sqlparser.UpdateExpr, exprs []sqlparser.SelectExpr) (engine.NonLiteralUpdateInfo, []sqlparser.SelectExpr) {
func addNonLiteralUpdExprToSelect(ctx *plancontext.PlanningContext, updatedTable *vindexes.Table, updExpr *sqlparser.UpdateExpr, exprs []sqlparser.SelectExpr) (engine.NonLiteralUpdateInfo, []sqlparser.SelectExpr) {
// Create the comparison expression.
compExpr := sqlparser.NewComparisonExpr(sqlparser.NullSafeEqualOp, updExpr.Name, updExpr.Expr, nil)
castedExpr := getCastedUpdateExpression(updatedTable, updExpr)
compExpr := sqlparser.NewComparisonExpr(sqlparser.NullSafeEqualOp, updExpr.Name, castedExpr, nil)
info := engine.NonLiteralUpdateInfo{
CompExprCol: -1,
UpdateExprCol: -1,
ExprCol: -1,
}
// Add the expressions to the select expressions. We make sure to reuse the offset if it has already been added once.
for idx, selectExpr := range exprs {
if ctx.SemTable.EqualsExpr(selectExpr.(*sqlparser.AliasedExpr).Expr, compExpr) {
info.CompExprCol = idx
}
if ctx.SemTable.EqualsExpr(selectExpr.(*sqlparser.AliasedExpr).Expr, updExpr.Expr) {
if ctx.SemTable.EqualsExpr(selectExpr.(*sqlparser.AliasedExpr).Expr, castedExpr) {
info.UpdateExprCol = idx
}
if ctx.SemTable.EqualsExpr(selectExpr.(*sqlparser.AliasedExpr).Expr, updExpr.Name) {
info.ExprCol = idx
}
}
// If the expression doesn't exist, then we add the expression and store the offset.
if info.CompExprCol == -1 {
Expand All @@ -356,15 +355,54 @@ func addNonLiteralUpdExprToSelect(ctx *plancontext.PlanningContext, updExpr *sql
}
if info.UpdateExprCol == -1 {
info.UpdateExprCol = len(exprs)
exprs = append(exprs, aeWrap(updExpr.Expr))
}
if info.ExprCol == -1 {
info.ExprCol = len(exprs)
exprs = append(exprs, aeWrap(updExpr.Name))
exprs = append(exprs, aeWrap(castedExpr))
}
return info, exprs
}

func getCastedUpdateExpression(updatedTable *vindexes.Table, updExpr *sqlparser.UpdateExpr) sqlparser.Expr {
castTypeStr := getCastTypeForColumn(updatedTable, updExpr)
if castTypeStr == "" {
return updExpr.Expr
}
return &sqlparser.CastExpr{
Expr: updExpr.Expr,
Type: &sqlparser.ConvertType{
Type: castTypeStr,
},
}
}

func getCastTypeForColumn(updatedTable *vindexes.Table, updExpr *sqlparser.UpdateExpr) string {
var ty querypb.Type
for _, column := range updatedTable.Columns {
if updExpr.Name.Name.Equal(column.Name) {
ty = column.Type
break
}
}
switch {
case sqltypes.IsNull(ty):
return ""
case sqltypes.IsSigned(ty):
return "SIGNED"
case sqltypes.IsUnsigned(ty):
return "UNSIGNED"
case sqltypes.IsFloat(ty):
return "FLOAT"
case sqltypes.IsDecimal(ty):
return "DECIMAL"
case sqltypes.IsDateOrTime(ty):
return "DATETIME"
case sqltypes.IsBinary(ty):
return "BINARY"
case sqltypes.IsText(ty):
return "CHAR"
default:
return ""
}
}

// createFkChildForUpdate creates the update query operator for the child table based on the foreign key constraints.
func createFkChildForUpdate(ctx *plancontext.PlanningContext, fk vindexes.ChildFKInfo, selectOffsets []int, nonLiteralUpdateInfo []engine.NonLiteralUpdateInfo, updatedTable *vindexes.Table) (*FkChild, error) {
// Create a ValTuple of child column names
Expand Down Expand Up @@ -484,6 +522,7 @@ func buildChildUpdOpForSetNull(
// So, if either of :v1 or :v2 is NULL, then the entire condition is true (which is the same as not having the condition when :v1 or :v2 is NULL).
updateExprs := ctx.SemTable.GetUpdateExpressionsForFk(fk.String(updatedTable))
compExpr := nullSafeNotInComparison(ctx,
updatedTable,
updateExprs, fk, updatedTable.GetTableName(), nonLiteralUpdateInfo, false /* appendQualifier */)
if compExpr != nil {
childWhereExpr = &sqlparser.AndExpr{
Expand Down Expand Up @@ -522,6 +561,7 @@ func createFKVerifyOp(
updStmt *sqlparser.Update,
parentFks []vindexes.ParentFKInfo,
restrictChildFks []vindexes.ChildFKInfo,
updatedTable *vindexes.Table,
) (ops.Operator, error) {
if len(parentFks) == 0 && len(restrictChildFks) == 0 {
return childOp, nil
Expand All @@ -530,7 +570,7 @@ func createFKVerifyOp(
var Verify []*VerifyOp
// This validates that new values exists on the parent table.
for _, fk := range parentFks {
op, err := createFkVerifyOpForParentFKForUpdate(ctx, updStmt, fk)
op, err := createFkVerifyOpForParentFKForUpdate(ctx, updatedTable, updStmt, fk)
if err != nil {
return nil, err
}
Expand All @@ -541,7 +581,7 @@ func createFKVerifyOp(
}
// This validates that the old values don't exist on the child table.
for _, fk := range restrictChildFks {
op, err := createFkVerifyOpForChildFKForUpdate(ctx, updStmt, fk)
op, err := createFkVerifyOpForChildFKForUpdate(ctx, updatedTable, updStmt, fk)
if err != nil {
return nil, err
}
Expand All @@ -568,7 +608,7 @@ func createFKVerifyOp(
// where Parent.p1 is null and Parent.p2 is null and Child.id = 1 and Child.c2 + 1 is not null
// and Child.c2 is not null and not ((Child.c1) <=> (Child.c2 + 1))
// limit 1
func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, updStmt *sqlparser.Update, pFK vindexes.ParentFKInfo) (ops.Operator, error) {
func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, updatedTable *vindexes.Table, updStmt *sqlparser.Update, pFK vindexes.ParentFKInfo) (ops.Operator, error) {
childTblExpr := updStmt.TableExprs[0].(*sqlparser.AliasedTableExpr)
childTbl, err := childTblExpr.TableName()
if err != nil {
Expand Down Expand Up @@ -608,7 +648,7 @@ func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, updS
}
} else {
notEqualColNames = append(notEqualColNames, prefixColNames(ctx, childTbl, matchedExpr.Name))
prefixedMatchExpr := prefixColNames(ctx, childTbl, matchedExpr.Expr)
prefixedMatchExpr := prefixColNames(ctx, childTbl, getCastedUpdateExpression(updatedTable, matchedExpr))
notEqualExprs = append(notEqualExprs, prefixedMatchExpr)
joinExpr = &sqlparser.ComparisonExpr{
Operator: sqlparser.EqualOp,
Expand Down Expand Up @@ -668,7 +708,7 @@ func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, updS
// verify query:
// select 1 from Child join Parent on Parent.p1 = Child.c1 and Parent.p2 = Child.c2
// where Parent.id = 1 and ((Parent.col + 1) IS NULL OR (child.c1) NOT IN ((Parent.col + 1))) limit 1
func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updStmt *sqlparser.Update, cFk vindexes.ChildFKInfo) (ops.Operator, error) {
func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updatedTable *vindexes.Table, updStmt *sqlparser.Update, cFk vindexes.ChildFKInfo) (ops.Operator, error) {
// ON UPDATE RESTRICT foreign keys that require validation, should only be allowed in the case where we
// are verifying all the FKs on vtgate level.
if !ctx.VerifyAllFKs {
Expand Down Expand Up @@ -710,7 +750,7 @@ func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updSt
// For example, if we are setting `update child cola = :v1 and colb = :v2`, then on the parent, the where condition would look something like this -
// `:v1 IS NULL OR :v2 IS NULL OR (cola, colb) NOT IN ((:v1,:v2))`
// So, if either of :v1 or :v2 is NULL, then the entire condition is true (which is the same as not having the condition when :v1 or :v2 is NULL).
compExpr := nullSafeNotInComparison(ctx, updStmt.Exprs, cFk, parentTbl, nil /* nonLiteralUpdateInfo */, true /* appendQualifier */)
compExpr := nullSafeNotInComparison(ctx, updatedTable, updStmt.Exprs, cFk, parentTbl, nil /* nonLiteralUpdateInfo */, true /* appendQualifier */)
if compExpr != nil {
whereCond = sqlparser.AndExpressions(whereCond, compExpr)
}
Expand All @@ -735,7 +775,7 @@ func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updSt
// `:v1 IS NULL OR :v2 IS NULL OR (cola, colb) NOT IN ((:v1,:v2))`
// So, if either of :v1 or :v2 is NULL, then the entire condition is true (which is the same as not having the condition when :v1 or :v2 is NULL)
// This expression is used in cascading SET NULLs and in verifying whether an update should be restricted.
func nullSafeNotInComparison(ctx *plancontext.PlanningContext, updateExprs sqlparser.UpdateExprs, cFk vindexes.ChildFKInfo, parentTbl sqlparser.TableName, nonLiteralUpdateInfo []engine.NonLiteralUpdateInfo, appendQualifier bool) sqlparser.Expr {
func nullSafeNotInComparison(ctx *plancontext.PlanningContext, updatedTable *vindexes.Table, updateExprs sqlparser.UpdateExprs, cFk vindexes.ChildFKInfo, parentTbl sqlparser.TableName, nonLiteralUpdateInfo []engine.NonLiteralUpdateInfo, appendQualifier bool) sqlparser.Expr {
var valTuple sqlparser.ValTuple
var updateValues sqlparser.ValTuple
for idx, updateExpr := range updateExprs {
Expand All @@ -744,7 +784,7 @@ func nullSafeNotInComparison(ctx *plancontext.PlanningContext, updateExprs sqlpa
if sqlparser.IsNull(updateExpr.Expr) {
return nil
}
childUpdateExpr := prefixColNames(ctx, parentTbl, updateExpr.Expr)
childUpdateExpr := prefixColNames(ctx, parentTbl, getCastedUpdateExpression(updatedTable, updateExpr))
if len(nonLiteralUpdateInfo) > 0 && nonLiteralUpdateInfo[idx].UpdateExprBvName != "" {
childUpdateExpr = sqlparser.NewArgument(nonLiteralUpdateInfo[idx].UpdateExprBvName)
}
Expand Down
Loading

0 comments on commit ac36409

Please sign in to comment.