Skip to content

Commit

Permalink
fix: ignore alternative which dml modifying reference table column
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <harshit@planetscale.com>
  • Loading branch information
harshit-gangal committed Dec 20, 2024
1 parent 45192d2 commit bc4eb4b
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 20 deletions.
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/cte_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func tryMergeRecurse(ctx *plancontext.PlanningContext, in *RecurseCTE) (Operator
}

func tryMergeCTE(ctx *plancontext.PlanningContext, seed, term Operator, in *RecurseCTE) *Route {
seedRoute, termRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(seed, term)
seedRoute, termRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(ctx, seed, term)
if seedRoute == nil {
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func createDeleteWithInputOp(ctx *plancontext.PlanningContext, del *sqlparser.De
}

var delOps []dmlOp
for _, target := range ctx.SemTable.Targets.Constituents() {
for _, target := range ctx.SemTable.DMLTargets.Constituents() {
op := createDeleteOpWithTarget(ctx, target, del.Ignore)
delOps = append(delOps, op)
}
Expand Down
14 changes: 8 additions & 6 deletions go/vt/vtgate/planbuilder/operators/join_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
// If they can be merged, a new operator with the merged routing is returned
// If they cannot be merged, nil is returned.
func (jm *joinMerger) mergeJoinInputs(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinPredicates []sqlparser.Expr) *Route {
lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(lhs, rhs)
lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(ctx, lhs, rhs)
if lhsRoute == nil {
return nil
}
Expand Down Expand Up @@ -102,13 +102,13 @@ func mergeAnyShardRoutings(ctx *plancontext.PlanningContext, a, b *AnyShardRouti
}
}

func prepareInputRoutes(lhs Operator, rhs Operator) (*Route, *Route, Routing, Routing, routingType, routingType, bool) {
func prepareInputRoutes(ctx *plancontext.PlanningContext, lhs Operator, rhs Operator) (*Route, *Route, Routing, Routing, routingType, routingType, bool) {
lhsRoute, rhsRoute := operatorsToRoutes(lhs, rhs)
if lhsRoute == nil || rhsRoute == nil {
return nil, nil, nil, nil, 0, 0, false
}

lhsRoute, rhsRoute, routingA, routingB, sameKeyspace := getRoutesOrAlternates(lhsRoute, rhsRoute)
lhsRoute, rhsRoute, routingA, routingB, sameKeyspace := getRoutesOrAlternates(ctx, lhsRoute, rhsRoute)

a, b := getRoutingType(routingA), getRoutingType(routingB)
return lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace
Expand Down Expand Up @@ -159,7 +159,7 @@ func (rt routingType) String() string {

// getRoutesOrAlternates gets the Routings from each Route. If they are from different keyspaces,
// we check if this is a table with alternates in other keyspaces that we can use
func getRoutesOrAlternates(lhsRoute, rhsRoute *Route) (*Route, *Route, Routing, Routing, bool) {
func getRoutesOrAlternates(ctx *plancontext.PlanningContext, lhsRoute, rhsRoute *Route) (*Route, *Route, Routing, Routing, bool) {
routingA := lhsRoute.Routing
routingB := rhsRoute.Routing
sameKeyspace := routingA.Keyspace() == routingB.Keyspace()
Expand All @@ -171,13 +171,15 @@ func getRoutesOrAlternates(lhsRoute, rhsRoute *Route) (*Route, *Route, Routing,
return lhsRoute, rhsRoute, routingA, routingB, sameKeyspace
}

if refA, ok := routingA.(*AnyShardRouting); ok {
if refA, ok := routingA.(*AnyShardRouting); ok &&
!TableID(lhsRoute).IsOverlapping(ctx.SemTable.DMLTargets) {
if altARoute := refA.AlternateInKeyspace(routingB.Keyspace()); altARoute != nil {
return altARoute, rhsRoute, altARoute.Routing, routingB, true
}
}

if refB, ok := routingB.(*AnyShardRouting); ok {
if refB, ok := routingB.(*AnyShardRouting); ok &&
!TableID(rhsRoute).IsOverlapping(ctx.SemTable.DMLTargets) {
if altBRoute := refB.AlternateInKeyspace(routingA.Keyspace()); altBRoute != nil {
return lhsRoute, altBRoute, routingA, altBRoute.Routing, true
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/subquery_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ func mergeSubqueryInputs(ctx *plancontext.PlanningContext, in, out Operator, joi
return nil
}

inRoute, outRoute, inRouting, outRouting, sameKeyspace := getRoutesOrAlternates(inRoute, outRoute)
inRoute, outRoute, inRouting, outRouting, sameKeyspace := getRoutesOrAlternates(ctx, inRoute, outRoute)
inner, outer := getRoutingType(inRouting), getRoutingType(outRouting)

switch {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/union_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func mergeUnionInputs(
lhsExprs, rhsExprs sqlparser.SelectExprs,
distinct bool,
) (Operator, sqlparser.SelectExprs) {
lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(lhs, rhs)
lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(ctx, lhs, rhs)
if lhsRoute == nil {
return nil, nil
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func createUpdateWithInputOp(ctx *plancontext.PlanningContext, upd *sqlparser.Up
ueMap := prepareUpdateExpressionList(ctx, upd)

var updOps []dmlOp
for _, target := range ctx.SemTable.Targets.Constituents() {
for _, target := range ctx.SemTable.DMLTargets.Constituents() {
op := createUpdateOpWithTarget(ctx, upd, target, ueMap[target])
updOps = append(updOps, op)
}
Expand Down Expand Up @@ -308,7 +308,7 @@ func errIfUpdateNotSupported(ctx *plancontext.PlanningContext, stmt *sqlparser.U
}
}

// Now we check if any of the foreign key columns that are being udpated have dependencies on other updated columns.
// Now we check if any of the foreign key columns that are being updated have dependencies on other updated columns.
// This is unsafe, and we currently don't support this in Vitess.
if err := ctx.SemTable.ErrIfFkDependentColumnUpdated(stmt.Exprs); err != nil {
panic(err)
Expand Down
2 changes: 2 additions & 0 deletions go/vt/vtgate/planbuilder/plan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func (s *planTestSuite) TestPlan() {
s.addPKsProvided(vschema, "user", []string{"user_extra"}, []string{"id", "user_id"})
s.addPKsProvided(vschema, "ordering", []string{"order"}, []string{"oid", "region_id"})
s.addPKsProvided(vschema, "ordering", []string{"order_event"}, []string{"oid", "ename"})
s.addPKsProvided(vschema, "main", []string{"source_of_ref"}, []string{"id"})

// You will notice that some tests expect user.Id instead of user.id.
// This is because we now pre-create vindex columns in the symbol
Expand Down Expand Up @@ -305,6 +306,7 @@ func (s *planTestSuite) TestOne() {
s.addPKsProvided(vschema, "user", []string{"user_extra"}, []string{"id", "user_id"})
s.addPKsProvided(vschema, "ordering", []string{"order"}, []string{"oid", "region_id"})
s.addPKsProvided(vschema, "ordering", []string{"order_event"}, []string{"oid", "ename"})
s.addPKsProvided(vschema, "main", []string{"source_of_ref"}, []string{"id"})

s.testFile("onecase.json", vw, false)
}
Expand Down
70 changes: 70 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/reference_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -771,5 +771,75 @@
"user.user_extra"
]
}
},
{
"comment": "update reference table with join with sharded table",
"query": "update main.source_of_ref as sr join main.rerouted_ref as rr on sr.id = rr.id inner join user.music as m on sr.col = m.col set sr.tt = 5 where m.user_id = 1",
"plan": {
"QueryType": "UPDATE",
"Original": "update main.source_of_ref as sr join main.rerouted_ref as rr on sr.id = rr.id inner join user.music as m on sr.col = m.col set sr.tt = 5 where m.user_id = 1",
"Instructions": {
"OperatorType": "DMLWithInput",
"TargetTabletType": "PRIMARY",
"Offset": [
"0:[0]"
],
"Inputs": [
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "R:0",
"JoinVars": {
"m_col": 0
},
"TableName": "music_rerouted_ref, source_of_ref",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "EqualUnique",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select m.col from music as m where 1 != 1",
"Query": "select m.col from music as m where m.user_id = 1 lock in share mode",
"Table": "music",
"Values": [
"1"
],
"Vindex": "user_index"
},
{
"OperatorType": "Route",
"Variant": "Unsharded",
"Keyspace": {
"Name": "main",
"Sharded": false
},
"FieldQuery": "select sr.id from source_of_ref as sr, rerouted_ref as rr where 1 != 1",
"Query": "select sr.id from source_of_ref as sr, rerouted_ref as rr where sr.col = :m_col and sr.id = rr.id lock in share mode",
"Table": "rerouted_ref, source_of_ref"
}
]
},
{
"OperatorType": "Update",
"Variant": "Unsharded",
"Keyspace": {
"Name": "main",
"Sharded": false
},
"TargetTabletType": "PRIMARY",
"Query": "update source_of_ref as sr set sr.tt = 5 where sr.id in ::dml_vals",
"Table": "source_of_ref"
}
]
},
"TablesUsed": [
"main.rerouted_ref",
"main.source_of_ref",
"user.music"
]
}
}
]
2 changes: 1 addition & 1 deletion go/vt/vtgate/semantics/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func (a *analyzer) newSemTable(
Direct: a.binder.direct,
ExprTypes: a.typer.m,
Tables: a.tables.Tables,
Targets: a.binder.targets,
DMLTargets: a.binder.targets,
NotSingleRouteErr: a.notSingleRouteErr,
NotUnshardedErr: a.unshardedErr,
Warning: a.warning,
Expand Down
14 changes: 7 additions & 7 deletions go/vt/vtgate/semantics/semantic_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ type (
// It doesn't recurse inside derived tables to find the original dependencies.
Direct ExprDependencies

// Targets contains the TableSet of each table getting modified by the update/delete statement.
Targets TableSet
// DMLTargets contains the TableSet of each table getting modified by the update/delete statement.
DMLTargets TableSet

// ColumnEqualities is used for transitive closures (e.g., if a == b and b == c, then a == c).
ColumnEqualities map[columnName][]sqlparser.Expr
Expand Down Expand Up @@ -202,15 +202,15 @@ func (st *SemTable) CopyDependencies(from, to sqlparser.Expr) {

// GetChildForeignKeysForTargets gets the child foreign keys as a list for all the target tables.
func (st *SemTable) GetChildForeignKeysForTargets() (fks []vindexes.ChildFKInfo) {
for _, ts := range st.Targets.Constituents() {
for _, ts := range st.DMLTargets.Constituents() {
fks = append(fks, st.childForeignKeysInvolved[ts]...)
}
return fks
}

// GetChildForeignKeysForTableSet gets the child foreign keys as a listfor the TableSet.
func (st *SemTable) GetChildForeignKeysForTableSet(target TableSet) (fks []vindexes.ChildFKInfo) {
for _, ts := range st.Targets.Constituents() {
for _, ts := range st.DMLTargets.Constituents() {
if target.IsSolvedBy(ts) {
fks = append(fks, st.childForeignKeysInvolved[ts]...)
}
Expand Down Expand Up @@ -238,15 +238,15 @@ func (st *SemTable) GetChildForeignKeysList() []vindexes.ChildFKInfo {

// GetParentForeignKeysForTargets gets the parent foreign keys as a list for all the target tables.
func (st *SemTable) GetParentForeignKeysForTargets() (fks []vindexes.ParentFKInfo) {
for _, ts := range st.Targets.Constituents() {
for _, ts := range st.DMLTargets.Constituents() {
fks = append(fks, st.parentForeignKeysInvolved[ts]...)
}
return fks
}

// GetParentForeignKeysForTableSet gets the parent foreign keys as a list for the TableSet.
func (st *SemTable) GetParentForeignKeysForTableSet(target TableSet) (fks []vindexes.ParentFKInfo) {
for _, ts := range st.Targets.Constituents() {
for _, ts := range st.DMLTargets.Constituents() {
if target.IsSolvedBy(ts) {
fks = append(fks, st.parentForeignKeysInvolved[ts]...)
}
Expand Down Expand Up @@ -970,7 +970,7 @@ func (st *SemTable) UpdateChildFKExpr(origUpdExpr *sqlparser.UpdateExpr, newExpr

// GetTargetTableSetForTableName returns the TableSet for the given table name from the target tables.
func (st *SemTable) GetTargetTableSetForTableName(name sqlparser.TableName) (TableSet, error) {
for _, target := range st.Targets.Constituents() {
for _, target := range st.DMLTargets.Constituents() {
tbl, err := st.Tables[target.TableOffset()].Name()
if err != nil {
return "", err
Expand Down

0 comments on commit bc4eb4b

Please sign in to comment.