diff --git a/go/test/endtoend/vtgate/schema/schema_test.go b/go/test/endtoend/vtgate/schema/schema_test.go index d5ae8df0a90..04d91d8d978 100644 --- a/go/test/endtoend/vtgate/schema/schema_test.go +++ b/go/test/endtoend/vtgate/schema/schema_test.go @@ -238,6 +238,18 @@ func testApplySchemaBatch(t *testing.T) { require.NoError(t, err) checkTables(t, totalTableCount) } + { + sqls := "create table batch1(id int primary key);create table batch2(id int primary key);create table batch3(id int primary key);create table batch4(id int primary key);create table batch5(id int primary key);" + _, err := clusterInstance.VtctlclientProcess.ExecuteCommandWithOutput("ApplySchema", "--", "--ddl_strategy", "direct --allow-zero-in-date", "--sql", sqls, "--batch_size", "2", keyspaceName) + require.NoError(t, err) + checkTables(t, totalTableCount+5) + } + { + sqls := "drop table batch1; drop table batch2; drop table batch3; drop table batch4; drop table batch5" + _, err := clusterInstance.VtctlclientProcess.ExecuteCommandWithOutput("ApplySchema", "--", "--sql", sqls, keyspaceName) + require.NoError(t, err) + checkTables(t, totalTableCount) + } } // checkTables checks the number of tables in the first two shards. diff --git a/go/vt/schemamanager/tablet_executor.go b/go/vt/schemamanager/tablet_executor.go index e9e9a66c6fb..a56a95d5034 100644 --- a/go/vt/schemamanager/tablet_executor.go +++ b/go/vt/schemamanager/tablet_executor.go @@ -441,6 +441,33 @@ func (exec *TabletExecutor) executeOnAllTablets(ctx context.Context, execResult } } +// applyAllowZeroInDate takes a SQL string which may contain one or more statements, +// and, assuming those are DDLs, adds a /*vt+ allowZeroInDate=true */ directive to all of them, +// returning the result again as one long SQL. +func applyAllowZeroInDate(sql string) (string, error) { + // sql may be a batch of multiple statements + sqls, err := sqlparser.SplitStatementToPieces(sql) + if err != nil { + return sql, err + } + var modifiedSqls []string + for _, singleSQL := range sqls { + // --allow-zero-in-date Applies to DDLs + stmt, err := sqlparser.Parse(singleSQL) + if err != nil { + return sql, err + } + if ddlStmt, ok := stmt.(sqlparser.DDLStatement); ok { + // Add comments directive to allow zero in date + const directive = `/*vt+ allowZeroInDate=true */` + ddlStmt.SetComments(ddlStmt.GetParsedComments().Prepend(directive)) + singleSQL = sqlparser.String(ddlStmt) + } + modifiedSqls = append(modifiedSqls, singleSQL) + } + return strings.Join(modifiedSqls, ";"), err +} + func (exec *TabletExecutor) executeOneTablet( ctx context.Context, tablet *topodatapb.Tablet, @@ -459,22 +486,17 @@ func (exec *TabletExecutor) executeOneTablet( } else { if exec.ddlStrategySetting != nil && exec.ddlStrategySetting.IsAllowZeroInDateFlag() { // --allow-zero-in-date Applies to DDLs - stmt, err := sqlparser.Parse(string(sql)) + sql, err = applyAllowZeroInDate(sql) if err != nil { errChan <- ShardWithError{Shard: tablet.Shard, Err: err.Error()} return } - if ddlStmt, ok := stmt.(sqlparser.DDLStatement); ok { - // Add comments directive to allow zero in date - const directive = `/*vt+ allowZeroInDate=true */` - ddlStmt.SetComments(ddlStmt.GetParsedComments().Prepend(directive)) - sql = sqlparser.String(ddlStmt) - } } result, err = exec.tmc.ExecuteFetchAsDba(ctx, tablet, false, &tabletmanagerdatapb.ExecuteFetchAsDbaRequest{ Query: []byte(sql), MaxRows: 10, }) + } if err != nil { errChan <- ShardWithError{Shard: tablet.Shard, Err: err.Error()} diff --git a/go/vt/schemamanager/tablet_executor_test.go b/go/vt/schemamanager/tablet_executor_test.go index 15022ecc527..175e10dfb66 100644 --- a/go/vt/schemamanager/tablet_executor_test.go +++ b/go/vt/schemamanager/tablet_executor_test.go @@ -408,3 +408,38 @@ func TestAllSQLsAreCreateQueries(t *testing.T) { }) } } + +func TestApplyAllowZeroInDate(t *testing.T) { + tcases := []struct { + sql string + expect string + }{ + { + "create table t1(id int primary key); ", + "create /*vt+ allowZeroInDate=true */ table t1 (\n\tid int primary key\n)", + }, + { + "create table t1(id int primary key)", + "create /*vt+ allowZeroInDate=true */ table t1 (\n\tid int primary key\n)", + }, + { + "create table t1(id int primary key);select 1 from dual", + "create /*vt+ allowZeroInDate=true */ table t1 (\n\tid int primary key\n);select 1 from dual", + }, + { + "create table t1(id int primary key); alter table t2 add column id2 int", + "create /*vt+ allowZeroInDate=true */ table t1 (\n\tid int primary key\n);alter /*vt+ allowZeroInDate=true */ table t2 add column id2 int", + }, + { + " ; ; ;;; create table t1(id int primary key); ;; alter table t2 add column id2 int ;;", + "create /*vt+ allowZeroInDate=true */ table t1 (\n\tid int primary key\n);alter /*vt+ allowZeroInDate=true */ table t2 add column id2 int", + }, + } + for _, tcase := range tcases { + t.Run(tcase.sql, func(t *testing.T) { + result, err := applyAllowZeroInDate(tcase.sql) + assert.NoError(t, err) + assert.Equal(t, tcase.expect, result) + }) + } +} diff --git a/go/vt/sqlparser/ast_test.go b/go/vt/sqlparser/ast_test.go index ff48430899e..af3fc5f41cb 100644 --- a/go/vt/sqlparser/ast_test.go +++ b/go/vt/sqlparser/ast_test.go @@ -737,7 +737,22 @@ func TestSplitStatementToPieces(t *testing.T) { "`createtime` datetime NOT NULL DEFAULT NOW() COMMENT 'create time;'," + "`comment` varchar(100) NOT NULL DEFAULT '' COMMENT 'comment'," + "PRIMARY KEY (`id`))", - }} + }, { + input: "create table t1 (id int primary key); create table t2 (id int primary key);", + output: "create table t1 (id int primary key); create table t2 (id int primary key)", + }, { + input: ";;; create table t1 (id int primary key);;; ;create table t2 (id int primary key);", + output: " create table t1 (id int primary key);create table t2 (id int primary key)", + }, { + // The input doesn't have to be valid SQL statements! + input: ";create table t1 ;create table t2 (id;", + output: "create table t1 ;create table t2 (id", + }, { + // Ignore quoted semicolon + input: ";create table t1 ';';;;create table t2 (id;", + output: "create table t1 ';';create table t2 (id", + }, + } for _, tcase := range testcases { t.Run(tcase.input, func(t *testing.T) {