diff --git a/go/vt/vtgate/vindexes/foreign_keys.go b/go/vt/vtgate/vindexes/foreign_keys.go index d79daf71140..e18ed27ff25 100644 --- a/go/vt/vtgate/vindexes/foreign_keys.go +++ b/go/vt/vtgate/vindexes/foreign_keys.go @@ -146,6 +146,7 @@ func (t *Table) ChildFKsNeedsHandling() (fks []ChildFKInfo) { switch fk.OnDelete { case sqlparser.Cascade, sqlparser.SetNull, sqlparser.SetDefault: fks = append(fks, fk) + continue } // sqlparser.Restrict, sqlparser.NoAction, sqlparser.DefaultAction // all the actions means the same thing i.e. Restrict @@ -159,6 +160,10 @@ func (t *Table) ChildFKsNeedsHandling() (fks []ChildFKInfo) { } func isShardScoped(pTable *Table, cTable *Table, pCols sqlparser.Columns, cCols sqlparser.Columns) bool { + if !pTable.Keyspace.Sharded { + return true + } + pPrimaryVdx := pTable.ColumnVindexes[0] cPrimaryVdx := cTable.ColumnVindexes[0] diff --git a/go/vt/vtgate/vindexes/foreign_keys_test.go b/go/vt/vtgate/vindexes/foreign_keys_test.go index b56bdf2f062..8a21e6909d0 100644 --- a/go/vt/vtgate/vindexes/foreign_keys_test.go +++ b/go/vt/vtgate/vindexes/foreign_keys_test.go @@ -154,3 +154,112 @@ func pkInfo(parentTable *Table, pCols []string, cCols []string) ParentFKInfo { ChildColumns: sqlparser.MakeColumns(cCols...), } } + +// TestChildFKs tests the ChildFKsNeedsHandling method is provides the child foreign key table whose +// rows needs to be managed by vitess. +func TestChildFKs(t *testing.T) { + col1Vindex := &ColumnVindex{ + Name: "v1", + Vindex: binVindex, + Columns: sqlparser.MakeColumns("col1"), + } + col4DiffVindex := &ColumnVindex{ + Name: "v2", + Vindex: binOnlyVindex, + Columns: sqlparser.MakeColumns("col4"), + } + + unshardedTbl := &Table{ + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: uks2, + } + shardedSingleColTbl := &Table{ + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: sks, + ColumnVindexes: []*ColumnVindex{col1Vindex}, + } + shardedSingleColTblWithDiffVindex := &Table{ + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: sks, + ColumnVindexes: []*ColumnVindex{col4DiffVindex}, + } + + tests := []struct { + name string + table *Table + expChildTbls []string + }{{ + name: "No Parent FKs", + table: &Table{ + ColumnVindexes: []*ColumnVindex{col1Vindex}, + Keyspace: sks, + }, + expChildTbls: []string{}, + }, { + name: "restrict unsharded", + table: &Table{ + ColumnVindexes: []*ColumnVindex{col1Vindex}, + Keyspace: uks2, + ChildForeignKeys: []ChildFKInfo{ckInfo(unshardedTbl, []string{"col4"}, []string{"col1"}, sqlparser.Restrict)}, + }, + expChildTbls: []string{}, + }, { + name: "restrict shard scoped", + table: &Table{ + ColumnVindexes: []*ColumnVindex{col1Vindex}, + Keyspace: sks, + ChildForeignKeys: []ChildFKInfo{ckInfo(shardedSingleColTbl, []string{"col1"}, []string{"col1"}, sqlparser.Restrict)}, + }, + expChildTbls: []string{}, + }, { + name: "restrict Keyspaces don't match", + table: &Table{ + Keyspace: uks, + ChildForeignKeys: []ChildFKInfo{ckInfo(shardedSingleColTbl, []string{"col1"}, []string{"col1"}, sqlparser.Restrict)}, + }, + expChildTbls: []string{"t1"}, + }, { + name: "restrict cross shard", + table: &Table{ + Keyspace: sks, + ColumnVindexes: []*ColumnVindex{col1Vindex}, + ChildForeignKeys: []ChildFKInfo{ckInfo(shardedSingleColTblWithDiffVindex, []string{"col4"}, []string{"col1"}, sqlparser.Restrict)}, + }, + expChildTbls: []string{"t1"}, + }, { + name: "cascade unsharded", + table: &Table{ + Keyspace: uks2, + ColumnVindexes: []*ColumnVindex{col1Vindex}, + ChildForeignKeys: []ChildFKInfo{ckInfo(unshardedTbl, []string{"col4"}, []string{"col1"}, sqlparser.Cascade)}, + }, + expChildTbls: []string{"t1"}, + }, { + name: "cascade cross shard", + table: &Table{ + Keyspace: sks, + ColumnVindexes: []*ColumnVindex{col1Vindex}, + ChildForeignKeys: []ChildFKInfo{ckInfo(shardedSingleColTblWithDiffVindex, []string{"col4"}, []string{"col1"}, sqlparser.Cascade)}, + }, + expChildTbls: []string{"t1"}, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + childFks := tt.table.ChildFKsNeedsHandling() + var actualChildTbls []string + for _, fk := range childFks { + actualChildTbls = append(actualChildTbls, fk.Table.Name.String()) + } + require.ElementsMatch(t, tt.expChildTbls, actualChildTbls) + }) + } +} + +func ckInfo(cTable *Table, pCols []string, cCols []string, refAction sqlparser.ReferenceAction) ChildFKInfo { + return ChildFKInfo{ + Table: cTable, + ParentColumns: sqlparser.MakeColumns(pCols...), + ChildColumns: sqlparser.MakeColumns(cCols...), + OnDelete: refAction, + } +}