diff --git a/go/vt/vtctl/workflow/framework_test.go b/go/vt/vtctl/workflow/framework_test.go index 991d2d7231d..4f464ac4324 100644 --- a/go/vt/vtctl/workflow/framework_test.go +++ b/go/vt/vtctl/workflow/framework_test.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "os" + "reflect" "regexp" "slices" "strings" @@ -38,6 +39,7 @@ import ( "vitess.io/vitess/go/vt/mysqlctl/tmutils" "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/topo/memorytopo" + "vitess.io/vitess/go/vt/topo/topoproto" "vitess.io/vitess/go/vt/topotools" "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/vt/vterrors" @@ -269,6 +271,7 @@ type testTMClient struct { vrQueries map[int][]*queryResult createVReplicationWorkflowRequests map[uint32]*createVReplicationWorkflowRequestResponse readVReplicationWorkflowRequests map[uint32]*tabletmanagerdatapb.ReadVReplicationWorkflowRequest + applySchemaRequests map[uint32]*applySchemaRequestResponse primaryPositions map[uint32]string vdiffRequests map[uint32]*vdiffRequestResponse refreshStateErrors map[uint32]error @@ -291,6 +294,7 @@ func newTestTMClient(env *testEnv) *testTMClient { vrQueries: make(map[int][]*queryResult), createVReplicationWorkflowRequests: make(map[uint32]*createVReplicationWorkflowRequestResponse), readVReplicationWorkflowRequests: make(map[uint32]*tabletmanagerdatapb.ReadVReplicationWorkflowRequest), + applySchemaRequests: make(map[uint32]*applySchemaRequestResponse), readVReplicationWorkflowsResponses: make(map[string][]*tabletmanagerdatapb.ReadVReplicationWorkflowsResponse), primaryPositions: make(map[uint32]string), vdiffRequests: make(map[uint32]*vdiffRequestResponse), @@ -467,8 +471,30 @@ func (tmc *testTMClient) ExecuteFetchAsAllPrivs(ctx context.Context, tablet *top return nil, nil } +func (tmc *testTMClient) expectApplySchemaRequest(tabletID uint32, req *applySchemaRequestResponse) { + tmc.mu.Lock() + defer tmc.mu.Unlock() + + if tmc.applySchemaRequests == nil { + tmc.applySchemaRequests = make(map[uint32]*applySchemaRequestResponse) + } + + tmc.applySchemaRequests[tabletID] = req +} + // Note: ONLY breaks up change.SQL into individual statements and executes it. Does NOT fully implement ApplySchema. func (tmc *testTMClient) ApplySchema(ctx context.Context, tablet *topodatapb.Tablet, change *tmutils.SchemaChange) (*tabletmanagerdatapb.SchemaChangeResult, error) { + tmc.mu.Lock() + defer tmc.mu.Unlock() + + if expect, ok := tmc.applySchemaRequests[tablet.Alias.Uid]; ok { + if !reflect.DeepEqual(change, expect.change) { + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unexpected ApplySchema request on tablet %s: got %+v, want %+v", + topoproto.TabletAliasString(tablet.Alias), change, expect.change) + } + return expect.res, expect.err + } + stmts := strings.Split(change.SQL, ";") for _, stmt := range stmts { @@ -497,6 +523,12 @@ type createVReplicationWorkflowRequestResponse struct { err error } +type applySchemaRequestResponse struct { + change *tmutils.SchemaChange + res *tabletmanagerdatapb.SchemaChangeResult + err error +} + func (tmc *testTMClient) expectVDiffRequest(tablet *topodatapb.Tablet, vrr *vdiffRequestResponse) { tmc.mu.Lock() defer tmc.mu.Unlock() diff --git a/go/vt/vtctl/workflow/traffic_switcher.go b/go/vt/vtctl/workflow/traffic_switcher.go index 4c047a4b200..3182f137a2d 100644 --- a/go/vt/vtctl/workflow/traffic_switcher.go +++ b/go/vt/vtctl/workflow/traffic_switcher.go @@ -1527,12 +1527,14 @@ func (ts *trafficSwitcher) getTargetSequenceMetadata(ctx context.Context) (map[s if err != nil { return nil, err } - stmt, err := sqlparser.ParseAndBind(sqlCreateSequenceTable, sqltypes.StringBindVariable(sqlescape.EscapeID(tableName))) + escapedTableName, err := sqlescape.EnsureEscaped(tableName) if err != nil { - return nil, err + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid table name %s: %v", + tableName, err) } + stmt := sqlparser.BuildParsedQuery(sqlCreateSequenceTable, escapedTableName) _, err = ts.ws.tmc.ApplySchema(ctx, primary.Tablet, &tmutils.SchemaChange{ - SQL: stmt, + SQL: stmt.Query, Force: false, AllowReplication: true, SQLMode: vreplication.SQLMode, @@ -1543,6 +1545,9 @@ func (ts *trafficSwitcher) getTargetSequenceMetadata(ctx context.Context) (map[s tableName, globalKeyspace) } if bt := globalVSchema.Tables[sequenceMetadata.backingTableName]; bt == nil { + if globalVSchema.Tables == nil { + globalVSchema.Tables = make(map[string]*vschemapb.Table) + } globalVSchema.Tables[tableName] = &vschemapb.Table{ Type: vindexes.TypeSequence, } diff --git a/go/vt/vtctl/workflow/traffic_switcher_test.go b/go/vt/vtctl/workflow/traffic_switcher_test.go index 6f05a787f53..72684da8171 100644 --- a/go/vt/vtctl/workflow/traffic_switcher_test.go +++ b/go/vt/vtctl/workflow/traffic_switcher_test.go @@ -28,9 +28,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/vt/mysqlctl/tmutils" "vitess.io/vitess/go/vt/proto/vschema" + vtctldatapb "vitess.io/vitess/go/vt/proto/vtctldata" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/vtgate/vindexes" + "vitess.io/vitess/go/vt/vttablet/tabletmanager/vreplication" tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" ) @@ -74,6 +78,7 @@ func TestGetTargetSequenceMetadata(t *testing.T) { cell := "cell1" workflow := "wf1" table := "`t1`" + tableDDL := "create table t1 (id int not null auto_increment primary key, c1 varchar(10))" unescapedTable := "t1" sourceKeyspace := &testKeyspace{ KeyspaceName: "source-ks", @@ -91,12 +96,25 @@ func TestGetTargetSequenceMetadata(t *testing.T) { env := newTestEnv(t, ctx, cell, sourceKeyspace, targetKeyspace) defer env.close() + env.tmc.schema = map[string]*tabletmanagerdatapb.SchemaDefinition{ + unescapedTable: { + TableDefinitions: []*tabletmanagerdatapb.TableDefinition{ + { + Name: unescapedTable, + Schema: tableDDL, + }, + }, + }, + } + type testCase struct { - name string - sourceVSchema *vschema.Keyspace - targetVSchema *vschema.Keyspace - want map[string]*sequenceMetadata - err string + name string + sourceVSchema *vschema.Keyspace + targetVSchema *vschema.Keyspace + options *vtctldatapb.WorkflowOptions + want map[string]*sequenceMetadata + expectSourceApplySchemaRequest *applySchemaRequestResponse + err string } tests := []testCase{ { @@ -152,6 +170,65 @@ func TestGetTargetSequenceMetadata(t *testing.T) { }, }, }, + { + name: "auto_increment replaced with sequence", + sourceVSchema: &vschema.Keyspace{ + Vindexes: vindexes, + Tables: map[string]*vschema.Table{}, // Table will be created + }, + options: &vtctldatapb.WorkflowOptions{ + StripShardedAutoIncrement: vtctldatapb.ShardedAutoIncrementHandling_REPLACE, + GlobalKeyspace: sourceKeyspace.KeyspaceName, + }, + expectSourceApplySchemaRequest: &applySchemaRequestResponse{ + change: &tmutils.SchemaChange{ + SQL: sqlparser.BuildParsedQuery(sqlCreateSequenceTable, fmt.Sprintf("`%s_seq`", unescapedTable)).Query, + Force: false, + AllowReplication: true, + SQLMode: vreplication.SQLMode, + DisableForeignKeyChecks: true, + }, + res: &tabletmanagerdatapb.SchemaChangeResult{}, + }, + targetVSchema: &vschema.Keyspace{ + Vindexes: vindexes, + Tables: map[string]*vschema.Table{ + table: { + ColumnVindexes: []*vschema.ColumnVindex{ + { + Name: "xxhash", + Column: "`my-col`", + }, + }, + AutoIncrement: &vschema.AutoIncrement{ + Column: "my-col", + Sequence: fmt.Sprintf("%s_seq", unescapedTable), + }, + }, + }, + }, + want: map[string]*sequenceMetadata{ + fmt.Sprintf("%s_seq", unescapedTable): { + backingTableName: fmt.Sprintf("%s_seq", unescapedTable), + backingTableKeyspace: "source-ks", + backingTableDBName: "vt_source-ks", + usingTableName: unescapedTable, + usingTableDBName: "vt_targetks", + usingTableDefinition: &vschema.Table{ + ColumnVindexes: []*vschema.ColumnVindex{ + { + Column: "my-col", + Name: "xxhash", + }, + }, + AutoIncrement: &vschema.AutoIncrement{ + Column: "my-col", + Sequence: fmt.Sprintf("%s_seq", unescapedTable), + }, + }, + }, + }, + }, { name: "sequences with backticks", sourceVSchema: &vschema.Keyspace{ @@ -336,6 +413,9 @@ func TestGetTargetSequenceMetadata(t *testing.T) { Tablet: tablet, }, } + if tc.expectSourceApplySchemaRequest != nil { + env.tmc.expectApplySchemaRequest(tablet.Alias.Uid, tc.expectSourceApplySchemaRequest) + } } for i, shard := range targetKeyspace.ShardNames { tablet := env.tablets[targetKeyspace.KeyspaceName][startingTargetTabletUID+(i*tabletUIDStep)] @@ -354,6 +434,7 @@ func TestGetTargetSequenceMetadata(t *testing.T) { targetKeyspace: targetKeyspace.KeyspaceName, sources: sources, targets: targets, + options: tc.options, } got, err := ts.getTargetSequenceMetadata(ctx) if tc.err != "" {