Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing Mysqld.normalizedSchema() by formally parsing the query #13866

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3c5501e
normalizedSchema: slight refactoring and adding unit tests
shlomi-noach Aug 28, 2023
0acb7d9
introduce a test case that breaks
shlomi-noach Aug 28, 2023
2e93538
use 'sqlparser.ReplaceTableQualifiers()' to inject placeholder, then …
shlomi-noach Aug 28, 2023
f13fc87
allow changing the empty qualifier
shlomi-noach Aug 28, 2023
4e2f9a0
more test cases
shlomi-noach Aug 28, 2023
76e5cf3
fix unit tests (typo in table3 creation), adapt to allowing empty sch…
shlomi-noach Aug 28, 2023
62921bd
PreflightSchemaChange is deprecated
shlomi-noach Aug 28, 2023
2390702
check error value
shlomi-noach Aug 28, 2023
0ac3f49
Use CanonicalString
shlomi-noach Aug 29, 2023
7d30621
normalize expected queries
shlomi-noach Aug 29, 2023
d170ce9
normalize definition
shlomi-noach Aug 29, 2023
bad39ff
normalize definition
shlomi-noach Aug 29, 2023
cb8100a
handle CreateDatabase, AlterDatabase, DropDatabase statements
shlomi-noach Aug 29, 2023
6db7847
slight refactor
shlomi-noach Aug 29, 2023
01c83f3
get rid of text/template, use 'ReplaceTableQualifiers()'
shlomi-noach Aug 29, 2023
514e70a
normalize CREATE DATABASE query
shlomi-noach Aug 29, 2023
efebccb
Add tablespace name test case to TestCanonicalOutput unit test
mattlord Aug 29, 2023
9ff992f
Correct test case expectation
mattlord Aug 29, 2023
ad6c6c7
Properly treat TABLESPACE clause values as case sensitive
mattlord Aug 29, 2023
e5862ce
Merge branch 'main' into safe-table-name-replacement
shlomi-noach Aug 30, 2023
6db9027
Merge branch 'safe-table-name-replacement' of github.com:planetscale/…
shlomi-noach Aug 30, 2023
ab8e8f2
Revert CanonicalString() to String(). Adapt some of the unit tests
shlomi-noach Aug 30, 2023
e784181
adapt tests
shlomi-noach Aug 30, 2023
6e0ae75
adapting test results
shlomi-noach Aug 31, 2023
48dafe8
adapting test results
shlomi-noach Aug 31, 2023
0be16bf
Test auto_increment clause removal as well
mattlord Sep 19, 2023
ce04269
Merge remote-tracking branch 'origin/main' into safe-table-name-repla…
mattlord Sep 19, 2023
8fbeb88
Adjust usage of function added after PR branch
mattlord Sep 19, 2023
b3b3882
updated comment
shlomi-noach Sep 21, 2023
4b65366
updated comment
shlomi-noach Sep 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go/test/endtoend/tabletmanager/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ var (
getSchemaT1Results8030 = "CREATE TABLE `t1` (\n `id` bigint NOT NULL,\n `value` varchar(16) DEFAULT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb3"
getSchemaT1Results80 = "CREATE TABLE `t1` (\n `id` bigint NOT NULL,\n `value` varchar(16) DEFAULT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8"
getSchemaT1Results57 = "CREATE TABLE `t1` (\n `id` bigint(20) NOT NULL,\n `value` varchar(16) DEFAULT NULL,\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8"
getSchemaV1Results = fmt.Sprintf("CREATE ALGORITHM=UNDEFINED DEFINER=`%s`@`%s` SQL SECURITY DEFINER VIEW {{.DatabaseName}}.`v1` AS select {{.DatabaseName}}.`t1`.`id` AS `id`,{{.DatabaseName}}.`t1`.`value` AS `value` from {{.DatabaseName}}.`t1`", username, hostname)
getSchemaV1Results = fmt.Sprintf("create algorithm = UNDEFINED definer = %s@%s sql security DEFINER view {{.DatabaseName}}.v1 as select {{.DatabaseName}}.t1.id as id, {{.DatabaseName}}.t1.value as value from {{.DatabaseName}}.t1", username, hostname)
)

// TabletCommands tests the basic tablet commands
Expand Down
107 changes: 28 additions & 79 deletions go/vt/mysqlctl/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
querypb "vitess.io/vitess/go/vt/proto/query"
tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata"
"vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
)

Expand Down Expand Up @@ -99,7 +100,7 @@ func (mysqld *Mysqld) GetSchema(ctx context.Context, dbName string, request *tab
if len(qr.Rows) == 0 {
return nil, fmt.Errorf("empty create database statement for %v", dbName)
}
sd.DatabaseSchema = strings.Replace(qr.Rows[0][1].ToString(), backtickDBName, "{{.DatabaseName}}", 1)
sd.DatabaseSchema = strings.Replace(qr.Rows[0][1].ToString(), backtickDBName, tmutils.DatabaseNamePlaceholder, 1)

tds, err := mysqld.collectBasicTableData(ctx, dbName, request.Tables, request.ExcludeTables, request.IncludeViews)
if err != nil {
Expand Down Expand Up @@ -249,6 +250,29 @@ func (mysqld *Mysqld) collectSchema(ctx context.Context, dbName, tableName, tabl
return fields, columns, schema, nil
}

// normalizedStatement normalizes a CREATE TABLE or CREATE VIEW statement as follows:
// - For CREATE TABLE, it stripts away any AUTO_INCREMENT=... clause.
// - For CREATE VIEW, it replaces the schema name with given `dbName`
func normalizedStatement(ctx context.Context, statementQuery, dbName, tableType string) (string, error) {
// Normalize & remove auto_increment because it changes on every insert
// FIXME(alainjobart) find a way to share this with
// vt/tabletserver/table_info.go:162
norm := statementQuery
norm = autoIncr.ReplaceAllLiteralString(norm, "")
if tableType == tmutils.TableView {
replaced, err := sqlparser.ReplaceTableQualifiers(norm, dbName, tmutils.DatabaseNamePlaceholder)
if err != nil {
// parsing unsuccessful
return norm, err
}
// Parsing successful
replaced = tmutils.UnqualifyDatabaseNamePlaceholder(replaced)
return replaced, nil
}

return norm, nil
}

// normalizedSchema returns a table schema with database names replaced, and auto_increment annotations removed.
func (mysqld *Mysqld) normalizedSchema(ctx context.Context, dbName, tableName, tableType string) (string, error) {
backtickDBName := sqlescape.EscapeID(dbName)
Expand All @@ -263,15 +287,7 @@ func (mysqld *Mysqld) normalizedSchema(ctx context.Context, dbName, tableName, t
// Normalize & remove auto_increment because it changes on every insert
// FIXME(alainjobart) find a way to share this with
// vt/tabletserver/table_info.go:162
norm := qr.Rows[0][1].ToString()
norm = autoIncr.ReplaceAllLiteralString(norm, "")
if tableType == tmutils.TableView {
// Views will have the dbname in there, replace it
// with {{.DatabaseName}}
norm = strings.Replace(norm, backtickDBName, "{{.DatabaseName}}", -1)
}

return norm, nil
return normalizedStatement(ctx, qr.Rows[0][1].ToString(), dbName, tableType)
}

// ResolveTables returns a list of actual tables+views matching a list
Expand Down Expand Up @@ -410,76 +426,9 @@ func (mysqld *Mysqld) getPrimaryKeyColumns(ctx context.Context, dbName string, t
return colMap, err
}

// PreflightSchemaChange checks the schema changes in "changes" by applying them
// to an intermediate database that has the same schema as the target database.
// PreflightSchemaChange is deprecated
func (mysqld *Mysqld) PreflightSchemaChange(ctx context.Context, dbName string, changes []string) ([]*tabletmanagerdatapb.SchemaChangeResult, error) {
results := make([]*tabletmanagerdatapb.SchemaChangeResult, len(changes))

// Get current schema from the real database.
req := &tabletmanagerdatapb.GetSchemaRequest{IncludeViews: true, TableSchemaOnly: true}
originalSchema, err := mysqld.GetSchema(ctx, dbName, req)
if err != nil {
return nil, err
}

// Populate temporary database with it.
initialCopySQL := "SET sql_log_bin = 0;\n"
initialCopySQL += "DROP DATABASE IF EXISTS _vt_preflight;\n"
initialCopySQL += "CREATE DATABASE _vt_preflight;\n"
initialCopySQL += "USE _vt_preflight;\n"
// We're not smart enough to create the tables in a foreign-key-compatible way,
// so we temporarily disable foreign key checks while adding the existing tables.
initialCopySQL += "SET foreign_key_checks = 0;\n"
for _, td := range originalSchema.TableDefinitions {
if td.Type == tmutils.TableBaseTable {
initialCopySQL += td.Schema + ";\n"
}
}
for _, td := range originalSchema.TableDefinitions {
if td.Type == tmutils.TableView {
// Views will have {{.DatabaseName}} in there, replace
// it with _vt_preflight
s := strings.Replace(td.Schema, "{{.DatabaseName}}", "`_vt_preflight`", -1)
initialCopySQL += s + ";\n"
}
}
if err = mysqld.executeSchemaCommands(ctx, initialCopySQL); err != nil {
return nil, err
}

// For each change, record the schema before and after.
for i, change := range changes {
req := &tabletmanagerdatapb.GetSchemaRequest{IncludeViews: true}
beforeSchema, err := mysqld.GetSchema(ctx, "_vt_preflight", req)
if err != nil {
return nil, err
}

// apply schema change to the temporary database
sql := "SET sql_log_bin = 0;\n"
sql += "USE _vt_preflight;\n"
sql += change
if err = mysqld.executeSchemaCommands(ctx, sql); err != nil {
return nil, err
}

// get the result
afterSchema, err := mysqld.GetSchema(ctx, "_vt_preflight", req)
if err != nil {
return nil, err
}

results[i] = &tabletmanagerdatapb.SchemaChangeResult{BeforeSchema: beforeSchema, AfterSchema: afterSchema}
}

// and clean up the extra database
dropSQL := "SET sql_log_bin = 0;\n"
dropSQL += "DROP DATABASE _vt_preflight;\n"
if err = mysqld.executeSchemaCommands(ctx, dropSQL); err != nil {
return nil, err
}

return results, nil
return nil, vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "PreflightSchemaChange is deprecated")
}

// ApplySchemaChange will apply the schema change to the given database.
Expand Down
58 changes: 58 additions & 0 deletions go/vt/mysqlctl/schema_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package mysqlctl

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql/fakesqldb"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/mysqlctl/tmutils"
querypb "vitess.io/vitess/go/vt/proto/query"
)

Expand Down Expand Up @@ -103,3 +106,58 @@ func TestColumnList(t *testing.T) {
require.Equal(t, `[name:"col1" type:VARCHAR]`, fmt.Sprintf("%+v", fields))

}

func TestNormalizedStatement(t *testing.T) {
tcases := []struct {
statement string
db string
typ string
expect string
}{
{
statement: "create table mydb.t (id int auto_increment primary key) AUTO_INCREMENT=4",
db: "mydb",
typ: tmutils.TableBaseTable,
expect: "create table mydb.t (id int auto_increment primary key)",
},
{
statement: "create table `mydb`.t (id int primary key)",
db: "mydb",
typ: tmutils.TableBaseTable,
expect: "create table `mydb`.t (id int primary key)",
},
{
statement: "create view `mydb`.v as select * from t",
db: "mydb",
typ: tmutils.TableView,
expect: "create view {{.DatabaseName}}.v as select * from t",
},
{
statement: "create view `mydb`.v as select * from `mydb`.`t`",
db: "mydb",
typ: tmutils.TableView,
expect: "create view {{.DatabaseName}}.v as select * from {{.DatabaseName}}.t",
},
{
statement: "create view `mydb`.v as select * from `mydb`.mydb",
db: "mydb",
typ: tmutils.TableView,
expect: "create view {{.DatabaseName}}.v as select * from {{.DatabaseName}}.mydb",
},
{
statement: "create view `mydb`.v as select * from `mydb`.`mydb`",
db: "mydb",
typ: tmutils.TableView,
expect: "create view {{.DatabaseName}}.v as select * from {{.DatabaseName}}.mydb",
},
}
ctx := context.Background()
for _, tcase := range tcases {
testName := tcase.statement
t.Run(testName, func(t *testing.T) {
result, err := normalizedStatement(ctx, tcase.statement, tcase.db, tcase.typ)
assert.NoError(t, err)
assert.Equal(t, tcase.expect, result)
})
}
}
31 changes: 21 additions & 10 deletions go/vt/mysqlctl/tmutils/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (

"vitess.io/vitess/go/vt/concurrency"
"vitess.io/vitess/go/vt/schema"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"

tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata"
)
Expand All @@ -38,8 +40,17 @@ const (
TableBaseTable = "BASE TABLE"
// TableView indicates the table type is a view.
TableView = "VIEW"

DatabaseNamePlaceholder = "{{.DatabaseName}}"
)

func UnqualifyDatabaseNamePlaceholder(s string) string {
return strings.Replace(s, sqlescape.EscapeID(DatabaseNamePlaceholder), DatabaseNamePlaceholder, -1)
}
func QualifyDatabaseNamePlaceholder(s string) string {
return strings.Replace(s, DatabaseNamePlaceholder, sqlescape.EscapeID(DatabaseNamePlaceholder), -1)
}

// TableDefinitionGetColumn returns the index of a column inside a
// TableDefinition.
func TableDefinitionGetColumn(td *tabletmanagerdatapb.TableDefinition, name string) (index int, ok bool) {
Expand Down Expand Up @@ -206,14 +217,15 @@ func SchemaDefinitionGetTable(sd *tabletmanagerdatapb.SchemaDefinition, table st
// SchemaDefinitionToSQLStrings converts a SchemaDefinition to an array of SQL strings. The array contains all
// the SQL statements needed for creating the database, tables, and views - in that order.
// All SQL statements will have {{.DatabaseName}} in place of the actual db name.
func SchemaDefinitionToSQLStrings(sd *tabletmanagerdatapb.SchemaDefinition) []string {
func SchemaDefinitionToSQLStrings(sd *tabletmanagerdatapb.SchemaDefinition) ([]string, error) {
sqlStrings := make([]string, 0, len(sd.TableDefinitions)+1)
createViewSQL := make([]string, 0, len(sd.TableDefinitions))

// Backtick database name since keyspace names appear in the routing rules, and they might need to be escaped.
// We unescape() them first in case we have an explicitly escaped string was specified.
createDatabaseSQL := strings.Replace(sd.DatabaseSchema, "`{{.DatabaseName}}`", "{{.DatabaseName}}", -1)
createDatabaseSQL = strings.Replace(createDatabaseSQL, "{{.DatabaseName}}", sqlescape.EscapeID("{{.DatabaseName}}"), -1)
createDatabaseSQL := sd.DatabaseSchema
createDatabaseSQL = UnqualifyDatabaseNamePlaceholder(createDatabaseSQL)
createDatabaseSQL = QualifyDatabaseNamePlaceholder(createDatabaseSQL)
sqlStrings = append(sqlStrings, createDatabaseSQL)

for _, td := range sd.TableDefinitions {
Expand All @@ -223,17 +235,16 @@ func SchemaDefinitionToSQLStrings(sd *tabletmanagerdatapb.SchemaDefinition) []st
if td.Type == TableView {
createViewSQL = append(createViewSQL, td.Schema)
} else {
lines := strings.Split(td.Schema, "\n")
for i, line := range lines {
if strings.HasPrefix(line, "CREATE TABLE `") {
lines[i] = strings.Replace(line, "CREATE TABLE `", "CREATE TABLE `{{.DatabaseName}}`.`", 1)
}
replaced, err := sqlparser.ReplaceTableQualifiers(td.Schema, "", DatabaseNamePlaceholder)
if err != nil {
// parsing unsuccessful
return nil, vterrors.Wrapf(err, "parsing schema: %v", td.Schema)
}
sqlStrings = append(sqlStrings, strings.Join(lines, "\n"))
sqlStrings = append(sqlStrings, replaced)
}
}

return append(sqlStrings, createViewSQL...)
return append(sqlStrings, createViewSQL...), nil
}

// DiffSchema generates a report on what's different between two SchemaDefinitions
Expand Down
40 changes: 24 additions & 16 deletions go/vt/mysqlctl/tmutils/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,32 @@ import (

var basicTable1 = &tabletmanagerdatapb.TableDefinition{
Name: "table1",
Schema: "table schema 1",
Schema: "create table table1 (id int primary key)",
Type: TableBaseTable,
}
var basicTable2 = &tabletmanagerdatapb.TableDefinition{
Name: "table2",
Schema: "table schema 2",
Schema: "create table table2 (id int primary key)",
Type: TableBaseTable,
}

var table3 = &tabletmanagerdatapb.TableDefinition{
Name: "table2",
Name: "table3",
Schema: "CREATE TABLE `table3` (\n" +
"id bigint not null,\n" +
"id bigint not null\n" +
") Engine=InnoDB",
Type: TableBaseTable,
}

var view1 = &tabletmanagerdatapb.TableDefinition{
Name: "view1",
Schema: "view schema 1",
Schema: "create view view1 as select id from t1",
Type: TableView,
}

var view2 = &tabletmanagerdatapb.TableDefinition{
Name: "view2",
Schema: "view schema 2",
Schema: "create view view2 as select id from t2",
Type: TableView,
}

Expand All @@ -73,7 +73,11 @@ func TestToSQLStrings(t *testing.T) {
view1,
},
},
want: []string{"CREATE DATABASE `{{.DatabaseName}}`", basicTable1.Schema, view1.Schema},
want: []string{
"CREATE DATABASE `{{.DatabaseName}}`",
"create table `{{.DatabaseName}}`.table1 (\n\tid int primary key\n)",
"create view view1 as select id from t1",
},
},
{
// SchemaDefinition doesn't need any tables or views
Expand All @@ -96,7 +100,11 @@ func TestToSQLStrings(t *testing.T) {
basicTable2,
},
},
want: []string{"CREATE DATABASE `{{.DatabaseName}}`", basicTable1.Schema, basicTable2.Schema},
want: []string{
"CREATE DATABASE `{{.DatabaseName}}`",
"create table `{{.DatabaseName}}`.table1 (\n\tid int primary key\n)",
"create table `{{.DatabaseName}}`.table2 (\n\tid int primary key\n)",
},
},
{
// multiple tables and views should be ordered with all tables before views
Expand All @@ -110,9 +118,10 @@ func TestToSQLStrings(t *testing.T) {
},
},
want: []string{
"CREATE DATABASE `{{.DatabaseName}}`",
basicTable1.Schema, basicTable2.Schema,
view1.Schema, view2.Schema,
"CREATE DATABASE `{{.DatabaseName}}`", "create table `{{.DatabaseName}}`.table1 (\n\tid int primary key\n)",
"create table `{{.DatabaseName}}`.table2 (\n\tid int primary key\n)",
"create view view1 as select id from t1",
"create view view2 as select id from t2",
},
},
{
Expand All @@ -126,16 +135,15 @@ func TestToSQLStrings(t *testing.T) {
},
want: []string{
"CREATE DATABASE `{{.DatabaseName}}`",
basicTable1.Schema,
"CREATE TABLE `{{.DatabaseName}}`.`table3` (\n" +
"id bigint not null,\n" +
") Engine=InnoDB",
"create table `{{.DatabaseName}}`.table1 (\n\tid int primary key\n)",
"create table `{{.DatabaseName}}`.table3 (\n\tid bigint not null\n) Engine InnoDB",
},
},
}

for _, tc := range testcases {
got := SchemaDefinitionToSQLStrings(tc.input)
got, err := SchemaDefinitionToSQLStrings(tc.input)
assert.NoError(t, err)
assert.Equal(t, tc.want, got)
}
}
Expand Down
Loading
Loading