From 8e199796d8fb4bf750f793c8b2e5402c91493130 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Thu, 25 Feb 2021 18:40:18 -0800 Subject: [PATCH 01/16] Check constraint starter; breaks a lot of tests, but compiles --- memory/table.go | 34 ++++++++++ sql/core.go | 17 +++++ sql/parse/parse.go | 37 +++++++--- sql/plan/alter_check.go | 145 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 225 insertions(+), 8 deletions(-) create mode 100644 sql/plan/alter_check.go diff --git a/memory/table.go b/memory/table.go index 93989df1fe..f24fa88aab 100644 --- a/memory/table.go +++ b/memory/table.go @@ -37,6 +37,7 @@ type Table struct { columns []int indexes map[string]sql.Index foreignKeys []sql.ForeignKeyConstraint + checks []sql.CheckConstraint pkIndexesEnabled bool // Data storage @@ -1048,6 +1049,11 @@ func (t *Table) GetForeignKeys(_ *sql.Context) ([]sql.ForeignKeyConstraint, erro return t.foreignKeys, nil } +// GetForeignKeys implements sql.ForeignKeyTable +func (t *Table) GetChecks(_ *sql.Context) ([]sql.CheckConstraint, error) { + return t.checks, nil +} + // CreateForeignKey implements sql.ForeignKeyAlterableTable. Foreign keys are not enforced on update / delete. func (t *Table) CreateForeignKey(_ *sql.Context, fkName string, columns []string, referencedTable string, referencedColumns []string, onUpdate, onDelete sql.ForeignKeyReferenceOption) error { for _, key := range t.foreignKeys { @@ -1079,6 +1085,34 @@ func (t *Table) DropForeignKey(ctx *sql.Context, fkName string) error { return nil } +// CreateCheck implements sql.CheckAlterableTable +func (t *Table) CreateCheckConstraint(_ *sql.Context, chName string, expr sql.Expression, enforced bool) error { + for _, key := range t.checks { + if key.Name == chName { + return fmt.Errorf("constraint %s already exists", chName) + } + } + + t.checks = append(t.checks, sql.CheckConstraint{ + Name: chName, + Expr: expr, + Enforced: enforced, + }) + + return nil +} + +// func (t *Table) DropCheck(ctx *sql.Context, chName string) error {} implements sql.CheckAlterableTable. +func (t *Table) DropCheckConstraint(ctx *sql.Context, chName string) error { + for i, key := range t.checks { + if key.Name == chName { + t.checks = append(t.checks[:i], t.checks[i+1:]...) + return nil + } + } + return nil +} + func (t *Table) createIndex(name string, columns []sql.IndexColumn, constraint sql.IndexConstraint, comment string) (sql.Index, error) { if t.indexes[name] != nil { // TODO: extract a standard error type for this diff --git a/sql/core.go b/sql/core.go index 39cc599d89..0d4e2b126c 100644 --- a/sql/core.go +++ b/sql/core.go @@ -205,6 +205,13 @@ type ForeignKeyConstraint struct { OnDelete ForeignKeyReferenceOption } +// CheckConstraint declares a constraint between the columns of two tables. +type CheckConstraint struct { + Name string + Expr Expression + Enforced bool +} + // TableWrapper is a node that wraps the real table. This is needed because // wrappers cannot implement some methods the table may implement. type TableWrapper interface { @@ -329,6 +336,16 @@ type ForeignKeyAlterableTable interface { DropForeignKey(ctx *Context, fkName string) error } +// ForeignKeyAlterableTable represents a table that supports foreign key modification operations. +type CheckAlterableTable interface { + Table + // CreateCheckConstraint creates an check constraint for this table, using the provided parameters. + // Returns an error if the constraint name already exists. + CreateCheckConstraint(ctx *Context, chName string, expr Expression, enforced bool) error + // DropCheckConstraint removes a check constraint from the database. + DropCheckConstraint(ctx *Context, chName string) error +} + // InsertableTable is a table that can process insertion of new rows. type InsertableTable interface { Table diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 06b74850a0..394b161815 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -702,24 +702,31 @@ func convertAlterTable(ctx *sql.Context, ddl *sqlparser.DDL) (sql.Node, error) { //TODO: support multiple constraints in a single ALTER statement if ddl.ConstraintAction != "" && len(ddl.TableSpec.Constraints) == 1 { table := tableNameToUnresolvedTable(ddl.Table) - parsedConstraint, err := convertConstraintDefinition(ctx, ddl.TableSpec.Constraints[0]) + parsedConstraint, err := convertConstraintDefinition(ctx, ddl.TableSpec.Constraints[0], ddl.TableSpec.Constraints[0].Name) if err != nil { return nil, err } switch strings.ToLower(ddl.ConstraintAction) { case sqlparser.AddStr: - if fkConstraint, ok := parsedConstraint.(*sql.ForeignKeyConstraint); ok { + switch c := parsedConstraint.(type) { + case *sql.ForeignKeyConstraint: return plan.NewAlterAddForeignKey( table, - plan.NewUnresolvedTable(fkConstraint.ReferencedTable, ddl.Table.Qualifier.String()), - fkConstraint), nil - } else { + plan.NewUnresolvedTable(c.ReferencedTable, ddl.Table.Qualifier.String()), + c), nil + + case *sql.CheckConstraint: + return plan.NewAlterAddCheck(table, c), nil + default: return nil, ErrUnsupportedFeature.New(sqlparser.String(ddl)) + } case sqlparser.DropStr: switch c := parsedConstraint.(type) { case *sql.ForeignKeyConstraint: return plan.NewAlterDropForeignKey(table, c), nil + case *sql.CheckConstraint: + return plan.NewAlterDropCheck(table, c), nil case namedConstraint: // For simple named constraint drops, fill in a partial foreign key constraint. This will need to be changed if // we ever support other kinds of constraints than foreign keys (e.g. CHECK) @@ -913,8 +920,13 @@ func convertCreateTable(ctx *sql.Context, c *sqlparser.DDL) (sql.Node, error) { } var fkDefs []*sql.ForeignKeyConstraint + constraintCnt := 0 for _, unknownConstraint := range c.TableSpec.Constraints { - parsedConstraint, err := convertConstraintDefinition(ctx, unknownConstraint) + name := unknownConstraint.Name + if name == "" { + name = fmt.Sprintf("%s_chk_%d", c.Table.Name, constraintCnt) + } + parsedConstraint, err := convertConstraintDefinition(ctx, unknownConstraint, name) if err != nil { return nil, err } @@ -997,7 +1009,7 @@ type namedConstraint struct { name string } -func convertConstraintDefinition(ctx *sql.Context, cd *sqlparser.ConstraintDefinition) (interface{}, error) { +func convertConstraintDefinition(ctx *sql.Context, cd *sqlparser.ConstraintDefinition, name string) (interface{}, error) { if fkConstraint, ok := cd.Details.(*sqlparser.ForeignKeyDefinition); ok { columns := make([]string, len(fkConstraint.Source)) for i, col := range fkConstraint.Source { @@ -1008,13 +1020,22 @@ func convertConstraintDefinition(ctx *sql.Context, cd *sqlparser.ConstraintDefin refColumns[i] = col.String() } return &sql.ForeignKeyConstraint{ - Name: cd.Name, + Name: name, Columns: columns, ReferencedTable: fkConstraint.ReferencedTable.Name.String(), ReferencedColumns: refColumns, OnUpdate: convertReferenceAction(fkConstraint.OnUpdate), OnDelete: convertReferenceAction(fkConstraint.OnDelete), }, nil + } else if cConstraint, ok := cd.Details.(*sqlparser.CheckConstraintDefinition); ok { + c, err := exprToExpression(ctx, cConstraint.Expr) + if err != nil { + return nil, err + } + return &sql.CheckConstraint{ + Name: name, + Expr: c, + }, nil } else if len(cd.Name) > 0 && cd.Details == nil { return namedConstraint{cd.Name}, nil } diff --git a/sql/plan/alter_check.go b/sql/plan/alter_check.go new file mode 100644 index 0000000000..d6508fac1b --- /dev/null +++ b/sql/plan/alter_check.go @@ -0,0 +1,145 @@ +// Copyright 2020-2021 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plan + +import ( + "fmt" + "gopkg.in/src-d/go-errors.v1" + + "github.com/dolthub/go-mysql-server/sql" +) + +var ( + // ErrNoCheckSupport is returned when the table does not support FOREIGN KEY operations. + ErrNoCheckSupport = errors.NewKind("the table does not support foreign key operations: %s") + // ErrCheckMissingColumns is returned when an ALTER TABLE ADD FOREIGN KEY statement does not provide any columns + ErrCheckMissingColumns = errors.NewKind("cannot create a foreign key without columns") + // ErrAddCheckDuplicateColumn is returned when an ALTER TABLE ADD FOREIGN KEY statement has the same column multiple times + ErrAddCheckDuplicateColumn = errors.NewKind("cannot have duplicates of columns in a foreign key: `%v`") +) + +type CreateCheck struct { + UnaryNode + ChDef *sql.CheckConstraint +} + +type DropCheck struct { + UnaryNode + ChDef *sql.CheckConstraint +} + +func NewAlterAddCheck(table sql.Node, chDef *sql.CheckConstraint) *CreateCheck { + return &CreateCheck{ + UnaryNode: UnaryNode{table}, + ChDef: chDef, + } +} + +func NewAlterDropCheck(table sql.Node, chDef *sql.CheckConstraint) *DropCheck { + return &DropCheck{ + UnaryNode: UnaryNode{Child: table}, + ChDef: chDef, + } +} + +func getCheckAlterable(node sql.Node) (sql.CheckAlterableTable, error) { + switch node := node.(type) { + case sql.CheckAlterableTable: + return node, nil + case *ResolvedTable: + return getCheckAlterableTable(node.Table) + default: + return nil, ErrNoCheckSupport.New(node.String()) + } +} + +func getCheckAlterableTable(t sql.Table) (sql.CheckAlterableTable, error) { + switch t := t.(type) { + case sql.CheckAlterableTable: + return t, nil + case sql.TableWrapper: + return getCheckAlterableTable(t.Underlying()) + default: + return nil, ErrNoCheckSupport.New(t.Name()) + } +} + +// Execute inserts the rows in the database. +func (p *CreateCheck) Execute(ctx *sql.Context) error { + chAlterable, err := getCheckAlterable(p.UnaryNode.Child) + if err != nil { + return err + } + return chAlterable.CreateCheckConstraint(ctx, p.ChDef.Name, p.ChDef.Expr, p.ChDef.Enforced) +} + +// Execute inserts the rows in the database. +func (p *DropCheck) Execute(ctx *sql.Context) error { + chAlterable, err := getCheckAlterable(p.UnaryNode.Child) + if err != nil { + return err + } + return chAlterable.DropCheckConstraint(ctx, p.ChDef.Name) +} + +// RowIter implements the Node interface. +func (p *DropCheck) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { + err := p.Execute(ctx) + if err != nil { + return nil, err + } + return sql.RowsToRowIter(), nil +} + +// WithChildren implements the Node interface. +func (p *DropCheck) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) + } + return NewAlterDropCheck(children[0], p.ChDef), nil +} + +// WithChildren implements the Node interface. +func (p *CreateCheck) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) + } + return NewAlterDropCheck(children[0], p.ChDef), nil +} + +func (p *CreateCheck) Schema() sql.Schema { return nil } +func (p *DropCheck) Schema() sql.Schema { return nil } + +func (p *CreateCheck) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { + err := p.Execute(ctx) + if err != nil { + return nil, err + } + return sql.RowsToRowIter(), nil +} + +func (p DropCheck) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("DropCheck(%s)", p.ChDef.Name) + _ = pr.WriteChildren(fmt.Sprintf("Table(%s)", p.UnaryNode.Child.String())) + return pr.String() +} + +func (p CreateCheck) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("AddCheck(%s)", p.ChDef.Name) + _ = pr.WriteChildren(fmt.Sprintf("Table(%s)", p.UnaryNode.Child.String())) + return pr.String() +} From a2960e65150b14f9206cadd8b178e7a8f340b9b0 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Fri, 26 Feb 2021 09:46:19 -0800 Subject: [PATCH 02/16] Fix test foreign key names --- sql/parse/parse.go | 2 +- sql/parse/parse_test.go | 10 +++++----- sql/plan/alter_check.go | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/parse/parse.go b/sql/parse/parse.go index e4c8dc63b7..d7e7c16cd3 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -924,7 +924,7 @@ func convertCreateTable(ctx *sql.Context, c *sqlparser.DDL) (sql.Node, error) { for _, unknownConstraint := range c.TableSpec.Constraints { name := unknownConstraint.Name if name == "" { - name = fmt.Sprintf("%s_chk_%d", c.Table.Name, constraintCnt) + name = fmt.Sprintf("%s_constraint_%d", c.Table.Name, constraintCnt) } parsedConstraint, err := convertConstraintDefinition(ctx, unknownConstraint, name) if err != nil { diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 873f1d2bf8..c06a0942b4 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -368,7 +368,7 @@ var fixtures = map[string]sql.Node{ false, nil, []*sql.ForeignKeyConstraint{{ - Name: "", + Name: "t1_constraint_0", Columns: []string{"b_id"}, ReferencedTable: "t0", ReferencedColumns: []string{"b"}, @@ -418,7 +418,7 @@ var fixtures = map[string]sql.Node{ false, nil, []*sql.ForeignKeyConstraint{{ - Name: "", + Name: "t1_constraint_0", Columns: []string{"b_id"}, ReferencedTable: "t0", ReferencedColumns: []string{"b"}, @@ -443,7 +443,7 @@ var fixtures = map[string]sql.Node{ false, nil, []*sql.ForeignKeyConstraint{{ - Name: "", + Name: "t1_constraint_0", Columns: []string{"b_id"}, ReferencedTable: "t0", ReferencedColumns: []string{"b"}, @@ -468,7 +468,7 @@ var fixtures = map[string]sql.Node{ false, nil, []*sql.ForeignKeyConstraint{{ - Name: "", + Name: "t1_constraint_0", Columns: []string{"b_id"}, ReferencedTable: "t0", ReferencedColumns: []string{"b"}, @@ -498,7 +498,7 @@ var fixtures = map[string]sql.Node{ false, nil, []*sql.ForeignKeyConstraint{{ - Name: "", + Name: "t1_constraint_0", Columns: []string{"b_id", "c_id"}, ReferencedTable: "t0", ReferencedColumns: []string{"b", "c"}, diff --git a/sql/plan/alter_check.go b/sql/plan/alter_check.go index d6508fac1b..49ac34b392 100644 --- a/sql/plan/alter_check.go +++ b/sql/plan/alter_check.go @@ -1,4 +1,4 @@ -// Copyright 2020-2021 Dolthub, Inc. +// Copyright 2021 Dolthub, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. From ca8068f3a188ee74f3fb369c84b5be01240ecdc0 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Fri, 26 Feb 2021 16:34:25 -0800 Subject: [PATCH 03/16] Create table tests --- memory/table.go | 8 +++--- sql/core.go | 6 ++-- sql/parse/parse.go | 24 +++++++++------- sql/parse/parse_test.go | 62 +++++++++++++++++++++++++++++++++++++++++ sql/plan/alter_check.go | 16 ++++------- sql/plan/ddl.go | 18 ++++++++++-- sql/storedprocedure.go | 7 +++-- 7 files changed, 110 insertions(+), 31 deletions(-) diff --git a/memory/table.go b/memory/table.go index f24fa88aab..0fdb46f819 100644 --- a/memory/table.go +++ b/memory/table.go @@ -37,7 +37,7 @@ type Table struct { columns []int indexes map[string]sql.Index foreignKeys []sql.ForeignKeyConstraint - checks []sql.CheckConstraint + checks []sql.CheckConstraint pkIndexesEnabled bool // Data storage @@ -1049,7 +1049,7 @@ func (t *Table) GetForeignKeys(_ *sql.Context) ([]sql.ForeignKeyConstraint, erro return t.foreignKeys, nil } -// GetForeignKeys implements sql.ForeignKeyTable +// GetChecks implements sql.ForeignKeyTable func (t *Table) GetChecks(_ *sql.Context) ([]sql.CheckConstraint, error) { return t.checks, nil } @@ -1094,8 +1094,8 @@ func (t *Table) CreateCheckConstraint(_ *sql.Context, chName string, expr sql.Ex } t.checks = append(t.checks, sql.CheckConstraint{ - Name: chName, - Expr: expr, + Name: chName, + Expr: expr, Enforced: enforced, }) diff --git a/sql/core.go b/sql/core.go index 92fe30e2fa..521b816398 100644 --- a/sql/core.go +++ b/sql/core.go @@ -227,9 +227,9 @@ type ForeignKeyConstraint struct { // CheckConstraint declares a constraint between the columns of two tables. type CheckConstraint struct { - Name string - Expr Expression - Enforced bool + Name string + Expr Expression + Enforced bool } // TableWrapper is a node that wraps the real table. This is needed because diff --git a/sql/parse/parse.go b/sql/parse/parse.go index d7e7c16cd3..8a3b5215ee 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -920,6 +920,7 @@ func convertCreateTable(ctx *sql.Context, c *sqlparser.DDL) (sql.Node, error) { } var fkDefs []*sql.ForeignKeyConstraint + var chDefs []*sql.CheckConstraint constraintCnt := 0 for _, unknownConstraint := range c.TableSpec.Constraints { name := unknownConstraint.Name @@ -933,6 +934,8 @@ func convertCreateTable(ctx *sql.Context, c *sqlparser.DDL) (sql.Node, error) { switch constraint := parsedConstraint.(type) { case *sql.ForeignKeyConstraint: fkDefs = append(fkDefs, constraint) + case *sql.CheckConstraint: + chDefs = append(chDefs, constraint) default: return nil, ErrUnknownConstraintDefinition.New(unknownConstraint.Name, unknownConstraint) } @@ -1002,7 +1005,7 @@ func convertCreateTable(ctx *sql.Context, c *sqlparser.DDL) (sql.Node, error) { } return plan.NewCreateTable( - sql.UnresolvedDatabase(""), c.Table.Name.String(), schema, c.IfNotExists, idxDefs, fkDefs), nil + sql.UnresolvedDatabase(""), c.Table.Name.String(), schema, c.IfNotExists, idxDefs, fkDefs, chDefs), nil } type namedConstraint struct { @@ -1027,15 +1030,16 @@ func convertConstraintDefinition(ctx *sql.Context, cd *sqlparser.ConstraintDefin OnUpdate: convertReferenceAction(fkConstraint.OnUpdate), OnDelete: convertReferenceAction(fkConstraint.OnDelete), }, nil - } else if cConstraint, ok := cd.Details.(*sqlparser.CheckConstraintDefinition); ok { - c, err := exprToExpression(ctx, cConstraint.Expr) - if err != nil { - return nil, err - } - return &sql.CheckConstraint{ - Name: name, - Expr: c, - }, nil + } else if chConstraint, ok := cd.Details.(*sqlparser.CheckConstraintDefinition); ok { + c, err := exprToExpression(ctx, chConstraint.Expr) + if err != nil { + return nil, err + } + return &sql.CheckConstraint{ + Name: name, + Expr: c, + Enforced: chConstraint.Enforced, + }, nil } else if len(cd.Name) > 0 && cd.Details == nil { return namedConstraint{cd.Name}, nil } diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index c06a0942b4..8c0b947014 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -84,6 +84,7 @@ var fixtures = map[string]sql.Node{ false, nil, nil, + nil, ), `CREATE TABLE t1(a INTEGER NOT NULL PRIMARY KEY, b TEXT)`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -102,6 +103,7 @@ var fixtures = map[string]sql.Node{ false, nil, nil, + nil, ), `CREATE TABLE t1(a INTEGER NOT NULL PRIMARY KEY COMMENT "hello", b TEXT COMMENT "goodbye")`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -122,6 +124,7 @@ var fixtures = map[string]sql.Node{ false, nil, nil, + nil, ), `CREATE TABLE t1(a INTEGER, b TEXT, PRIMARY KEY (a))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -140,6 +143,7 @@ var fixtures = map[string]sql.Node{ false, nil, nil, + nil, ), `CREATE TABLE t1(a INTEGER, b TEXT, PRIMARY KEY (a, b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -158,6 +162,7 @@ var fixtures = map[string]sql.Node{ false, nil, nil, + nil, ), `CREATE TABLE IF NOT EXISTS t1(a INTEGER, b TEXT, PRIMARY KEY (a, b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -176,6 +181,7 @@ var fixtures = map[string]sql.Node{ true, nil, nil, + nil, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, INDEX (b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -200,6 +206,7 @@ var fixtures = map[string]sql.Node{ Comment: "", }}, nil, + nil, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, INDEX idx_name (b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -224,6 +231,7 @@ var fixtures = map[string]sql.Node{ Comment: "", }}, nil, + nil, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, INDEX idx_name (b) COMMENT 'hi')`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -248,6 +256,7 @@ var fixtures = map[string]sql.Node{ Comment: "hi", }}, nil, + nil, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, UNIQUE INDEX (b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -272,6 +281,7 @@ var fixtures = map[string]sql.Node{ Comment: "", }}, nil, + nil, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, UNIQUE (b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -296,6 +306,7 @@ var fixtures = map[string]sql.Node{ Comment: "", }}, nil, + nil, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, INDEX (b, a))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -320,6 +331,7 @@ var fixtures = map[string]sql.Node{ Comment: "", }}, nil, + nil, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, INDEX (b), INDEX (b, a))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -350,6 +362,7 @@ var fixtures = map[string]sql.Node{ Comment: "", }}, nil, + nil, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, FOREIGN KEY (b_id) REFERENCES t0(b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -375,6 +388,7 @@ var fixtures = map[string]sql.Node{ OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, }}, + nil, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, CONSTRAINT fk_name FOREIGN KEY (b_id) REFERENCES t0(b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -400,6 +414,7 @@ var fixtures = map[string]sql.Node{ OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, }}, + nil, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, FOREIGN KEY (b_id) REFERENCES t0(b) ON UPDATE CASCADE)`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -425,6 +440,7 @@ var fixtures = map[string]sql.Node{ OnUpdate: sql.ForeignKeyReferenceOption_Cascade, OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, }}, + nil, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, FOREIGN KEY (b_id) REFERENCES t0(b) ON DELETE RESTRICT)`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -450,6 +466,7 @@ var fixtures = map[string]sql.Node{ OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, OnDelete: sql.ForeignKeyReferenceOption_Restrict, }}, + nil, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, FOREIGN KEY (b_id) REFERENCES t0(b) ON UPDATE SET NULL ON DELETE NO ACTION)`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -475,6 +492,7 @@ var fixtures = map[string]sql.Node{ OnUpdate: sql.ForeignKeyReferenceOption_SetNull, OnDelete: sql.ForeignKeyReferenceOption_NoAction, }}, + nil, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, c_id BIGINT, FOREIGN KEY (b_id, c_id) REFERENCES t0(b, c))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -505,6 +523,7 @@ var fixtures = map[string]sql.Node{ OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, }}, + nil, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, c_id BIGINT, CONSTRAINT fk_name FOREIGN KEY (b_id, c_id) REFERENCES t0(b, c) ON UPDATE RESTRICT ON DELETE CASCADE)`: plan.NewCreateTable( sql.UnresolvedDatabase(""), @@ -535,6 +554,49 @@ var fixtures = map[string]sql.Node{ OnUpdate: sql.ForeignKeyReferenceOption_Restrict, OnDelete: sql.ForeignKeyReferenceOption_Cascade, }}, + nil, + ), + `CREATE TABLE t1(a INTEGER PRIMARY KEY, CHECK (a > 0))`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }}, + false, + nil, + nil, + []*sql.CheckConstraint{{ + Name: "t1_constraint_0", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }}, + ), + `CREATE TABLE t1(a INTEGER PRIMARY KEY, CONSTRAINT ch1 CHECK (a > 0))`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }}, + false, + nil, + nil, + []*sql.CheckConstraint{{ + Name: "ch1", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }}, ), `DROP TABLE foo;`: plan.NewDropTable( sql.UnresolvedDatabase(""), false, "foo", diff --git a/sql/plan/alter_check.go b/sql/plan/alter_check.go index 49ac34b392..ce8b7a1c2d 100644 --- a/sql/plan/alter_check.go +++ b/sql/plan/alter_check.go @@ -22,17 +22,13 @@ import ( ) var ( - // ErrNoCheckSupport is returned when the table does not support FOREIGN KEY operations. - ErrNoCheckSupport = errors.NewKind("the table does not support foreign key operations: %s") - // ErrCheckMissingColumns is returned when an ALTER TABLE ADD FOREIGN KEY statement does not provide any columns - ErrCheckMissingColumns = errors.NewKind("cannot create a foreign key without columns") - // ErrAddCheckDuplicateColumn is returned when an ALTER TABLE ADD FOREIGN KEY statement has the same column multiple times - ErrAddCheckDuplicateColumn = errors.NewKind("cannot have duplicates of columns in a foreign key: `%v`") + // ErrNoCheckConstraintSupport is returned when the table does not support CONSTRAINT CHECK operations. + ErrNoCheckConstraintSupport = errors.NewKind("the table does not support check constraint operations: %s") ) type CreateCheck struct { UnaryNode - ChDef *sql.CheckConstraint + ChDef *sql.CheckConstraint } type DropCheck struct { @@ -43,7 +39,7 @@ type DropCheck struct { func NewAlterAddCheck(table sql.Node, chDef *sql.CheckConstraint) *CreateCheck { return &CreateCheck{ UnaryNode: UnaryNode{table}, - ChDef: chDef, + ChDef: chDef, } } @@ -61,7 +57,7 @@ func getCheckAlterable(node sql.Node) (sql.CheckAlterableTable, error) { case *ResolvedTable: return getCheckAlterableTable(node.Table) default: - return nil, ErrNoCheckSupport.New(node.String()) + return nil, ErrNoCheckConstraintSupport.New(node.String()) } } @@ -72,7 +68,7 @@ func getCheckAlterableTable(t sql.Table) (sql.CheckAlterableTable, error) { case sql.TableWrapper: return getCheckAlterableTable(t.Underlying()) default: - return nil, ErrNoCheckSupport.New(t.Name()) + return nil, ErrNoCheckConstraintSupport.New(t.Name()) } } diff --git a/sql/plan/ddl.go b/sql/plan/ddl.go index 1371257f9f..f933712224 100644 --- a/sql/plan/ddl.go +++ b/sql/plan/ddl.go @@ -82,6 +82,7 @@ type CreateTable struct { schema sql.Schema ifNotExists bool fkDefs []*sql.ForeignKeyConstraint + chDefs []*sql.CheckConstraint idxDefs []*IndexDefinition like sql.Node } @@ -91,7 +92,7 @@ var _ sql.Node = (*CreateTable)(nil) var _ sql.Expressioner = (*CreateTable)(nil) // NewCreateTable creates a new CreateTable node -func NewCreateTable(db sql.Database, name string, schema sql.Schema, ifNotExists bool, idxDefs []*IndexDefinition, fkDefs []*sql.ForeignKeyConstraint) *CreateTable { +func NewCreateTable(db sql.Database, name string, schema sql.Schema, ifNotExists bool, idxDefs []*IndexDefinition, fkDefs []*sql.ForeignKeyConstraint, chDefs []*sql.CheckConstraint) *CreateTable { for _, s := range schema { s.Source = name } @@ -103,6 +104,7 @@ func NewCreateTable(db sql.Database, name string, schema sql.Schema, ifNotExists ifNotExists: ifNotExists, idxDefs: idxDefs, fkDefs: fkDefs, + chDefs: chDefs, } } @@ -151,7 +153,7 @@ func (c *CreateTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error } //TODO: in the event that foreign keys or indexes aren't supported, you'll be left with a created table and no foreign keys/indexes //this also means that if a foreign key or index fails, you'll only have what was declared up to the failure - if len(c.idxDefs) > 0 || len(c.fkDefs) > 0 { + if len(c.idxDefs) > 0 || len(c.fkDefs) > 0 || len(c.chDefs) > 0 { tableNode, ok, err := c.db.GetTableInsensitive(ctx, c.name) if err != nil { return sql.RowsToRowIter(), err @@ -183,6 +185,18 @@ func (c *CreateTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error } } } + if len(c.chDefs) > 0 { + chAlterable, ok := tableNode.(sql.CheckAlterableTable) + if !ok { + return sql.RowsToRowIter(), ErrNoCheckConstraintSupport.New(c.name) + } + for _, chDef := range c.chDefs { + err = chAlterable.CreateCheckConstraint(ctx, chDef.Name, chDef.Expr, chDef.Enforced) + if err != nil { + return sql.RowsToRowIter(), err + } + } + } } return sql.RowsToRowIter(), nil } diff --git a/sql/storedprocedure.go b/sql/storedprocedure.go index cf7905bb57..0fd0f18968 100644 --- a/sql/storedprocedure.go +++ b/sql/storedprocedure.go @@ -22,6 +22,7 @@ import ( // ProcedureSecurityContext determines whether the stored procedure is executed using the privileges of the definer or // the invoker. type ProcedureSecurityContext byte + const ( // ProcedureSecurityContext_Definer uses the definer's security context. ProcedureSecurityContext_Definer ProcedureSecurityContext = iota @@ -31,6 +32,7 @@ const ( // ProcedureParamDirection represents the use case of the stored procedure parameter. type ProcedureParamDirection byte + const ( // ProcedureParamDirection_In means the parameter passes its contained value to the stored procedure. ProcedureParamDirection_In ProcedureParamDirection = iota @@ -45,12 +47,13 @@ const ( // ProcedureParam represents the parameter of a stored procedure. type ProcedureParam struct { Direction ProcedureParamDirection // Direction is the direction of the parameter. - Name string // Name is the name of the parameter. - Type Type // Type is the SQL type of the parameter. + Name string // Name is the name of the parameter. + Type Type // Type is the SQL type of the parameter. } // Characteristic represents a characteristic that is defined on either a stored procedure or stored function. type Characteristic byte + const ( Characteristic_LanguageSql Characteristic = iota Characteristic_Deterministic From 61d5a1b8930ff5e32f841776e6017be8d0a4ea69 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Mon, 1 Mar 2021 09:33:51 -0800 Subject: [PATCH 04/16] Remove constraint name generation, add check tests --- sql/parse/parse.go | 15 +++++---------- sql/parse/parse_test.go | 34 ++++++++++++++++++++++++++++------ 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 8a3b5215ee..cf94b7cd44 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -702,7 +702,7 @@ func convertAlterTable(ctx *sql.Context, ddl *sqlparser.DDL) (sql.Node, error) { //TODO: support multiple constraints in a single ALTER statement if ddl.ConstraintAction != "" && len(ddl.TableSpec.Constraints) == 1 { table := tableNameToUnresolvedTable(ddl.Table) - parsedConstraint, err := convertConstraintDefinition(ctx, ddl.TableSpec.Constraints[0], ddl.TableSpec.Constraints[0].Name) + parsedConstraint, err := convertConstraintDefinition(ctx, ddl.TableSpec.Constraints[0]) if err != nil { return nil, err } @@ -921,13 +921,8 @@ func convertCreateTable(ctx *sql.Context, c *sqlparser.DDL) (sql.Node, error) { var fkDefs []*sql.ForeignKeyConstraint var chDefs []*sql.CheckConstraint - constraintCnt := 0 for _, unknownConstraint := range c.TableSpec.Constraints { - name := unknownConstraint.Name - if name == "" { - name = fmt.Sprintf("%s_constraint_%d", c.Table.Name, constraintCnt) - } - parsedConstraint, err := convertConstraintDefinition(ctx, unknownConstraint, name) + parsedConstraint, err := convertConstraintDefinition(ctx, unknownConstraint) if err != nil { return nil, err } @@ -1012,7 +1007,7 @@ type namedConstraint struct { name string } -func convertConstraintDefinition(ctx *sql.Context, cd *sqlparser.ConstraintDefinition, name string) (interface{}, error) { +func convertConstraintDefinition(ctx *sql.Context, cd *sqlparser.ConstraintDefinition) (interface{}, error) { if fkConstraint, ok := cd.Details.(*sqlparser.ForeignKeyDefinition); ok { columns := make([]string, len(fkConstraint.Source)) for i, col := range fkConstraint.Source { @@ -1023,7 +1018,7 @@ func convertConstraintDefinition(ctx *sql.Context, cd *sqlparser.ConstraintDefin refColumns[i] = col.String() } return &sql.ForeignKeyConstraint{ - Name: name, + Name: cd.Name, Columns: columns, ReferencedTable: fkConstraint.ReferencedTable.Name.String(), ReferencedColumns: refColumns, @@ -1036,7 +1031,7 @@ func convertConstraintDefinition(ctx *sql.Context, cd *sqlparser.ConstraintDefin return nil, err } return &sql.CheckConstraint{ - Name: name, + Name: cd.Name, Expr: c, Enforced: chConstraint.Enforced, }, nil diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 8c0b947014..667e8bb233 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -381,7 +381,7 @@ var fixtures = map[string]sql.Node{ false, nil, []*sql.ForeignKeyConstraint{{ - Name: "t1_constraint_0", + Name: "", Columns: []string{"b_id"}, ReferencedTable: "t0", ReferencedColumns: []string{"b"}, @@ -433,7 +433,7 @@ var fixtures = map[string]sql.Node{ false, nil, []*sql.ForeignKeyConstraint{{ - Name: "t1_constraint_0", + Name: "", Columns: []string{"b_id"}, ReferencedTable: "t0", ReferencedColumns: []string{"b"}, @@ -459,7 +459,7 @@ var fixtures = map[string]sql.Node{ false, nil, []*sql.ForeignKeyConstraint{{ - Name: "t1_constraint_0", + Name: "", Columns: []string{"b_id"}, ReferencedTable: "t0", ReferencedColumns: []string{"b"}, @@ -485,7 +485,7 @@ var fixtures = map[string]sql.Node{ false, nil, []*sql.ForeignKeyConstraint{{ - Name: "t1_constraint_0", + Name: "", Columns: []string{"b_id"}, ReferencedTable: "t0", ReferencedColumns: []string{"b"}, @@ -516,7 +516,7 @@ var fixtures = map[string]sql.Node{ false, nil, []*sql.ForeignKeyConstraint{{ - Name: "t1_constraint_0", + Name: "", Columns: []string{"b_id", "c_id"}, ReferencedTable: "t0", ReferencedColumns: []string{"b", "c"}, @@ -569,7 +569,7 @@ var fixtures = map[string]sql.Node{ nil, nil, []*sql.CheckConstraint{{ - Name: "t1_constraint_0", + Name: "", Expr: expression.NewGreaterThan( expression.NewUnresolvedColumn("a"), expression.NewLiteral(int8(0), sql.Int8), @@ -795,6 +795,28 @@ var fixtures = map[string]sql.Node{ OnDelete: sql.ForeignKeyReferenceOption_Cascade, }, ), + `ALTER TABLE t1 ADD CHECK (a > 0)`: plan.NewAlterAddCheck( + plan.NewUnresolvedTable("t1", ""), + &sql.CheckConstraint{ + Name: "", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }, + ), + `ALTER TABLE t1 ADD CONSTRAINT ch1 CHECK (a > 0)`: plan.NewAlterAddCheck( + plan.NewUnresolvedTable("t1", ""), + &sql.CheckConstraint{ + Name: "ch1", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }, + ), `ALTER TABLE t1 DROP FOREIGN KEY fk_name`: plan.NewAlterDropForeignKey( plan.NewUnresolvedTable("t1", ""), &sql.ForeignKeyConstraint{ From b820244dd920835daea2c8003d574ff1dbbc7923 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Mon, 1 Mar 2021 15:09:38 -0800 Subject: [PATCH 05/16] Engine tests progress --- engine.go | 4 +- enginetest/enginetests.go | 148 ++++++++++++++++++++++++++++ enginetest/harness.go | 9 ++ enginetest/memory_engine_test.go | 12 +++ memory/table.go | 35 +++++-- sql/analyzer/resolve_create_like.go | 2 +- sql/core.go | 7 ++ sql/errors.go | 7 ++ sql/parse/parse.go | 11 ++- sql/plan/alter_check.go | 43 +++++++- sql/plan/ddl_test.go | 2 +- 11 files changed, 263 insertions(+), 17 deletions(-) diff --git a/engine.go b/engine.go index e675b7f41f..0a82c607ea 100644 --- a/engine.go +++ b/engine.go @@ -176,7 +176,7 @@ func (e *Engine) QueryWithBindings( case *plan.CreateIndex: typ = sql.CreateIndexProcess perm = auth.ReadPerm | auth.WritePerm - case *plan.CreateForeignKey, *plan.DropForeignKey, *plan.AlterIndex, *plan.CreateView, + case *plan.CreateForeignKey, *plan.CreateCheck, *plan.DropForeignKey, *plan.AlterIndex, *plan.CreateView, *plan.DeleteFrom, *plan.DropIndex, *plan.DropView, *plan.InsertInto, *plan.LockTables, *plan.UnlockTables, *plan.Update: @@ -242,7 +242,7 @@ func ResolveDefaults(tableName string, schema []*ColumnWithRawDefault) (sql.Sche return unresolvedSchema, nil } // *plan.CreateTable properly handles resolving default values, so we hijack it - createTable := plan.NewCreateTable(db, tableName, unresolvedSchema, false, nil, nil) + createTable := plan.NewCreateTable(db, tableName, unresolvedSchema, false, nil, nil, nil) analyzed, err := e.Analyzer.Analyze(ctx, createTable, nil) if err != nil { return nil, err diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 62d3aa63a9..3c9d60c126 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -114,6 +114,14 @@ func createForeignKeys(t *testing.T, harness Harness, engine *sqle.Engine) { } } +func createCheckConstraint(t *testing.T, harness Harness, engine *sqle.Engine) { + if chk, ok := harness.(CheckConstraintHarness); ok && chk.SupportsCheckConstraint() { + ctx := NewContextWithEngine(harness, engine) + TestQueryWithContext(t, ctx, engine, + "ALTER TABLE chk_tbl ADD CONSTRAINT chk1 CHECK (a > 0)", + nil, nil) + } +} // Tests generating the correct query plans for various queries using databases and tables provided by the given // harness. func TestQueryPlans(t *testing.T, harness Harness) { @@ -2155,6 +2163,146 @@ func TestDropForeignKeys(t *testing.T, harness Harness) { assert.True(t, sql.ErrTableNotFound.Is(err)) } +func TestCreateCheckConstraints(t *testing.T, harness Harness) { + require := require.New(t) + + e := NewEngine(t, harness) + + TestQuery(t, harness, e, + "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER)", + []sql.Row(nil), + nil, + ) + TestQuery(t, harness, e, + "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0)", + []sql.Row(nil), + nil, + ) + + db, err := e.Catalog.Database("mydb") + require.NoError(err) + + ctx := NewContext(harness) + table, ok, err := db.GetTableInsensitive(ctx, "t1") + require.NoError(err) + require.True(ok) + + cht, ok := table.(sql.CheckConstraintTable) + require.True(ok) + + checks, err := cht.GetCheckConstraints(NewContext(harness)) + require.NoError(err) + + expected := []sql.CheckConstraint{{ + Name: "chk2", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("b"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }} + assert.Equal(t, expected, checks) + + // Some faulty create statements + _, _, err = e.Query(NewContext(harness), "ALTER TABLE t2 ADD CONSTRAINT chk2 CHECK (c > 0)") + require.Error(err) + assert.True(t, sql.ErrTableNotFound.Is(err)) + + _, _, err = e.Query(NewContext(harness), "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (c > 0)") + require.Error(err) + assert.True(t, sql.ErrTableColumnNotFound.Is(err)) +} + +func TestDisallowedCheckConstraints(t *testing.T, harness Harness) { + require := require.New(t) + e := NewEngine(t, harness) + var err error + + TestQuery(t, harness, e, + "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER)", + []sql.Row(nil), + nil, + ) + + // functions, UDFs, procedures + _, _, err = e.Query(NewContext(harness), "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (current_user = \"root@\")") + require.Error(err) + assert.True(t, sql.ErrInvalidConstraintFunctionsNotSupported.Is(err)) + + _, _, err = e.Query(NewContext(harness), "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK ((select count(*) from t1) = 0)") + require.Error(err) + assert.True(t, sql.ErrInvalidConstraintSubqueryNotSupported.Is(err)) +} + +func TestDropCheckConstraints(t *testing.T, harness Harness) { + require := require.New(t) + + e := NewEngine(t, harness) + + TestQuery(t, harness, e, + "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER, c integer)", + []sql.Row(nil), + nil, + ) + TestQuery(t, harness, e, + "ALTER TABLE t1 ADD CONSTRAINT chk1 CHECK (a > 0)", + []sql.Row(nil), + nil, + ) + TestQuery(t, harness, e, + "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0)", + []sql.Row(nil), + nil, + ) + TestQuery(t, harness, e, + "ALTER TABLE t1 ADD CONSTRAINT chk3 CHECK (c > 0)", + []sql.Row(nil), + nil, + ) + TestQuery(t, harness, e, + "ALTER TABLE t1 DROP CONSTRAINT chk3", + []sql.Row(nil), + nil, + ) + TestQuery(t, harness, e, + "ALTER TABLE t1 DROP CHECK chk1", + []sql.Row(nil), + nil, + ) + + db, err := e.Catalog.Database("mydb") + require.NoError(err) + + ctx := NewContext(harness) + table, ok, err := db.GetTableInsensitive(ctx, "t1") + require.NoError(err) + require.True(ok) + + cht, ok := table.(sql.CheckConstraintTable) + require.True(ok) + + checks, err := cht.GetCheckConstraints(NewContext(harness)) + require.NoError(err) + + expected := []sql.CheckConstraint{{ + Name: "chk2", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("b"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }} + assert.Equal(t, expected, checks) + + // Some faulty create statements + _, _, err = e.Query(NewContext(harness), "ALTER TABLE t2 DROP CONSTRAINT chk2") + require.Error(err) + assert.True(t, sql.ErrTableNotFound.Is(err)) + + _, _, err = e.Query(NewContext(harness), "ALTER TABLE t1 DROP CHECK chk3") + require.NoError(err) +} + func TestNaturalJoin(t *testing.T, harness Harness) { require := require.New(t) diff --git a/enginetest/harness.go b/enginetest/harness.go index 0a5b104e00..700ce74c67 100755 --- a/enginetest/harness.go +++ b/enginetest/harness.go @@ -66,6 +66,15 @@ type ForeignKeyHarness interface { SupportsForeignKeys() bool } +// CheckConstraintHarness is an extension to Harness that lets an integrator test their implementation with check constraints. +// Integrator tables must implement sql.CheckAlterableTable and sql.CheckConstraintTable. +type CheckConstraintHarness interface { + Harness + // SupportsCheckConstraint returns whether this harness should accept CREATE CHECK statements as part of test + // setup. + SupportsCheckConstraint() bool +} + // VersionedDBHarness is an extension to Harness that lets an integrator test their implementation of versioned (AS OF) // queries. Integrators must implement sql.VersionedDatabase. For each table version being created, there will be a // call to NewTableAsOf, some number of Delete and Insert operations, and then a call to SnapshotTable. diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 717d561c1b..d764f87103 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -317,6 +317,18 @@ func TestDropForeignKeys(t *testing.T) { enginetest.TestDropForeignKeys(t, enginetest.NewDefaultMemoryHarness()) } +func TestCreateCheckConstraints(t *testing.T) { + enginetest.TestCreateCheckConstraints(t, enginetest.NewDefaultMemoryHarness()) +} + +func TestTestDisallowedCheckConstraints(t *testing.T) { + enginetest.TestDisallowedCheckConstraints(t, enginetest.NewDefaultMemoryHarness()) +} + +func TestDropCheckConstraints(t *testing.T) { + enginetest.TestDropCheckConstraints(t, enginetest.NewDefaultMemoryHarness()) +} + func TestExplode(t *testing.T) { enginetest.TestExplode(t, enginetest.NewDefaultMemoryHarness()) } diff --git a/memory/table.go b/memory/table.go index 0fdb46f819..5b34146269 100644 --- a/memory/table.go +++ b/memory/table.go @@ -1050,7 +1050,7 @@ func (t *Table) GetForeignKeys(_ *sql.Context) ([]sql.ForeignKeyConstraint, erro } // GetChecks implements sql.ForeignKeyTable -func (t *Table) GetChecks(_ *sql.Context) ([]sql.CheckConstraint, error) { +func (t *Table) GetCheckConstraints(_ *sql.Context) ([]sql.CheckConstraint, error) { return t.checks, nil } @@ -1062,6 +1062,12 @@ func (t *Table) CreateForeignKey(_ *sql.Context, fkName string, columns []string } } + for _, key := range t.checks { + if key.Name == fkName { + return fmt.Errorf("constraint %s already exists", fkName) + } + } + t.foreignKeys = append(t.foreignKeys, sql.ForeignKeyConstraint{ Name: fkName, Columns: columns, @@ -1076,12 +1082,23 @@ func (t *Table) CreateForeignKey(_ *sql.Context, fkName string, columns []string // DropForeignKey implements sql.ForeignKeyAlterableTable. func (t *Table) DropForeignKey(ctx *sql.Context, fkName string) error { + return t.DropConstraint(ctx, fkName) +} + +// DropForeignKey implements sql.ForeignKeyAlterableTable. +func (t *Table) DropConstraint(ctx *sql.Context, name string) error { for i, key := range t.foreignKeys { - if key.Name == fkName { + if key.Name == name { t.foreignKeys = append(t.foreignKeys[:i], t.foreignKeys[i+1:]...) return nil } } + for i, key := range t.checks { + if key.Name == name { + t.checks = append(t.checks[:i], t.checks[i+1:]...) + return nil + } + } return nil } @@ -1093,6 +1110,12 @@ func (t *Table) CreateCheckConstraint(_ *sql.Context, chName string, expr sql.Ex } } + for _, key := range t.foreignKeys { + if key.Name == chName { + return fmt.Errorf("constraint %s already exists", chName) + } + } + t.checks = append(t.checks, sql.CheckConstraint{ Name: chName, Expr: expr, @@ -1104,13 +1127,7 @@ func (t *Table) CreateCheckConstraint(_ *sql.Context, chName string, expr sql.Ex // func (t *Table) DropCheck(ctx *sql.Context, chName string) error {} implements sql.CheckAlterableTable. func (t *Table) DropCheckConstraint(ctx *sql.Context, chName string) error { - for i, key := range t.checks { - if key.Name == chName { - t.checks = append(t.checks[:i], t.checks[i+1:]...) - return nil - } - } - return nil + return t.DropConstraint(ctx, chName) } func (t *Table) createIndex(name string, columns []sql.IndexColumn, constraint sql.IndexConstraint, comment string) (sql.Index, error) { diff --git a/sql/analyzer/resolve_create_like.go b/sql/analyzer/resolve_create_like.go index 27418fe921..db98f6b703 100644 --- a/sql/analyzer/resolve_create_like.go +++ b/sql/analyzer/resolve_create_like.go @@ -68,5 +68,5 @@ func resolveCreateLike(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) tempCol.Source = planCreate.Name() newSch[i] = &tempCol } - return plan.NewCreateTable(planCreate.Database(), planCreate.Name(), newSch, planCreate.IfNotExists(), idxDefs, nil), nil + return plan.NewCreateTable(planCreate.Database(), planCreate.Name(), newSch, planCreate.IfNotExists(), idxDefs, nil, nil), nil } diff --git a/sql/core.go b/sql/core.go index 521b816398..851354073f 100644 --- a/sql/core.go +++ b/sql/core.go @@ -356,6 +356,13 @@ type ForeignKeyAlterableTable interface { DropForeignKey(ctx *Context, fkName string) error } +// CheckConstraintTable is a table that can declare its check constraints. +type CheckConstraintTable interface { + Table + // GetCheckConstraints returns the check constraints on this table. + GetCheckConstraints(ctx *Context) ([]CheckConstraint, error) +} + // ForeignKeyAlterableTable represents a table that supports foreign key modification operations. type CheckAlterableTable interface { Table diff --git a/sql/errors.go b/sql/errors.go index aa443c4e23..cb9da86400 100755 --- a/sql/errors.go +++ b/sql/errors.go @@ -150,6 +150,13 @@ var ( // ErrCannotDropDatabaseDoesntExist is returned when a DROP DATABASE is callend when a table is dropped that doesn't exist. ErrCannotDropDatabaseDoesntExist = errors.NewKind("can't drop database %s; database doesn't exist") + + // ErrInvalidConstraintFunctionsNotSupported is returned when a CONSTRAINT CHECK is called with a sub-function expression. + ErrInvalidConstraintFunctionsNotSupported = errors.NewKind("Invalid constraint expression, functions not supported: %s") + + + // ErrInvalidConstraintSubqueryNotSupported is returned when a CONSTRAINT CHECK is called with a sub-query expression. + ErrInvalidConstraintSubqueryNotSupported = errors.NewKind("Invalid constraint expression, sub-queries not supported: %s") ) func CastSQLError(err error) (*mysql.SQLError, bool) { diff --git a/sql/parse/parse.go b/sql/parse/parse.go index cf94b7cd44..ff9ba615ce 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -1026,10 +1026,15 @@ func convertConstraintDefinition(ctx *sql.Context, cd *sqlparser.ConstraintDefin OnDelete: convertReferenceAction(fkConstraint.OnDelete), }, nil } else if chConstraint, ok := cd.Details.(*sqlparser.CheckConstraintDefinition); ok { - c, err := exprToExpression(ctx, chConstraint.Expr) - if err != nil { - return nil, err + var c sql.Expression + var err error + if chConstraint.Expr != nil { + c, err = exprToExpression(ctx, chConstraint.Expr) + if err != nil { + return nil, err + } } + return &sql.CheckConstraint{ Name: cd.Name, Expr: c, diff --git a/sql/plan/alter_check.go b/sql/plan/alter_check.go index ce8b7a1c2d..c69fb39a79 100644 --- a/sql/plan/alter_check.go +++ b/sql/plan/alter_check.go @@ -16,6 +16,7 @@ package plan import ( "fmt" + "github.com/dolthub/go-mysql-server/sql/expression" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" @@ -78,6 +79,46 @@ func (p *CreateCheck) Execute(ctx *sql.Context) error { if err != nil { return err } + + // Make sure that all columns are valid, in the table, and there are no duplicates + cols := make(map[string]bool) + for _, col := range chAlterable.Schema() { + cols[col.Name] = true + } + + sql.Inspect(p.ChDef.Expr, func(expr sql.Expression) bool { + switch expr := expr.(type) { + case *expression.UnresolvedColumn: + if _, ok := cols[expr.Name()]; !ok { + err = sql.ErrTableColumnNotFound.New(expr.Name()) + return false + } + case *expression.UnresolvedFunction: + err = sql.ErrInvalidConstraintFunctionsNotSupported.New(expr.String()) + return false + case *Subquery: + err = sql.ErrInvalidConstraintSubqueryNotSupported.New(expr.String()) + return false + } + return true + }) + if err != nil { + return err + } + //switch p.ChDef.Expr.(type): + // case expression.BinaryExpression: + //for _, chCol := range p.ChDef.Expr. { + // if seen, ok := seenCols[fkCol]; ok { + // if !seen { + // seenCols[fkCol] = true + // } else { + // return ErrAddForeignKeyDuplicateColumn.New(fkCol) + // } + // } else { + // return sql.ErrTableColumnNotFound.New(fkCol) + // } + //} + return chAlterable.CreateCheckConstraint(ctx, p.ChDef.Name, p.ChDef.Expr, p.ChDef.Enforced) } @@ -112,7 +153,7 @@ func (p *CreateCheck) WithChildren(children ...sql.Node) (sql.Node, error) { if len(children) != 1 { return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) } - return NewAlterDropCheck(children[0], p.ChDef), nil + return NewAlterAddCheck(children[0], p.ChDef), nil } func (p *CreateCheck) Schema() sql.Schema { return nil } diff --git a/sql/plan/ddl_test.go b/sql/plan/ddl_test.go index 1b6c54479c..5f6718b056 100644 --- a/sql/plan/ddl_test.go +++ b/sql/plan/ddl_test.go @@ -100,7 +100,7 @@ func TestDropTable(t *testing.T) { } func createTable(t *testing.T, db sql.Database, name string, schema sql.Schema, ifNotExists bool) error { - c := NewCreateTable(db, name, schema, ifNotExists, nil, nil) + c := NewCreateTable(db, name, schema, ifNotExists, nil, nil, nil) rows, err := c.RowIter(sql.NewEmptyContext(), nil) if err != nil { From eb10c6dbe9a0bd732fe18c197653801eb18945ec Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Mon, 1 Mar 2021 15:51:40 -0800 Subject: [PATCH 06/16] Go fmt --- enginetest/enginetests.go | 1 + sql/errors.go | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 3c9d60c126..e2683d2950 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -122,6 +122,7 @@ func createCheckConstraint(t *testing.T, harness Harness, engine *sqle.Engine) { nil, nil) } } + // Tests generating the correct query plans for various queries using databases and tables provided by the given // harness. func TestQueryPlans(t *testing.T, harness Harness) { diff --git a/sql/errors.go b/sql/errors.go index cb9da86400..5fc168a36c 100755 --- a/sql/errors.go +++ b/sql/errors.go @@ -154,7 +154,6 @@ var ( // ErrInvalidConstraintFunctionsNotSupported is returned when a CONSTRAINT CHECK is called with a sub-function expression. ErrInvalidConstraintFunctionsNotSupported = errors.NewKind("Invalid constraint expression, functions not supported: %s") - // ErrInvalidConstraintSubqueryNotSupported is returned when a CONSTRAINT CHECK is called with a sub-query expression. ErrInvalidConstraintSubqueryNotSupported = errors.NewKind("Invalid constraint expression, sub-queries not supported: %s") ) From b2d9ea467f9b7941ba75d7b10ef7b893f7df44ba Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Tue, 2 Mar 2021 15:40:35 -0800 Subject: [PATCH 07/16] Simplify create table signature --- engine.go | 2 +- enginetest/enginetests.go | 9 - memory/table.go | 5 +- sql/analyzer/resolve_create_like.go | 8 +- sql/core.go | 2 +- sql/parse/parse.go | 9 +- sql/parse/parse_test.go | 902 ++++++++++++++-------------- sql/plan/ddl.go | 46 +- sql/plan/ddl_test.go | 2 +- 9 files changed, 518 insertions(+), 467 deletions(-) diff --git a/engine.go b/engine.go index 0a82c607ea..7eeb680ef8 100644 --- a/engine.go +++ b/engine.go @@ -242,7 +242,7 @@ func ResolveDefaults(tableName string, schema []*ColumnWithRawDefault) (sql.Sche return unresolvedSchema, nil } // *plan.CreateTable properly handles resolving default values, so we hijack it - createTable := plan.NewCreateTable(db, tableName, unresolvedSchema, false, nil, nil, nil) + createTable := plan.NewCreateTable(db, tableName, false, &plan.TableSpec{Schema: unresolvedSchema}) analyzed, err := e.Analyzer.Analyze(ctx, createTable, nil) if err != nil { return nil, err diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index e2683d2950..2638ed46e8 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -114,15 +114,6 @@ func createForeignKeys(t *testing.T, harness Harness, engine *sqle.Engine) { } } -func createCheckConstraint(t *testing.T, harness Harness, engine *sqle.Engine) { - if chk, ok := harness.(CheckConstraintHarness); ok && chk.SupportsCheckConstraint() { - ctx := NewContextWithEngine(harness, engine) - TestQueryWithContext(t, ctx, engine, - "ALTER TABLE chk_tbl ADD CONSTRAINT chk1 CHECK (a > 0)", - nil, nil) - } -} - // Tests generating the correct query plans for various queries using databases and tables provided by the given // harness. func TestQueryPlans(t *testing.T, harness Harness) { diff --git a/memory/table.go b/memory/table.go index 5b34146269..19caf429b2 100644 --- a/memory/table.go +++ b/memory/table.go @@ -1044,12 +1044,12 @@ func (t *Table) GetIndexes(ctx *sql.Context) ([]sql.Index, error) { return append(indexes, nonPrimaryIndexes...), nil } -// GetForeignKeys implements sql.ForeignKeyTable +// GetForeignKeys implements sql.CheckConstraintTable func (t *Table) GetForeignKeys(_ *sql.Context) ([]sql.ForeignKeyConstraint, error) { return t.foreignKeys, nil } -// GetChecks implements sql.ForeignKeyTable +// GetChecks implements sql.CheckConstraintTable func (t *Table) GetCheckConstraints(_ *sql.Context) ([]sql.CheckConstraint, error) { return t.checks, nil } @@ -1085,7 +1085,6 @@ func (t *Table) DropForeignKey(ctx *sql.Context, fkName string) error { return t.DropConstraint(ctx, fkName) } -// DropForeignKey implements sql.ForeignKeyAlterableTable. func (t *Table) DropConstraint(ctx *sql.Context, name string) error { for i, key := range t.foreignKeys { if key.Name == name { diff --git a/sql/analyzer/resolve_create_like.go b/sql/analyzer/resolve_create_like.go index db98f6b703..3fe186a3de 100644 --- a/sql/analyzer/resolve_create_like.go +++ b/sql/analyzer/resolve_create_like.go @@ -68,5 +68,11 @@ func resolveCreateLike(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) tempCol.Source = planCreate.Name() newSch[i] = &tempCol } - return plan.NewCreateTable(planCreate.Database(), planCreate.Name(), newSch, planCreate.IfNotExists(), idxDefs, nil, nil), nil + + tableSpec := &plan.TableSpec{ + Schema: newSch, + IdxDefs: idxDefs, + } + + return plan.NewCreateTable(planCreate.Database(), planCreate.Name(), planCreate.IfNotExists(), tableSpec), nil } diff --git a/sql/core.go b/sql/core.go index 851354073f..dde558c735 100644 --- a/sql/core.go +++ b/sql/core.go @@ -225,7 +225,7 @@ type ForeignKeyConstraint struct { OnDelete ForeignKeyReferenceOption } -// CheckConstraint declares a constraint between the columns of two tables. +// CheckConstraint declares a boolean-eval constraint. type CheckConstraint struct { Name string Expr Expression diff --git a/sql/parse/parse.go b/sql/parse/parse.go index ff9ba615ce..6d4f1d1453 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -999,8 +999,15 @@ func convertCreateTable(ctx *sql.Context, c *sqlparser.DDL) (sql.Node, error) { } } + tableSpec := &plan.TableSpec{ + Schema: schema, + IdxDefs: idxDefs, + FkDefs: fkDefs, + ChDefs: chDefs, + } + return plan.NewCreateTable( - sql.UnresolvedDatabase(""), c.Table.Name.String(), schema, c.IfNotExists, idxDefs, fkDefs, chDefs), nil + sql.UnresolvedDatabase(""), c.Table.Name.String(), c.IfNotExists, tableSpec), nil } type namedConstraint struct { diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 667e8bb233..96c1a29871 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -48,555 +48,571 @@ var fixtures = map[string]sql.Node{ `CREATE TABLE t1(a INTEGER, b TEXT, c DATE, d TIMESTAMP, e VARCHAR(20), f BLOB NOT NULL, g DATETIME, h CHAR(40))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: true, - }, { - Name: "b", - Type: sql.Text, - Nullable: true, - }, { - Name: "c", - Type: sql.Date, - Nullable: true, - }, { - Name: "d", - Type: sql.Timestamp, - Nullable: true, - }, { - Name: "e", - Type: sql.MustCreateStringWithDefaults(sqltypes.VarChar, 20), - Nullable: true, - }, { - Name: "f", - Type: sql.Blob, - Nullable: false, - }, { - Name: "g", - Type: sql.Datetime, - Nullable: true, - }, { - Name: "h", - Type: sql.MustCreateStringWithDefaults(sqltypes.Char, 40), - Nullable: true, - }}, false, - nil, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + }, { + Name: "c", + Type: sql.Date, + Nullable: true, + }, { + Name: "d", + Type: sql.Timestamp, + Nullable: true, + }, { + Name: "e", + Type: sql.MustCreateStringWithDefaults(sqltypes.VarChar, 20), + Nullable: true, + }, { + Name: "f", + Type: sql.Blob, + Nullable: false, + }, { + Name: "g", + Type: sql.Datetime, + Nullable: true, + }, { + Name: "h", + Type: sql.MustCreateStringWithDefaults(sqltypes.Char, 40), + Nullable: true, + }}, + }, ), `CREATE TABLE t1(a INTEGER NOT NULL PRIMARY KEY, b TEXT)`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Text, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + PrimaryKey: false, + }}, + }, ), `CREATE TABLE t1(a INTEGER NOT NULL PRIMARY KEY COMMENT "hello", b TEXT COMMENT "goodbye")`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - Comment: "hello", - }, { - Name: "b", - Type: sql.Text, - Nullable: true, - PrimaryKey: false, - Comment: "goodbye", - }}, false, - nil, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + Comment: "hello", + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + PrimaryKey: false, + Comment: "goodbye", + }}, + }, ), `CREATE TABLE t1(a INTEGER, b TEXT, PRIMARY KEY (a))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Text, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + PrimaryKey: false, + }}, + }, ), `CREATE TABLE t1(a INTEGER, b TEXT, PRIMARY KEY (a, b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Text, - Nullable: false, - PrimaryKey: true, - }}, false, - nil, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: false, + PrimaryKey: true, + }}, + }, ), `CREATE TABLE IF NOT EXISTS t1(a INTEGER, b TEXT, PRIMARY KEY (a, b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Text, - Nullable: false, - PrimaryKey: true, - }}, true, - nil, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: false, + PrimaryKey: true, + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, INDEX (b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - []*plan.IndexDefinition{{ - IndexName: "", - Using: sql.IndexUsing_Default, - Constraint: sql.IndexConstraint_None, - Columns: []sql.IndexColumn{{"b", 0}}, - Comment: "", - }}, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + IdxDefs: []*plan.IndexDefinition{{ + IndexName: "", + Using: sql.IndexUsing_Default, + Constraint: sql.IndexConstraint_None, + Columns: []sql.IndexColumn{{"b", 0}}, + Comment: "", + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, INDEX idx_name (b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - []*plan.IndexDefinition{{ - IndexName: "idx_name", - Using: sql.IndexUsing_Default, - Constraint: sql.IndexConstraint_None, - Columns: []sql.IndexColumn{{"b", 0}}, - Comment: "", - }}, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + IdxDefs: []*plan.IndexDefinition{{ + IndexName: "idx_name", + Using: sql.IndexUsing_Default, + Constraint: sql.IndexConstraint_None, + Columns: []sql.IndexColumn{{"b", 0}}, + Comment: "", + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, INDEX idx_name (b) COMMENT 'hi')`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - []*plan.IndexDefinition{{ - IndexName: "idx_name", - Using: sql.IndexUsing_Default, - Constraint: sql.IndexConstraint_None, - Columns: []sql.IndexColumn{{"b", 0}}, - Comment: "hi", - }}, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + IdxDefs: []*plan.IndexDefinition{{ + IndexName: "idx_name", + Using: sql.IndexUsing_Default, + Constraint: sql.IndexConstraint_None, + Columns: []sql.IndexColumn{{"b", 0}}, + Comment: "hi", + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, UNIQUE INDEX (b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - []*plan.IndexDefinition{{ - IndexName: "", - Using: sql.IndexUsing_Default, - Constraint: sql.IndexConstraint_Unique, - Columns: []sql.IndexColumn{{"b", 0}}, - Comment: "", - }}, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + IdxDefs: []*plan.IndexDefinition{{ + IndexName: "", + Using: sql.IndexUsing_Default, + Constraint: sql.IndexConstraint_Unique, + Columns: []sql.IndexColumn{{"b", 0}}, + Comment: "", + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, UNIQUE (b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - []*plan.IndexDefinition{{ - IndexName: "", - Using: sql.IndexUsing_Default, - Constraint: sql.IndexConstraint_Unique, - Columns: []sql.IndexColumn{{"b", 0}}, - Comment: "", - }}, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + IdxDefs: []*plan.IndexDefinition{{ + IndexName: "", + Using: sql.IndexUsing_Default, + Constraint: sql.IndexConstraint_Unique, + Columns: []sql.IndexColumn{{"b", 0}}, + Comment: "", + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, INDEX (b, a))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - []*plan.IndexDefinition{{ - IndexName: "", - Using: sql.IndexUsing_Default, - Constraint: sql.IndexConstraint_None, - Columns: []sql.IndexColumn{{"b", 0}, {"a", 0}}, - Comment: "", - }}, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + IdxDefs: []*plan.IndexDefinition{{ + IndexName: "", + Using: sql.IndexUsing_Default, + Constraint: sql.IndexConstraint_None, + Columns: []sql.IndexColumn{{"b", 0}, {"a", 0}}, + Comment: "", + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, INDEX (b), INDEX (b, a))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - []*plan.IndexDefinition{{ - IndexName: "", - Using: sql.IndexUsing_Default, - Constraint: sql.IndexConstraint_None, - Columns: []sql.IndexColumn{{"b", 0}}, - Comment: "", - }, { - IndexName: "", - Using: sql.IndexUsing_Default, - Constraint: sql.IndexConstraint_None, - Columns: []sql.IndexColumn{{"b", 0}, {"a", 0}}, - Comment: "", - }}, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + IdxDefs: []*plan.IndexDefinition{{ + IndexName: "", + Using: sql.IndexUsing_Default, + Constraint: sql.IndexConstraint_None, + Columns: []sql.IndexColumn{{"b", 0}}, + Comment: "", + }, { + IndexName: "", + Using: sql.IndexUsing_Default, + Constraint: sql.IndexConstraint_None, + Columns: []sql.IndexColumn{{"b", 0}, {"a", 0}}, + Comment: "", + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, FOREIGN KEY (b_id) REFERENCES t0(b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b_id", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - []*sql.ForeignKeyConstraint{{ - Name: "", - Columns: []string{"b_id"}, - ReferencedTable: "t0", - ReferencedColumns: []string{"b"}, - OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, - OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, - }}, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b_id", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + FkDefs: []*sql.ForeignKeyConstraint{{ + Name: "", + Columns: []string{"b_id"}, + ReferencedTable: "t0", + ReferencedColumns: []string{"b"}, + OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, + OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, CONSTRAINT fk_name FOREIGN KEY (b_id) REFERENCES t0(b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b_id", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - []*sql.ForeignKeyConstraint{{ - Name: "fk_name", - Columns: []string{"b_id"}, - ReferencedTable: "t0", - ReferencedColumns: []string{"b"}, - OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, - OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, - }}, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b_id", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + FkDefs: []*sql.ForeignKeyConstraint{{ + Name: "fk_name", + Columns: []string{"b_id"}, + ReferencedTable: "t0", + ReferencedColumns: []string{"b"}, + OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, + OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, FOREIGN KEY (b_id) REFERENCES t0(b) ON UPDATE CASCADE)`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b_id", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - []*sql.ForeignKeyConstraint{{ - Name: "", - Columns: []string{"b_id"}, - ReferencedTable: "t0", - ReferencedColumns: []string{"b"}, - OnUpdate: sql.ForeignKeyReferenceOption_Cascade, - OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, - }}, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b_id", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + FkDefs: []*sql.ForeignKeyConstraint{{ + Name: "", + Columns: []string{"b_id"}, + ReferencedTable: "t0", + ReferencedColumns: []string{"b"}, + OnUpdate: sql.ForeignKeyReferenceOption_Cascade, + OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, FOREIGN KEY (b_id) REFERENCES t0(b) ON DELETE RESTRICT)`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b_id", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - []*sql.ForeignKeyConstraint{{ - Name: "", - Columns: []string{"b_id"}, - ReferencedTable: "t0", - ReferencedColumns: []string{"b"}, - OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, - OnDelete: sql.ForeignKeyReferenceOption_Restrict, - }}, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b_id", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + FkDefs: []*sql.ForeignKeyConstraint{{ + Name: "", + Columns: []string{"b_id"}, + ReferencedTable: "t0", + ReferencedColumns: []string{"b"}, + OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, + OnDelete: sql.ForeignKeyReferenceOption_Restrict, + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, FOREIGN KEY (b_id) REFERENCES t0(b) ON UPDATE SET NULL ON DELETE NO ACTION)`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b_id", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - []*sql.ForeignKeyConstraint{{ - Name: "", - Columns: []string{"b_id"}, - ReferencedTable: "t0", - ReferencedColumns: []string{"b"}, - OnUpdate: sql.ForeignKeyReferenceOption_SetNull, - OnDelete: sql.ForeignKeyReferenceOption_NoAction, - }}, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b_id", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + + FkDefs: []*sql.ForeignKeyConstraint{{ + Name: "", + Columns: []string{"b_id"}, + ReferencedTable: "t0", + ReferencedColumns: []string{"b"}, + OnUpdate: sql.ForeignKeyReferenceOption_SetNull, + OnDelete: sql.ForeignKeyReferenceOption_NoAction, + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, c_id BIGINT, FOREIGN KEY (b_id, c_id) REFERENCES t0(b, c))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b_id", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }, { - Name: "c_id", - Type: sql.Int64, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - []*sql.ForeignKeyConstraint{{ - Name: "", - Columns: []string{"b_id", "c_id"}, - ReferencedTable: "t0", - ReferencedColumns: []string{"b", "c"}, - OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, - OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, - }}, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b_id", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }, { + Name: "c_id", + Type: sql.Int64, + Nullable: true, + PrimaryKey: false, + }}, + FkDefs: []*sql.ForeignKeyConstraint{{ + Name: "", + Columns: []string{"b_id", "c_id"}, + ReferencedTable: "t0", + ReferencedColumns: []string{"b", "c"}, + OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, + OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, c_id BIGINT, CONSTRAINT fk_name FOREIGN KEY (b_id, c_id) REFERENCES t0(b, c) ON UPDATE RESTRICT ON DELETE CASCADE)`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b_id", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }, { - Name: "c_id", - Type: sql.Int64, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - []*sql.ForeignKeyConstraint{{ - Name: "fk_name", - Columns: []string{"b_id", "c_id"}, - ReferencedTable: "t0", - ReferencedColumns: []string{"b", "c"}, - OnUpdate: sql.ForeignKeyReferenceOption_Restrict, - OnDelete: sql.ForeignKeyReferenceOption_Cascade, - }}, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b_id", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }, { + Name: "c_id", + Type: sql.Int64, + Nullable: true, + PrimaryKey: false, + }}, + FkDefs: []*sql.ForeignKeyConstraint{{ + Name: "fk_name", + Columns: []string{"b_id", "c_id"}, + ReferencedTable: "t0", + ReferencedColumns: []string{"b", "c"}, + OnUpdate: sql.ForeignKeyReferenceOption_Restrict, + OnDelete: sql.ForeignKeyReferenceOption_Cascade, + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, CHECK (a > 0))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }}, false, - nil, - nil, - []*sql.CheckConstraint{{ - Name: "", - Expr: expression.NewGreaterThan( - expression.NewUnresolvedColumn("a"), - expression.NewLiteral(int8(0), sql.Int8), - ), - Enforced: true, - }}, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }}, + ChDefs: []*sql.CheckConstraint{{ + Name: "", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }}, + }, + ), + `CREATE TABLE t1(a INTEGER PRIMARY KEY CHECK (a > 0))`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + false, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }}, + ChDefs: []*sql.CheckConstraint{{ + Name: "", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, CONSTRAINT ch1 CHECK (a > 0))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }}, false, - nil, - nil, - []*sql.CheckConstraint{{ - Name: "ch1", - Expr: expression.NewGreaterThan( - expression.NewUnresolvedColumn("a"), - expression.NewLiteral(int8(0), sql.Int8), - ), - Enforced: true, - }}, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }}, + ChDefs: []*sql.CheckConstraint{{ + Name: "ch1", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }}, + }, ), `DROP TABLE foo;`: plan.NewDropTable( sql.UnresolvedDatabase(""), false, "foo", diff --git a/sql/plan/ddl.go b/sql/plan/ddl.go index f933712224..2333876dd0 100644 --- a/sql/plan/ddl.go +++ b/sql/plan/ddl.go @@ -75,15 +75,47 @@ type IndexDefinition struct { Comment string } +// TableSpec is a node describing the schema of a table. +type TableSpec struct { + Schema sql.Schema + FkDefs []*sql.ForeignKeyConstraint + ChDefs []*sql.CheckConstraint + IdxDefs []*IndexDefinition +} + +func (c *TableSpec) WithSchema(schema sql.Schema) (*TableSpec, error) { + nc := *c + nc.Schema = schema + return &nc, nil +} + +func (c *TableSpec) WithForeignKeys(fkDefs []*sql.ForeignKeyConstraint) (*TableSpec, error) { + nc := *c + nc.FkDefs = fkDefs + return &nc, nil +} + +func (c *TableSpec) WithCheckConstraints(chDefs []*sql.CheckConstraint) (*TableSpec, error) { + nc := *c + nc.ChDefs = chDefs + return &nc, nil +} + +func (c *TableSpec) WithIndices(idxDefs []*IndexDefinition) (*TableSpec, error) { + nc := *c + nc.IdxDefs = idxDefs + return &nc, nil +} + // CreateTable is a node describing the creation of some table. type CreateTable struct { ddlNode name string schema sql.Schema - ifNotExists bool fkDefs []*sql.ForeignKeyConstraint chDefs []*sql.CheckConstraint idxDefs []*IndexDefinition + ifNotExists bool like sql.Node } @@ -92,19 +124,19 @@ var _ sql.Node = (*CreateTable)(nil) var _ sql.Expressioner = (*CreateTable)(nil) // NewCreateTable creates a new CreateTable node -func NewCreateTable(db sql.Database, name string, schema sql.Schema, ifNotExists bool, idxDefs []*IndexDefinition, fkDefs []*sql.ForeignKeyConstraint, chDefs []*sql.CheckConstraint) *CreateTable { - for _, s := range schema { +func NewCreateTable(db sql.Database, name string, ifNotExists bool, tableSpec *TableSpec) *CreateTable { + for _, s := range tableSpec.Schema { s.Source = name } return &CreateTable{ ddlNode: ddlNode{db}, name: name, - schema: schema, + schema: tableSpec.Schema, + fkDefs: tableSpec.FkDefs, + chDefs: tableSpec.ChDefs, + idxDefs: tableSpec.IdxDefs, ifNotExists: ifNotExists, - idxDefs: idxDefs, - fkDefs: fkDefs, - chDefs: chDefs, } } diff --git a/sql/plan/ddl_test.go b/sql/plan/ddl_test.go index 5f6718b056..d32fca75e4 100644 --- a/sql/plan/ddl_test.go +++ b/sql/plan/ddl_test.go @@ -100,7 +100,7 @@ func TestDropTable(t *testing.T) { } func createTable(t *testing.T, db sql.Database, name string, schema sql.Schema, ifNotExists bool) error { - c := NewCreateTable(db, name, schema, ifNotExists, nil, nil, nil) + c := NewCreateTable(db, name, ifNotExists, &TableSpec{Schema: schema}) rows, err := c.RowIter(sql.NewEmptyContext(), nil) if err != nil { From 827a5509174d5e44b5b5229f05a02fb45226d707 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Fri, 5 Mar 2021 17:55:00 -0800 Subject: [PATCH 08/16] Check interface string-friendly, expressions executing on alter table --- enginetest/enginetests.go | 18 ++-- memory/table.go | 40 ++++----- sql/core.go | 23 +++-- sql/parse/parse.go | 31 +++++++ sql/plan/alter_check.go | 181 +++++++++++++++++++++++++++----------- sql/plan/ddl.go | 4 +- 6 files changed, 209 insertions(+), 88 deletions(-) diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 2638ed46e8..12781c817d 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -2159,6 +2159,8 @@ func TestCreateCheckConstraints(t *testing.T, harness Harness) { require := require.New(t) e := NewEngine(t, harness) + //e.Analyzer.Debug = true + //e.Analyzer.Verbose = true TestQuery(t, harness, e, "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER)", @@ -2179,20 +2181,22 @@ func TestCreateCheckConstraints(t *testing.T, harness Harness) { require.NoError(err) require.True(ok) - cht, ok := table.(sql.CheckConstraintTable) + cht, ok := table.(sql.CheckTable) require.True(ok) - checks, err := cht.GetCheckConstraints(NewContext(harness)) + checks, err := cht.GetChecks(NewContext(harness)) require.NoError(err) - expected := []sql.CheckConstraint{{ + cmp := sql.CheckConstraint{ Name: "chk2", Expr: expression.NewGreaterThan( - expression.NewUnresolvedColumn("b"), + expression.NewUnresolvedColumn("t1.b"), expression.NewLiteral(int8(0), sql.Int8), ), Enforced: true, - }} + } + expected := []sql.CheckDefinition{*plan.NewCheckDefinition(&cmp)} + assert.Equal(t, expected, checks) // Some faulty create statements @@ -2270,10 +2274,10 @@ func TestDropCheckConstraints(t *testing.T, harness Harness) { require.NoError(err) require.True(ok) - cht, ok := table.(sql.CheckConstraintTable) + cht, ok := table.(sql.CheckTable) require.True(ok) - checks, err := cht.GetCheckConstraints(NewContext(harness)) + checks, err := cht.GetChecks(NewContext(harness)) require.NoError(err) expected := []sql.CheckConstraint{{ diff --git a/memory/table.go b/memory/table.go index 19caf429b2..168505e5ba 100644 --- a/memory/table.go +++ b/memory/table.go @@ -37,7 +37,7 @@ type Table struct { columns []int indexes map[string]sql.Index foreignKeys []sql.ForeignKeyConstraint - checks []sql.CheckConstraint + checks []sql.CheckDefinition pkIndexesEnabled bool // Data storage @@ -67,6 +67,8 @@ var _ sql.IndexAlterableTable = (*Table)(nil) var _ sql.IndexedTable = (*Table)(nil) var _ sql.ForeignKeyAlterableTable = (*Table)(nil) var _ sql.ForeignKeyTable = (*Table)(nil) +var _ sql.CheckAlterableTable = (*Table)(nil) +var _ sql.CheckTable = (*Table)(nil) var _ sql.AutoIncrementTable = (*Table)(nil) // PushdownTable is an extension to Table that implements sql.FilteredTable and sql.ProjectedTable. This is mostly just @@ -1044,16 +1046,11 @@ func (t *Table) GetIndexes(ctx *sql.Context) ([]sql.Index, error) { return append(indexes, nonPrimaryIndexes...), nil } -// GetForeignKeys implements sql.CheckConstraintTable +// GetForeignKeys implements sql.CheckTable func (t *Table) GetForeignKeys(_ *sql.Context) ([]sql.ForeignKeyConstraint, error) { return t.foreignKeys, nil } -// GetChecks implements sql.CheckConstraintTable -func (t *Table) GetCheckConstraints(_ *sql.Context) ([]sql.CheckConstraint, error) { - return t.checks, nil -} - // CreateForeignKey implements sql.ForeignKeyAlterableTable. Foreign keys are not enforced on update / delete. func (t *Table) CreateForeignKey(_ *sql.Context, fkName string, columns []string, referencedTable string, referencedColumns []string, onUpdate, onDelete sql.ForeignKeyReferenceOption) error { for _, key := range t.foreignKeys { @@ -1082,10 +1079,10 @@ func (t *Table) CreateForeignKey(_ *sql.Context, fkName string, columns []string // DropForeignKey implements sql.ForeignKeyAlterableTable. func (t *Table) DropForeignKey(ctx *sql.Context, fkName string) error { - return t.DropConstraint(ctx, fkName) + return t.dropConstraint(ctx, fkName) } -func (t *Table) DropConstraint(ctx *sql.Context, name string) error { +func (t *Table) dropConstraint(ctx *sql.Context, name string) error { for i, key := range t.foreignKeys { if key.Name == name { t.foreignKeys = append(t.foreignKeys[:i], t.foreignKeys[i+1:]...) @@ -1101,32 +1098,33 @@ func (t *Table) DropConstraint(ctx *sql.Context, name string) error { return nil } +// GetChecks implements sql.CheckTable +func (t *Table) GetChecks(_ *sql.Context) ([]sql.CheckDefinition, error) { + return t.checks, nil +} + // CreateCheck implements sql.CheckAlterableTable -func (t *Table) CreateCheckConstraint(_ *sql.Context, chName string, expr sql.Expression, enforced bool) error { +func (t *Table) CreateCheck(_ *sql.Context, check *sql.CheckDefinition) error { for _, key := range t.checks { - if key.Name == chName { - return fmt.Errorf("constraint %s already exists", chName) + if key.Name == check.Name { + return fmt.Errorf("constraint %s already exists", check.Name) } } for _, key := range t.foreignKeys { - if key.Name == chName { - return fmt.Errorf("constraint %s already exists", chName) + if key.Name == check.Name { + return fmt.Errorf("constraint %s already exists", check.Name) } } - t.checks = append(t.checks, sql.CheckConstraint{ - Name: chName, - Expr: expr, - Enforced: enforced, - }) + t.checks = append(t.checks, *check) return nil } // func (t *Table) DropCheck(ctx *sql.Context, chName string) error {} implements sql.CheckAlterableTable. -func (t *Table) DropCheckConstraint(ctx *sql.Context, chName string) error { - return t.DropConstraint(ctx, chName) +func (t *Table) DropCheck(ctx *sql.Context, chName string) error { + return t.dropConstraint(ctx, chName) } func (t *Table) createIndex(name string, columns []sql.IndexColumn, constraint sql.IndexConstraint, comment string) (sql.Index, error) { diff --git a/sql/core.go b/sql/core.go index dde558c735..b4686c28de 100644 --- a/sql/core.go +++ b/sql/core.go @@ -225,6 +225,13 @@ type ForeignKeyConstraint struct { OnDelete ForeignKeyReferenceOption } +// CheckDefinition defines a trigger. Integrators are not expected to parse or understand the trigger definitions, +// but must store and return them when asked. +type CheckDefinition struct { + Name string // The name of this check. Check names in a database are unique. + AlterStatement string // Reference for expression body +} + // CheckConstraint declares a boolean-eval constraint. type CheckConstraint struct { Name string @@ -356,21 +363,21 @@ type ForeignKeyAlterableTable interface { DropForeignKey(ctx *Context, fkName string) error } -// CheckConstraintTable is a table that can declare its check constraints. -type CheckConstraintTable interface { +// CheckTable is a table that can declare its check constraints. +type CheckTable interface { Table - // GetCheckConstraints returns the check constraints on this table. - GetCheckConstraints(ctx *Context) ([]CheckConstraint, error) + // GetChecks returns the check constraints on this table. + GetChecks(ctx *Context) ([]CheckDefinition, error) } // ForeignKeyAlterableTable represents a table that supports foreign key modification operations. type CheckAlterableTable interface { Table - // CreateCheckConstraint creates an check constraint for this table, using the provided parameters. + // CreateCheck creates an check constraint for this table, using the provided parameters. // Returns an error if the constraint name already exists. - CreateCheckConstraint(ctx *Context, chName string, expr Expression, enforced bool) error - // DropCheckConstraint removes a check constraint from the database. - DropCheckConstraint(ctx *Context, chName string) error + CreateCheck(ctx *Context, check *CheckDefinition) error + // DropCheck removes a check constraint from the database. + DropCheck(ctx *Context, chName string) error } // InsertableTable is a table that can process insertion of new rows. diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 6d4f1d1453..1057f96a6e 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -59,6 +59,8 @@ var ( ErrInvalidAutoIncCols = errors.NewKind("there can be only one auto_increment column and it must be defined as a key") ErrUnknownConstraintDefinition = errors.NewKind("unknown constraint definition: %s, %T") + + ErrInvalidCheckConstraint = errors.NewKind("invalid constraint definition: %s") ) var ( @@ -1122,6 +1124,35 @@ func convertInsert(ctx *sql.Context, i *sqlparser.Insert) (sql.Node, error) { ), nil } +func convertCheckDefToConstraint(ctx *sql.Context, check *sql.CheckDefinition) (*sql.CheckConstraint, error) { + parsed, err := sqlparser.ParseStrictDDL(check.AlterStatement) + if err != nil { + return nil, err + } + + ddl, ok := parsed.(*sqlparser.DDL) + if !ok || ddl.ConstraintAction == "" || len(ddl.TableSpec.Constraints) != 1 || strings.ToLower(ddl.ConstraintAction) != sqlparser.AddStr { + return nil, ErrInvalidCheckConstraint.New(check.AlterStatement) + } + + parsedConstraint := ddl.TableSpec.Constraints[0] + chConstraint, ok := parsedConstraint.Details.(*sqlparser.CheckConstraintDefinition) + if !ok || chConstraint.Expr == nil { + return nil, ErrInvalidCheckConstraint.New(check.AlterStatement) + } + + c, err := exprToExpression(ctx, chConstraint.Expr) + if err != nil { + return nil, err + } + + return &sql.CheckConstraint{ + Name: parsedConstraint.Name, + Expr: c, + Enforced: chConstraint.Enforced, + }, nil +} + func convertDelete(ctx *sql.Context, d *sqlparser.Delete) (sql.Node, error) { node, err := tableExprsToTable(ctx, d.TableExprs) if err != nil { diff --git a/sql/plan/alter_check.go b/sql/plan/alter_check.go index c69fb39a79..8405b04814 100644 --- a/sql/plan/alter_check.go +++ b/sql/plan/alter_check.go @@ -16,15 +16,18 @@ package plan import ( "fmt" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "gopkg.in/src-d/go-errors.v1" - - "github.com/dolthub/go-mysql-server/sql" + "io" ) var ( // ErrNoCheckConstraintSupport is returned when the table does not support CONSTRAINT CHECK operations. ErrNoCheckConstraintSupport = errors.NewKind("the table does not support check constraint operations: %s") + + // ErrNoCheckFailed is returned when the check constraint evaluates to false + ErrNoCheckFailed = errors.NewKind("check failed: %s, %s") ) type CreateCheck struct { @@ -73,6 +76,44 @@ func getCheckAlterableTable(t sql.Table) (sql.CheckAlterableTable, error) { } } +// Expressions implements the sql.Expressioner interface. +func (c *CreateCheck) Expressions() []sql.Expression { + return []sql.Expression{c.ChDef.Expr} +} + +// Resolved implements the Resolvable interface. +func (c *CreateCheck) Resolved() bool { + ok := true + sql.Inspect(c.ChDef.Expr, func(expr sql.Expression) bool { + switch expr.(type) { + case *expression.UnresolvedColumn: + ok = false + return false + } + return true + }) + return ok +} + +// WithExpressions implements the sql.Expressioner interface. +func (c *CreateCheck) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, fmt.Errorf("expected one expression, got: %d", len(exprs)) + } + + nc := *c + nc.ChDef.Expr = exprs[0] + return &nc, nil + //return &CreateCheck{ + // UnaryNode: c.UnaryNode, + // ChDef: &sql.CheckConstraint{ + // Name: c.ChDef.Name, + // Expr: exprs[0], + // Enforced: c.ChDef.Enforced, + // }, + //}, nil +} + // Execute inserts the rows in the database. func (p *CreateCheck) Execute(ctx *sql.Context) error { chAlterable, err := getCheckAlterable(p.UnaryNode.Child) @@ -80,46 +121,99 @@ func (p *CreateCheck) Execute(ctx *sql.Context) error { return err } - // Make sure that all columns are valid, in the table, and there are no duplicates - cols := make(map[string]bool) - for _, col := range chAlterable.Schema() { - cols[col.Name] = true + //check, err := ConvertCheckDefToConstraint(ctx, p.ChDef) + if err != nil { + return err } - - sql.Inspect(p.ChDef.Expr, func(expr sql.Expression) bool { - switch expr := expr.(type) { - case *expression.UnresolvedColumn: - if _, ok := cols[expr.Name()]; !ok { - err = sql.ErrTableColumnNotFound.New(expr.Name()) - return false - } - case *expression.UnresolvedFunction: - err = sql.ErrInvalidConstraintFunctionsNotSupported.New(expr.String()) - return false - case *Subquery: - err = sql.ErrInvalidConstraintSubqueryNotSupported.New(expr.String()) - return false - } - return true - }) + // check existing rows in table + var res interface{} + rowIter, err := p.UnaryNode.Child.RowIter(ctx, nil) if err != nil { return err } + for { + row, err := rowIter.Next() + if row == nil || err != io.EOF { + break + } + res, err = p.ChDef.Expr.Eval(ctx, row) + if err != nil { + return err + } + if val, ok := res.(bool); !ok || !val { + return ErrNoCheckFailed.New(p.ChDef.Expr.String(), row) + } + } + + // Make sure that all columns are valid, in the table, and there are no duplicates + //cols := make(map[string]bool) + //for _, col := range chAlterable.Schema() { + // cols[col.Name] = true + //} + // + //sql.Inspect(p.ChDef.Expr, func(expr sql.Expression) bool { + // switch expr := expr.(type) { + // case *expression.UnresolvedColumn: + // if _, ok := cols[expr.Name()]; !ok { + // err = sql.ErrTableColumnNotFound.New(expr.Name()) + // return false + // } + // case *expression.UnresolvedFunction: + // err = sql.ErrInvalidConstraintFunctionsNotSupported.New(expr.String()) + // return false + // case *Subquery: + // err = sql.ErrInvalidConstraintSubqueryNotSupported.New(expr.String()) + // return false + // } + // return true + //}) + //if err != nil { + // return err + //} //switch p.ChDef.Expr.(type): + // case expression.BinaryExpression: //for _, chCol := range p.ChDef.Expr. { // if seen, ok := seenCols[fkCol]; ok { // if !seen { // seenCols[fkCol] = true // } else { - // return ErrAddForeignKeyDuplicateColumn.New(fkCol) + // return ErrAddForeignKe yDuplicateColumn.New(fkCol) // } // } else { // return sql.ErrTableColumnNotFound.New(fkCol) // } //} - return chAlterable.CreateCheckConstraint(ctx, p.ChDef.Name, p.ChDef.Expr, p.ChDef.Enforced) + return chAlterable.CreateCheck(ctx, NewCheckDefinition(p.ChDef)) +} + +// WithChildren implements the Node interface. +func (p *CreateCheck) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) + } + return NewAlterAddCheck(children[0], p.ChDef), nil +} + +func (p *CreateCheck) Schema() sql.Schema { return nil } + +func (p *CreateCheck) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { + err := p.Execute(ctx) + if err != nil { + return nil, err + } + return sql.RowsToRowIter(), nil +} + +func (p CreateCheck) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("AddCheck(%s)", p.ChDef.Name) + _ = pr.WriteChildren( + fmt.Sprintf("Table(%s)", p.UnaryNode.Child.String()), + fmt.Sprintf("Expr(%s)", p.ChDef.Expr.String()), + ) + return pr.String() } // Execute inserts the rows in the database. @@ -128,7 +222,7 @@ func (p *DropCheck) Execute(ctx *sql.Context) error { if err != nil { return err } - return chAlterable.DropCheckConstraint(ctx, p.ChDef.Name) + return chAlterable.DropCheck(ctx, p.ChDef.Name) } // RowIter implements the Node interface. @@ -147,25 +241,7 @@ func (p *DropCheck) WithChildren(children ...sql.Node) (sql.Node, error) { } return NewAlterDropCheck(children[0], p.ChDef), nil } - -// WithChildren implements the Node interface. -func (p *CreateCheck) WithChildren(children ...sql.Node) (sql.Node, error) { - if len(children) != 1 { - return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) - } - return NewAlterAddCheck(children[0], p.ChDef), nil -} - -func (p *CreateCheck) Schema() sql.Schema { return nil } -func (p *DropCheck) Schema() sql.Schema { return nil } - -func (p *CreateCheck) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { - err := p.Execute(ctx) - if err != nil { - return nil, err - } - return sql.RowsToRowIter(), nil -} +func (p *DropCheck) Schema() sql.Schema { return nil } func (p DropCheck) String() string { pr := sql.NewTreePrinter() @@ -174,9 +250,14 @@ func (p DropCheck) String() string { return pr.String() } -func (p CreateCheck) String() string { - pr := sql.NewTreePrinter() - _ = pr.WriteNode("AddCheck(%s)", p.ChDef.Name) - _ = pr.WriteChildren(fmt.Sprintf("Table(%s)", p.UnaryNode.Child.String())) - return pr.String() +func NewCheckDefinition(check *sql.CheckConstraint) *sql.CheckDefinition { + return &sql.CheckDefinition{ + Name: check.Name, + AlterStatement: fmt.Sprintf( + "ALTER TABLE _ ADD CONSTRAINT %s CHECK %s ENFORCED %v", + check.Name, + check.Expr.String(), + check.Enforced, + ), + } } diff --git a/sql/plan/ddl.go b/sql/plan/ddl.go index 2333876dd0..1c938d149f 100644 --- a/sql/plan/ddl.go +++ b/sql/plan/ddl.go @@ -222,8 +222,8 @@ func (c *CreateTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error if !ok { return sql.RowsToRowIter(), ErrNoCheckConstraintSupport.New(c.name) } - for _, chDef := range c.chDefs { - err = chAlterable.CreateCheckConstraint(ctx, chDef.Name, chDef.Expr, chDef.Enforced) + for _, ch := range c.chDefs { + err = chAlterable.CreateCheck(ctx, NewCheckDefinition(ch)) if err != nil { return sql.RowsToRowIter(), err } From c750d09adfa812bb500172d870a5398d002cbd8b Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Mon, 8 Mar 2021 22:21:35 -0800 Subject: [PATCH 09/16] Add two analyzer nodes -- one for create check, one for insert --- enginetest/enginetests.go | 9 +- sql/analyzer/check_constraints.go | 151 ++++++++++++++++++++++++++++++ sql/analyzer/rules.go | 3 + sql/parse/parse.go | 131 ++++++++++---------------- sql/parse/util.go | 2 +- sql/plan/alter_check.go | 2 +- sql/plan/insert.go | 5 +- 7 files changed, 216 insertions(+), 87 deletions(-) create mode 100644 sql/analyzer/check_constraints.go diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 12781c817d..df6499a233 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -2204,7 +2204,7 @@ func TestCreateCheckConstraints(t *testing.T, harness Harness) { require.Error(err) assert.True(t, sql.ErrTableNotFound.Is(err)) - _, _, err = e.Query(NewContext(harness), "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (c > 0)") + _, _, err = e.Query(NewContext(harness), "ALTER TABLE t1 ADD CONSTRAINT chk3 CHECK (c > 0)") require.Error(err) assert.True(t, sql.ErrTableColumnNotFound.Is(err)) } @@ -2280,14 +2280,15 @@ func TestDropCheckConstraints(t *testing.T, harness Harness) { checks, err := cht.GetChecks(NewContext(harness)) require.NoError(err) - expected := []sql.CheckConstraint{{ + cmp := sql.CheckConstraint{ Name: "chk2", Expr: expression.NewGreaterThan( - expression.NewUnresolvedColumn("b"), + expression.NewUnresolvedColumn("t1.b"), expression.NewLiteral(int8(0), sql.Int8), ), Enforced: true, - }} + } + expected := []sql.CheckDefinition{*plan.NewCheckDefinition(&cmp)} assert.Equal(t, expected, checks) // Some faulty create statements diff --git a/sql/analyzer/check_constraints.go b/sql/analyzer/check_constraints.go new file mode 100644 index 0000000000..9d46c9130e --- /dev/null +++ b/sql/analyzer/check_constraints.go @@ -0,0 +1,151 @@ +// Copyright 2021 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package analyzer + +import ( + "strings" + + "github.com/dolthub/vitess/go/vt/sqlparser" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/parse" + "github.com/dolthub/go-mysql-server/sql/plan" +) + +// validateCreateTrigger handles CreateTrigger nodes, resolving references to "old" and "new" table references in +// the trigger body. Also validates that these old and new references are being used appropriately -- they are only +// valid for certain kinds of triggers and certain statements. +func validateCreateCheck(ctx *sql.Context, a *Analyzer, node sql.Node, scope *Scope) (sql.Node, error) { + ct, ok := node.(*plan.CreateCheck) + if !ok { + return node, nil + } + + chAlterable, ok := ct.UnaryNode.Child.(sql.Table) + if !ok { + return node, nil + } + checkCols := make(map[string]bool) + for _, col := range chAlterable.Schema() { + checkCols[col.Name] = true + } + + var err error + plan.InspectExpressionsWithNode(node, func(n sql.Node, e sql.Expression) bool { + if _, ok := n.(*plan.CreateCheck); !ok { + return true + } + + // Make sure that all columns are valid, in the table, and there are no duplicates + switch expr := e.(type) { + case *expression.UnresolvedColumn: + if _, ok := checkCols[expr.Name()]; !ok { + err = sql.ErrTableColumnNotFound.New(expr.Name()) + return false + } + case *expression.UnresolvedFunction: + err = sql.ErrInvalidConstraintFunctionsNotSupported.New(expr.String()) + return false + case *plan.Subquery: + err = sql.ErrInvalidConstraintSubqueryNotSupported.New(expr.String()) + return false + } + return true + }) + + if err != nil { + return nil, err + } + + return ct, nil +} + +// loadChecks loads any triggers that are required for a plan node to operate properly (except for nodes dealing with +// trigger execution). +func loadChecks(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.Node, error) { + span, _ := ctx.Span("loadChecks") + defer span.Finish() + + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + switch node := n.(type) { + case *plan.InsertInto: + nc := *node + table, ok := nc.Destination.(sql.CheckTable) + if !ok { + return node, nil + } + loadedChecks, err := loadChecksFromTable(ctx, table) + if err != nil { + return nil, err + } + if len(loadedChecks) != 0 { + nc.Checks = loadedChecks + } else { + nc.Checks = make([]*sql.CheckConstraint, 0) + } + return &nc, nil + default: + return node, nil + } + }) +} + +func loadChecksFromTable(ctx *sql.Context, table sql.Table) ([]*sql.CheckConstraint, error) { + var loadedChecks []*sql.CheckConstraint + if checkTable, ok := table.(sql.CheckTable); ok { + checks, err := checkTable.GetChecks(ctx) + if err != nil { + return nil, err + } + for _, ch := range checks { + constraint, err := convertCheckDefToConstraint(ctx, &ch) + if err != nil { + return nil, err + } + loadedChecks = append(loadedChecks, constraint) + } + } + return loadedChecks, nil +} + +func convertCheckDefToConstraint(ctx *sql.Context, check *sql.CheckDefinition) (*sql.CheckConstraint, error) { + parsed, err := sqlparser.ParseStrictDDL(check.AlterStatement) + if err != nil { + return nil, err + } + + ddl, ok := parsed.(*sqlparser.DDL) + if !ok || ddl.ConstraintAction == "" || len(ddl.TableSpec.Constraints) != 1 || strings.ToLower(ddl.ConstraintAction) != sqlparser.AddStr { + return nil, parse.ErrInvalidCheckConstraint.New(check.AlterStatement) + } + + parsedConstraint := ddl.TableSpec.Constraints[0] + chConstraint, ok := parsedConstraint.Details.(*sqlparser.CheckConstraintDefinition) + if !ok || chConstraint.Expr == nil { + return nil, parse.ErrInvalidCheckConstraint.New(check.AlterStatement) + } + + c, err := parse.ExprToExpression(ctx, chConstraint.Expr) + if err != nil { + return nil, err + } + + return &sql.CheckConstraint{ + Name: parsedConstraint.Name, + Expr: c, + Enforced: chConstraint.Enforced, + }, nil +} diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index bd88687812..7a04c5aa95 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -22,6 +22,7 @@ import ( // DefaultRules. var OnceBeforeDefault = []Rule{ {"load_stored_procedures", loadStoredProcedures}, + {"load_check_constraints", loadChecks}, {"resolve_views", resolveViews}, {"resolve_tables", resolveTables}, {"resolve_set_variables", resolveSetVariables}, @@ -32,6 +33,7 @@ var OnceBeforeDefault = []Rule{ {"check_unique_table_names", checkUniqueTableNames}, {"validate_create_trigger", validateCreateTrigger}, {"validate_stored_procedure", validateStoredProcedure}, + {"validate_check_constraint", validateCreateCheck}, } // DefaultRules to apply when analyzing nodes. @@ -60,6 +62,7 @@ var DefaultRules = []Rule{ // DefaultRules. var OnceAfterDefault = []Rule{ {"load_triggers", loadTriggers}, + {"load_checks", loadChecks}, {"process_truncate", processTruncate}, {"resolve_column_defaults", resolveColumnDefaults}, {"resolve_generators", resolveGenerators}, diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 1057f96a6e..b3dd50edb9 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -322,7 +322,7 @@ func convertShow(ctx *sql.Context, s *sqlparser.Show, query string) (sql.Node, e if s.ShowTablesOpt.Filter != nil { if s.ShowTablesOpt.Filter.Filter != nil { var err error - filter, err = exprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) + filter, err = ExprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) if err != nil { return nil, err } @@ -356,7 +356,7 @@ func convertShow(ctx *sql.Context, s *sqlparser.Show, query string) (sql.Node, e if s.ShowTablesOpt.Filter != nil { if s.ShowTablesOpt.Filter.Filter != nil { var err error - filter, err = exprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) + filter, err = ExprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) if err != nil { return nil, err } @@ -370,7 +370,7 @@ func convertShow(ctx *sql.Context, s *sqlparser.Show, query string) (sql.Node, e if s.ShowTablesOpt.AsOf != nil { var err error - asOf, err = exprToExpression(ctx, s.ShowTablesOpt.AsOf) + asOf, err = ExprToExpression(ctx, s.ShowTablesOpt.AsOf) if err != nil { return nil, err } @@ -406,7 +406,7 @@ func convertShow(ctx *sql.Context, s *sqlparser.Show, query string) (sql.Node, e } if s.ShowTablesOpt.Filter.Filter != nil { - filter, err := exprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) + filter, err := ExprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) if err != nil { return nil, err } @@ -429,7 +429,7 @@ func convertShow(ctx *sql.Context, s *sqlparser.Show, query string) (sql.Node, e } if s.ShowCollationFilterOpt != nil { - filterExpr, err := exprToExpression(ctx, *s.ShowCollationFilterOpt) + filterExpr, err := ExprToExpression(ctx, *s.ShowCollationFilterOpt) if err != nil { return nil, err } @@ -672,7 +672,7 @@ func convertCreateProcedure(ctx *sql.Context, query string, c *sqlparser.DDL) (s func convertCall(ctx *sql.Context, c *sqlparser.Call) (sql.Node, error) { params := make([]sql.Expression, len(c.Params)) for i, param := range c.Params { - expr, err := exprToExpression(ctx, param) + expr, err := ExprToExpression(ctx, param) if err != nil { return nil, err } @@ -1038,7 +1038,7 @@ func convertConstraintDefinition(ctx *sql.Context, cd *sqlparser.ConstraintDefin var c sql.Expression var err error if chConstraint.Expr != nil { - c, err = exprToExpression(ctx, chConstraint.Expr) + c, err = ExprToExpression(ctx, chConstraint.Expr) if err != nil { return nil, err } @@ -1124,35 +1124,6 @@ func convertInsert(ctx *sql.Context, i *sqlparser.Insert) (sql.Node, error) { ), nil } -func convertCheckDefToConstraint(ctx *sql.Context, check *sql.CheckDefinition) (*sql.CheckConstraint, error) { - parsed, err := sqlparser.ParseStrictDDL(check.AlterStatement) - if err != nil { - return nil, err - } - - ddl, ok := parsed.(*sqlparser.DDL) - if !ok || ddl.ConstraintAction == "" || len(ddl.TableSpec.Constraints) != 1 || strings.ToLower(ddl.ConstraintAction) != sqlparser.AddStr { - return nil, ErrInvalidCheckConstraint.New(check.AlterStatement) - } - - parsedConstraint := ddl.TableSpec.Constraints[0] - chConstraint, ok := parsedConstraint.Details.(*sqlparser.CheckConstraintDefinition) - if !ok || chConstraint.Expr == nil { - return nil, ErrInvalidCheckConstraint.New(check.AlterStatement) - } - - c, err := exprToExpression(ctx, chConstraint.Expr) - if err != nil { - return nil, err - } - - return &sql.CheckConstraint{ - Name: parsedConstraint.Name, - Expr: c, - Enforced: chConstraint.Enforced, - }, nil -} - func convertDelete(ctx *sql.Context, d *sqlparser.Delete) (sql.Node, error) { node, err := tableExprsToTable(ctx, d.TableExprs) if err != nil { @@ -1333,7 +1304,7 @@ func columnDefinitionToColumn(ctx *sql.Context, cd *sqlparser.ColumnDefinition, var defaultVal *sql.ColumnDefaultValue if cd.Type.Default != nil { - parsedExpr, err := exprToExpression(ctx, cd.Type.Default) + parsedExpr, err := ExprToExpression(ctx, cd.Type.Default) if err != nil { return nil, err } @@ -1390,7 +1361,7 @@ func valuesToValues(ctx *sql.Context, v sqlparser.Values) (sql.Node, error) { exprs := make([]sql.Expression, len(vt)) exprTuples[i] = exprs for j, e := range vt { - expr, err := exprToExpression(ctx, e) + expr, err := ExprToExpression(ctx, e) if err != nil { return nil, err } @@ -1445,7 +1416,7 @@ func tableExprToTable( case sqlparser.TableName: var node *plan.UnresolvedTable if t.AsOf != nil { - asOfExpr, err := exprToExpression(ctx, t.AsOf.Time) + asOfExpr, err := ExprToExpression(ctx, t.AsOf.Time) if err != nil { return nil, err } @@ -1498,7 +1469,7 @@ func tableExprToTable( return plan.NewCrossJoin(left, right), nil } - cond, err := exprToExpression(ctx, t.Condition.On) + cond, err := ExprToExpression(ctx, t.Condition.On) if err != nil { return nil, err } @@ -1517,7 +1488,7 @@ func tableExprToTable( } func whereToFilter(ctx *sql.Context, w *sqlparser.Where, child sql.Node) (*plan.Filter, error) { - c, err := exprToExpression(ctx, w.Expr) + c, err := ExprToExpression(ctx, w.Expr) if err != nil { return nil, err } @@ -1537,7 +1508,7 @@ func orderByToSort(ctx *sql.Context, ob sqlparser.OrderBy, child sql.Node) (*pla func orderByToSortFields(ctx *sql.Context, ob sqlparser.OrderBy) ([]sql.SortField, error) { var sortFields []sql.SortField for _, o := range ob { - e, err := exprToExpression(ctx, o.Expr) + e, err := ExprToExpression(ctx, o.Expr) if err != nil { return nil, err } @@ -1576,7 +1547,7 @@ func limitToLimit( } func havingToHaving(ctx *sql.Context, having *sqlparser.Where, node sql.Node) (sql.Node, error) { - cond, err := exprToExpression(ctx, having.Expr) + cond, err := ExprToExpression(ctx, having.Expr) if err != nil { return nil, err } @@ -1604,7 +1575,7 @@ func offsetToOffset( // getInt64Literal returns an int64 *expression.Literal for the value given, or an unsupported error with the string // given if the expression doesn't represent an integer literal. func getInt64Literal(ctx *sql.Context, expr sqlparser.Expr, errStr string) (*expression.Literal, error) { - e, err := exprToExpression(ctx, expr) + e, err := ExprToExpression(ctx, expr) if err != nil { return nil, err } @@ -1774,7 +1745,7 @@ func StringToColumnDefaultValue(ctx *sql.Context, exprStr string) (*sql.ColumnDe if !ok { return nil, fmt.Errorf("DefaultStringToExpression expected *sqlparser.AliasedExpr but received %T", parserSelect.SelectExprs[0]) } - parsedExpr, err := exprToExpression(ctx, aliasedExpr.Expr) + parsedExpr, err := ExprToExpression(ctx, aliasedExpr.Expr) if err != nil { return nil, err } @@ -1804,7 +1775,7 @@ func MustStringToColumnDefaultValue(ctx *sql.Context, exprStr string, outType sq return expr } -func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error) { +func ExprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error) { switch v := e.(type) { default: return nil, ErrUnsupportedSyntax.New(sqlparser.String(e)) @@ -1816,14 +1787,14 @@ func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error err error ) if v.Name != nil { - name, err = exprToExpression(ctx, v.Name) + name, err = ExprToExpression(ctx, v.Name) } else { - name, err = exprToExpression(ctx, v.StrVal) + name, err = ExprToExpression(ctx, v.StrVal) } if err != nil { return nil, err } - from, err := exprToExpression(ctx, v.From) + from, err := ExprToExpression(ctx, v.From) if err != nil { return nil, err } @@ -1831,7 +1802,7 @@ func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error if v.To == nil { return function.NewSubstring(name, from) } - to, err := exprToExpression(ctx, v.To) + to, err := ExprToExpression(ctx, v.To) if err != nil { return nil, err } @@ -1841,7 +1812,7 @@ func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error case *sqlparser.IsExpr: return isExprToExpression(ctx, v) case *sqlparser.NotExpr: - c, err := exprToExpression(ctx, v.Expr) + c, err := ExprToExpression(ctx, v.Expr) if err != nil { return nil, err } @@ -1882,50 +1853,50 @@ func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error return expression.NewUnresolvedFunction(v.Name.Lowered(), isAggregateFunc(v), overToWindow(ctx, v.Over), exprs...), nil case *sqlparser.ParenExpr: - return exprToExpression(ctx, v.Expr) + return ExprToExpression(ctx, v.Expr) case *sqlparser.AndExpr: - lhs, err := exprToExpression(ctx, v.Left) + lhs, err := ExprToExpression(ctx, v.Left) if err != nil { return nil, err } - rhs, err := exprToExpression(ctx, v.Right) + rhs, err := ExprToExpression(ctx, v.Right) if err != nil { return nil, err } return expression.NewAnd(lhs, rhs), nil case *sqlparser.OrExpr: - lhs, err := exprToExpression(ctx, v.Left) + lhs, err := ExprToExpression(ctx, v.Left) if err != nil { return nil, err } - rhs, err := exprToExpression(ctx, v.Right) + rhs, err := ExprToExpression(ctx, v.Right) if err != nil { return nil, err } return expression.NewOr(lhs, rhs), nil case *sqlparser.ConvertExpr: - expr, err := exprToExpression(ctx, v.Expr) + expr, err := ExprToExpression(ctx, v.Expr) if err != nil { return nil, err } return expression.NewConvert(expr, v.Type.Type), nil case *sqlparser.RangeCond: - val, err := exprToExpression(ctx, v.Left) + val, err := ExprToExpression(ctx, v.Left) if err != nil { return nil, err } - lower, err := exprToExpression(ctx, v.From) + lower, err := ExprToExpression(ctx, v.From) if err != nil { return nil, err } - upper, err := exprToExpression(ctx, v.To) + upper, err := ExprToExpression(ctx, v.To) if err != nil { return nil, err } @@ -1941,7 +1912,7 @@ func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error case sqlparser.ValTuple: var exprs = make([]sql.Expression, len(v)) for i, e := range v { - expr, err := exprToExpression(ctx, e) + expr, err := ExprToExpression(ctx, e) if err != nil { return nil, err } @@ -1968,9 +1939,9 @@ func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error return intervalExprToExpression(ctx, v) case *sqlparser.CollateExpr: // TODO: handle collation - return exprToExpression(ctx, v.Expr) + return ExprToExpression(ctx, v.Expr) case *sqlparser.ValuesFuncExpr: - col, err := exprToExpression(ctx, v.Name) + col, err := ExprToExpression(ctx, v.Name) if err != nil { return nil, err } @@ -1991,7 +1962,7 @@ func overToWindow(ctx *sql.Context, over *sqlparser.Over) *sql.Window { partitions := make([]sql.Expression, len(over.PartitionBy)) for i, expr := range over.PartitionBy { var err error - partitions[i], err = exprToExpression(ctx, expr) + partitions[i], err = ExprToExpression(ctx, expr) if err != nil { return nil } @@ -2080,7 +2051,7 @@ func convertVal(v *sqlparser.SQLVal) (sql.Expression, error) { } func isExprToExpression(ctx *sql.Context, c *sqlparser.IsExpr) (sql.Expression, error) { - e, err := exprToExpression(ctx, c.Expr) + e, err := ExprToExpression(ctx, c.Expr) if err != nil { return nil, err } @@ -2104,12 +2075,12 @@ func isExprToExpression(ctx *sql.Context, c *sqlparser.IsExpr) (sql.Expression, } func comparisonExprToExpression(ctx *sql.Context, c *sqlparser.ComparisonExpr) (sql.Expression, error) { - left, err := exprToExpression(ctx, c.Left) + left, err := ExprToExpression(ctx, c.Left) if err != nil { return nil, err } - right, err := exprToExpression(ctx, c.Right) + right, err := ExprToExpression(ctx, c.Right) if err != nil { return nil, err } @@ -2165,7 +2136,7 @@ func comparisonExprToExpression(ctx *sql.Context, c *sqlparser.ComparisonExpr) ( func groupByToExpressions(ctx *sql.Context, g sqlparser.GroupBy) ([]sql.Expression, error) { es := make([]sql.Expression, len(g)) for i, ve := range g { - e, err := exprToExpression(ctx, ve) + e, err := ExprToExpression(ctx, ve) if err != nil { return nil, err } @@ -2186,7 +2157,7 @@ func selectExprToExpression(ctx *sql.Context, se sqlparser.SelectExpr) (sql.Expr } return expression.NewQualifiedStar(e.TableName.Name.String()), nil case *sqlparser.AliasedExpr: - expr, err := exprToExpression(ctx, e.Expr) + expr, err := ExprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -2202,7 +2173,7 @@ func selectExprToExpression(ctx *sql.Context, se sqlparser.SelectExpr) (sql.Expr func unaryExprToExpression(ctx *sql.Context, e *sqlparser.UnaryExpr) (sql.Expression, error) { switch strings.ToLower(e.Operator) { case sqlparser.MinusStr: - expr, err := exprToExpression(ctx, e.Expr) + expr, err := ExprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -2210,7 +2181,7 @@ func unaryExprToExpression(ctx *sql.Context, e *sqlparser.UnaryExpr) (sql.Expres return expression.NewUnaryMinus(expr), nil case sqlparser.PlusStr: // Unary plus expressions do nothing (do not turn the expression positive). Just return the underlying expression. - return exprToExpression(ctx, e.Expr) + return ExprToExpression(ctx, e.Expr) default: return nil, ErrUnsupportedFeature.New("unary operator: " + e.Operator) @@ -2232,12 +2203,12 @@ func binaryExprToExpression(ctx *sql.Context, be *sqlparser.BinaryExpr) (sql.Exp sqlparser.IntDivStr, sqlparser.ModStr: - l, err := exprToExpression(ctx, be.Left) + l, err := ExprToExpression(ctx, be.Left) if err != nil { return nil, err } - r, err := exprToExpression(ctx, be.Right) + r, err := ExprToExpression(ctx, be.Right) if err != nil { return nil, err } @@ -2264,7 +2235,7 @@ func caseExprToExpression(ctx *sql.Context, e *sqlparser.CaseExpr) (sql.Expressi var err error if e.Expr != nil { - expr, err = exprToExpression(ctx, e.Expr) + expr, err = ExprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -2273,13 +2244,13 @@ func caseExprToExpression(ctx *sql.Context, e *sqlparser.CaseExpr) (sql.Expressi var branches []expression.CaseBranch for _, w := range e.Whens { var cond sql.Expression - cond, err = exprToExpression(ctx, w.Cond) + cond, err = ExprToExpression(ctx, w.Cond) if err != nil { return nil, err } var val sql.Expression - val, err = exprToExpression(ctx, w.Val) + val, err = ExprToExpression(ctx, w.Val) if err != nil { return nil, err } @@ -2292,7 +2263,7 @@ func caseExprToExpression(ctx *sql.Context, e *sqlparser.CaseExpr) (sql.Expressi var elseExpr sql.Expression if e.Else != nil { - elseExpr, err = exprToExpression(ctx, e.Else) + elseExpr, err = ExprToExpression(ctx, e.Else) if err != nil { return nil, err } @@ -2302,7 +2273,7 @@ func caseExprToExpression(ctx *sql.Context, e *sqlparser.CaseExpr) (sql.Expressi } func intervalExprToExpression(ctx *sql.Context, e *sqlparser.IntervalExpr) (sql.Expression, error) { - expr, err := exprToExpression(ctx, e.Expr) + expr, err := ExprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -2313,11 +2284,11 @@ func intervalExprToExpression(ctx *sql.Context, e *sqlparser.IntervalExpr) (sql. func setExprsToExpressions(ctx *sql.Context, e sqlparser.SetExprs) ([]sql.Expression, error) { res := make([]sql.Expression, len(e)) for i, updateExpr := range e { - colName, err := exprToExpression(ctx, updateExpr.Name) + colName, err := ExprToExpression(ctx, updateExpr.Name) if err != nil { return nil, err } - innerExpr, err := exprToExpression(ctx, updateExpr.Expr) + innerExpr, err := ExprToExpression(ctx, updateExpr.Expr) if err != nil { return nil, err } diff --git a/sql/parse/util.go b/sql/parse/util.go index 9d7f3a3d9a..60e9d7a692 100644 --- a/sql/parse/util.go +++ b/sql/parse/util.go @@ -377,7 +377,7 @@ func parseExpr(ctx *sql.Context, str string) (sql.Expression, error) { return nil, errInvalidIndexExpression.New(str) } - return exprToExpression(ctx, selectExpr.Expr) + return ExprToExpression(ctx, selectExpr.Expr) } func readQuotableIdent(ident *string) parseFunc { diff --git a/sql/plan/alter_check.go b/sql/plan/alter_check.go index 8405b04814..e81e2ca45c 100644 --- a/sql/plan/alter_check.go +++ b/sql/plan/alter_check.go @@ -163,7 +163,7 @@ func (p *CreateCheck) Execute(ctx *sql.Context) error { // return false // case *Subquery: // err = sql.ErrInvalidConstraintSubqueryNotSupported.New(expr.String()) - // return false + // return falseErrInvalidCheckConstraint // } // return true //}) diff --git a/sql/plan/insert.go b/sql/plan/insert.go index f159f471ef..71fae1a949 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -46,6 +46,7 @@ type InsertInto struct { ColumnNames []string IsReplace bool OnDupExprs []sql.Expression + Checks []*sql.CheckConstraint } // NewInsertInto creates an InsertInto node. @@ -82,6 +83,7 @@ type insertIter struct { ctx *sql.Context updateExprs []sql.Expression tableNode sql.Node + checks []*sql.CheckConstraint closed bool } @@ -117,6 +119,7 @@ func newInsertIter( values sql.Node, isReplace bool, onDupUpdateExpr []sql.Expression, + checks []*sql.CheckConstraint, row sql.Row, ) (*insertIter, error) { dstSchema := table.Schema() @@ -326,7 +329,7 @@ func (i insertIter) Close(ctx *sql.Context) error { // RowIter implements the Node interface. func (p *InsertInto) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { - return newInsertIter(ctx, p.Destination, p.Source, p.IsReplace, p.OnDupExprs, row) + return newInsertIter(ctx, p.Destination, p.Source, p.IsReplace, p.OnDupExprs, p.Checks, row) } // WithChildren implements the Node interface. From 607459e670a531d834d163ba18c12bd6ea5c4f36 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Fri, 12 Mar 2021 13:57:12 -0800 Subject: [PATCH 10/16] fix insert checking --- enginetest/enginetests.go | 60 +++++++++++++++-- enginetest/memory_engine_test.go | 4 ++ sql/analyzer/check_constraints.go | 42 ++++++++++-- sql/analyzer/rules.go | 3 +- sql/parse/parse.go | 1 + sql/parse/parse_test.go | 6 ++ sql/plan/alter_check.go | 105 ++++++++++++++++-------------- sql/plan/ddl.go | 6 +- sql/plan/insert.go | 36 +++++++--- 9 files changed, 193 insertions(+), 70 deletions(-) diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index df6499a233..8ccc6ec70f 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -2187,7 +2187,7 @@ func TestCreateCheckConstraints(t *testing.T, harness Harness) { checks, err := cht.GetChecks(NewContext(harness)) require.NoError(err) - cmp := sql.CheckConstraint{ + con := sql.CheckConstraint{ Name: "chk2", Expr: expression.NewGreaterThan( expression.NewUnresolvedColumn("t1.b"), @@ -2195,7 +2195,8 @@ func TestCreateCheckConstraints(t *testing.T, harness Harness) { ), Enforced: true, } - expected := []sql.CheckDefinition{*plan.NewCheckDefinition(&cmp)} + cmp, _ := plan.NewCheckDefinition(&con) + expected := []sql.CheckDefinition{*cmp} assert.Equal(t, expected, checks) @@ -2209,6 +2210,55 @@ func TestCreateCheckConstraints(t *testing.T, harness Harness) { assert.True(t, sql.ErrTableColumnNotFound.Is(err)) } +func TestChecksOnInsert(t *testing.T, harness Harness) { + + //require := require.New(t) + + e := NewEngine(t, harness) + //e.Analyzer.Debug = true + //e.Analyzer.Verbose = true + + TestQuery(t, harness, e, + "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER)", + []sql.Row(nil), + nil, + ) + TestQuery(t, harness, e, + "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0)", + []sql.Row(nil), + nil, + ) + RunQuery(t, e, harness, "INSERT INTO t1 VALUES (1,1)") + TestQuery(t, harness, e, `SELECT * FROM t1`, + []sql.Row{ + {1, 1}, + }, + nil, + ) + RunQuery(t, e, harness, "INSERT INTO t1 VALUES (0,0)") + TestQuery(t, harness, e, `SELECT * FROM t1`, + []sql.Row{ + {1, 1}, + }, + nil, + ) + + ctx := NewContext(harness) + require.True(t, len(ctx.Warnings()) > 0) + + expectedCode := 3819 + condition := false + for _, warning := range ctx.Warnings() { + if warning.Code == expectedCode { + condition = true + break + } + } + + require.True(t, condition) + +} + func TestDisallowedCheckConstraints(t *testing.T, harness Harness) { require := require.New(t) e := NewEngine(t, harness) @@ -2280,7 +2330,7 @@ func TestDropCheckConstraints(t *testing.T, harness Harness) { checks, err := cht.GetChecks(NewContext(harness)) require.NoError(err) - cmp := sql.CheckConstraint{ + con := sql.CheckConstraint{ Name: "chk2", Expr: expression.NewGreaterThan( expression.NewUnresolvedColumn("t1.b"), @@ -2288,7 +2338,9 @@ func TestDropCheckConstraints(t *testing.T, harness Harness) { ), Enforced: true, } - expected := []sql.CheckDefinition{*plan.NewCheckDefinition(&cmp)} + cmp, _ := plan.NewCheckDefinition(&con) + expected := []sql.CheckDefinition{*cmp} + assert.Equal(t, expected, checks) // Some faulty create statements diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index d764f87103..e57e391d05 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -321,6 +321,10 @@ func TestCreateCheckConstraints(t *testing.T) { enginetest.TestCreateCheckConstraints(t, enginetest.NewDefaultMemoryHarness()) } +func TestChecksOnInsert(t *testing.T) { + enginetest.TestChecksOnInsert(t, enginetest.NewDefaultMemoryHarness()) +} + func TestTestDisallowedCheckConstraints(t *testing.T) { enginetest.TestDisallowedCheckConstraints(t, enginetest.NewDefaultMemoryHarness()) } diff --git a/sql/analyzer/check_constraints.go b/sql/analyzer/check_constraints.go index 9d46c9130e..d3a282a86c 100644 --- a/sql/analyzer/check_constraints.go +++ b/sql/analyzer/check_constraints.go @@ -83,28 +83,54 @@ func loadChecks(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.No switch node := n.(type) { case *plan.InsertInto: nc := *node - table, ok := nc.Destination.(sql.CheckTable) - if !ok { - return node, nil + table, err := plan.GetCheckTable(nc.Destination) + + if err != nil { + return node, err } + loadedChecks, err := loadChecksFromTable(ctx, table) + if err != nil { return nil, err } + if len(loadedChecks) != 0 { nc.Checks = loadedChecks } else { - nc.Checks = make([]*sql.CheckConstraint, 0) + nc.Checks = make([]sql.Expression, 0) } + return &nc, nil + //case *plan.RenameColumn: + // nc := *node + // table, err := plan.GetCheckTable(nc.) + // + // if err != nil { + // return node, err + // } + // + // loadedChecks, err := loadChecksFromTable(ctx, table) + // + // if err != nil { + // return nil, err + // } + // + // if len(loadedChecks) != 0 { + // nc.Checks = loadedChecks + // } else { + // nc.Checks = make([]sql.Expression, 0) + // } + // + // return &nc, nil default: return node, nil } }) } -func loadChecksFromTable(ctx *sql.Context, table sql.Table) ([]*sql.CheckConstraint, error) { - var loadedChecks []*sql.CheckConstraint +func loadChecksFromTable(ctx *sql.Context, table sql.Table) ([]sql.Expression, error) { + var loadedChecks []sql.Expression if checkTable, ok := table.(sql.CheckTable); ok { checks, err := checkTable.GetChecks(ctx) if err != nil { @@ -115,7 +141,9 @@ func loadChecksFromTable(ctx *sql.Context, table sql.Table) ([]*sql.CheckConstra if err != nil { return nil, err } - loadedChecks = append(loadedChecks, constraint) + if constraint.Enforced { + loadedChecks = append(loadedChecks, constraint.Expr) + } } } return loadedChecks, nil diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index 7a04c5aa95..9195972065 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -22,9 +22,9 @@ import ( // DefaultRules. var OnceBeforeDefault = []Rule{ {"load_stored_procedures", loadStoredProcedures}, - {"load_check_constraints", loadChecks}, {"resolve_views", resolveViews}, {"resolve_tables", resolveTables}, + {"load_check_constraints", loadChecks}, {"resolve_set_variables", resolveSetVariables}, {"resolve_create_like", resolveCreateLike}, {"resolve_subqueries", resolveSubqueries}, @@ -62,7 +62,6 @@ var DefaultRules = []Rule{ // DefaultRules. var OnceAfterDefault = []Rule{ {"load_triggers", loadTriggers}, - {"load_checks", loadChecks}, {"process_truncate", processTruncate}, {"resolve_column_defaults", resolveColumnDefaults}, {"resolve_generators", resolveGenerators}, diff --git a/sql/parse/parse.go b/sql/parse/parse.go index b3dd50edb9..35feee23ed 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -1121,6 +1121,7 @@ func convertInsert(ctx *sql.Context, i *sqlparser.Insert) (sql.Node, error) { isReplace, columnsToStrings(i.Columns), onDupExprs, + nil, ), nil } diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 96c1a29871..5d68b7ca2e 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -1186,6 +1186,7 @@ var fixtures = map[string]sql.Node{ false, []string{"col1", "col2"}, []sql.Expression{}, + nil, ), `INSERT INTO t1 (col1, col2) VALUES (?, ?)`: plan.NewInsertInto( plan.NewUnresolvedTable("t1", ""), @@ -1196,6 +1197,7 @@ var fixtures = map[string]sql.Node{ false, []string{"col1", "col2"}, []sql.Expression{}, + nil, ), `UPDATE t1 SET col1 = ?, col2 = ? WHERE id = ?`: plan.NewUpdate( plan.NewFilter( @@ -1216,6 +1218,7 @@ var fixtures = map[string]sql.Node{ true, []string{"col1", "col2"}, []sql.Expression{}, + nil, ), `SHOW TABLES`: plan.NewShowTables(sql.UnresolvedDatabase(""), false, nil), `SHOW FULL TABLES`: plan.NewShowTables(sql.UnresolvedDatabase(""), true, nil), @@ -2402,6 +2405,7 @@ var fixtures = map[string]sql.Node{ false, []string{"a", "b"}, []sql.Expression{}, + nil, ), }, ), @@ -2429,6 +2433,7 @@ var fixtures = map[string]sql.Node{ false, []string{"a", "b"}, []sql.Expression{}, + nil, ), `CREATE TRIGGER myTrigger BEFORE UPDATE ON foo FOR EACH ROW INSERT INTO zzz (a,b) VALUES (old.a, old.b)`, `INSERT INTO zzz (a,b) VALUES (old.a, old.b)`, @@ -2446,6 +2451,7 @@ var fixtures = map[string]sql.Node{ false, []string{"a", "b"}, []sql.Expression{}, + nil, ), `CREATE TRIGGER myTrigger BEFORE UPDATE ON foo FOR EACH ROW FOLLOWS yourTrigger INSERT INTO zzz (a,b) VALUES (old.a, old.b)`, `INSERT INTO zzz (a,b) VALUES (old.a, old.b)`, diff --git a/sql/plan/alter_check.go b/sql/plan/alter_check.go index e81e2ca45c..36b8ece1ab 100644 --- a/sql/plan/alter_check.go +++ b/sql/plan/alter_check.go @@ -26,8 +26,8 @@ var ( // ErrNoCheckConstraintSupport is returned when the table does not support CONSTRAINT CHECK operations. ErrNoCheckConstraintSupport = errors.NewKind("the table does not support check constraint operations: %s") - // ErrNoCheckFailed is returned when the check constraint evaluates to false - ErrNoCheckFailed = errors.NewKind("check failed: %s, %s") + // ErrCheckFailed is returned when the check constraint evaluates to false + ErrCheckFailed = errors.NewKind("check failed: %s, %s") ) type CreateCheck struct { @@ -71,11 +71,36 @@ func getCheckAlterableTable(t sql.Table) (sql.CheckAlterableTable, error) { return t, nil case sql.TableWrapper: return getCheckAlterableTable(t.Underlying()) + case *ResolvedTable: + return getCheckAlterableTable(t.Table) default: return nil, ErrNoCheckConstraintSupport.New(t.Name()) } } +func GetCheckTable(node sql.Node) (sql.CheckTable, error) { + switch node := node.(type) { + case sql.CheckTable: + return node, nil + case *ResolvedTable: + return getCheckTable(node.Table), nil + default: + return nil, ErrNoCheckConstraintSupport.New(node.String()) + } +} + +// getCheckTable returns the underlying getCheckTable for the table given, or nil if it isn't a getCheckTable +func getCheckTable(t sql.Table) sql.CheckTable { + switch t := t.(type) { + case sql.CheckTable: + return t + case sql.TableWrapper: + return getCheckTable(t.Underlying()) + default: + return nil + } +} + // Expressions implements the sql.Expressioner interface. func (c *CreateCheck) Expressions() []sql.Expression { return []sql.Expression{c.ChDef.Expr} @@ -121,7 +146,6 @@ func (p *CreateCheck) Execute(ctx *sql.Context) error { return err } - //check, err := ConvertCheckDefToConstraint(ctx, p.ChDef) if err != nil { return err } @@ -141,51 +165,17 @@ func (p *CreateCheck) Execute(ctx *sql.Context) error { return err } if val, ok := res.(bool); !ok || !val { - return ErrNoCheckFailed.New(p.ChDef.Expr.String(), row) + return ErrCheckFailed.New(p.ChDef.Expr.String(), row) } } - // Make sure that all columns are valid, in the table, and there are no duplicates - //cols := make(map[string]bool) - //for _, col := range chAlterable.Schema() { - // cols[col.Name] = true - //} - // - //sql.Inspect(p.ChDef.Expr, func(expr sql.Expression) bool { - // switch expr := expr.(type) { - // case *expression.UnresolvedColumn: - // if _, ok := cols[expr.Name()]; !ok { - // err = sql.ErrTableColumnNotFound.New(expr.Name()) - // return false - // } - // case *expression.UnresolvedFunction: - // err = sql.ErrInvalidConstraintFunctionsNotSupported.New(expr.String()) - // return false - // case *Subquery: - // err = sql.ErrInvalidConstraintSubqueryNotSupported.New(expr.String()) - // return falseErrInvalidCheckConstraint - // } - // return true - //}) - //if err != nil { - // return err - //} - //switch p.ChDef.Expr.(type): - - // case expression.BinaryExpression: - //for _, chCol := range p.ChDef.Expr. { - // if seen, ok := seenCols[fkCol]; ok { - // if !seen { - // seenCols[fkCol] = true - // } else { - // return ErrAddForeignKe yDuplicateColumn.New(fkCol) - // } - // } else { - // return sql.ErrTableColumnNotFound.New(fkCol) - // } - //} + check, err := NewCheckDefinition(p.ChDef) + + if err != nil { + return err + } - return chAlterable.CreateCheck(ctx, NewCheckDefinition(p.ChDef)) + return chAlterable.CreateCheck(ctx, check) } // WithChildren implements the Node interface. @@ -250,14 +240,33 @@ func (p DropCheck) String() string { return pr.String() } -func NewCheckDefinition(check *sql.CheckConstraint) *sql.CheckDefinition { +func NewCheckDefinition(check *sql.CheckConstraint) (*sql.CheckDefinition, error) { + //var new sql.Node + //cleaned, err := expression.TransformUp(check.Expr, func(e sql.Expression) (sql.Expression, error) { + // switch expr := e.(type) { + // case *expression.GetField: + // return expr.WithTable(""), nil + // default: + // return expr, nil + // } + // //return new, nil + //}) + //if err != nil { + // return nil, err + //} + + enforced := "" + if !check.Enforced { + enforced = " NOT ENFORCED" + } + return &sql.CheckDefinition{ Name: check.Name, AlterStatement: fmt.Sprintf( - "ALTER TABLE _ ADD CONSTRAINT %s CHECK %s ENFORCED %v", + "ALTER TABLE _ ADD CONSTRAINT %s CHECK (%s)%s", check.Name, check.Expr.String(), - check.Enforced, + enforced, ), - } + }, nil } diff --git a/sql/plan/ddl.go b/sql/plan/ddl.go index 1c938d149f..5cba03499d 100644 --- a/sql/plan/ddl.go +++ b/sql/plan/ddl.go @@ -223,7 +223,11 @@ func (c *CreateTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error return sql.RowsToRowIter(), ErrNoCheckConstraintSupport.New(c.name) } for _, ch := range c.chDefs { - err = chAlterable.CreateCheck(ctx, NewCheckDefinition(ch)) + check, err := NewCheckDefinition(ch) + if err != nil { + return nil, err + } + err = chAlterable.CreateCheck(ctx, check) if err != nil { return sql.RowsToRowIter(), err } diff --git a/sql/plan/insert.go b/sql/plan/insert.go index 71fae1a949..2f31dbb627 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -46,17 +46,18 @@ type InsertInto struct { ColumnNames []string IsReplace bool OnDupExprs []sql.Expression - Checks []*sql.CheckConstraint + Checks []sql.Expression } // NewInsertInto creates an InsertInto node. -func NewInsertInto(dst, src sql.Node, isReplace bool, cols []string, onDupExprs []sql.Expression) *InsertInto { +func NewInsertInto(dst, src sql.Node, isReplace bool, cols []string, onDupExprs []sql.Expression, checks []sql.Expression) *InsertInto { return &InsertInto{ Destination: dst, Source: src, ColumnNames: cols, IsReplace: isReplace, OnDupExprs: onDupExprs, + Checks: checks, } } @@ -82,8 +83,8 @@ type insertIter struct { rowSource sql.RowIter ctx *sql.Context updateExprs []sql.Expression + checks []sql.Expression tableNode sql.Node - checks []*sql.CheckConstraint closed bool } @@ -119,7 +120,7 @@ func newInsertIter( values sql.Node, isReplace bool, onDupUpdateExpr []sql.Expression, - checks []*sql.CheckConstraint, + checks []sql.Expression, row sql.Row, ) (*insertIter, error) { dstSchema := table.Schema() @@ -156,6 +157,7 @@ func newInsertIter( updater: updater, rowSource: rowIter, updateExprs: onDupUpdateExpr, + checks: checks, ctx: ctx, }, nil } @@ -183,6 +185,19 @@ func (i insertIter) Next() (returnRow sql.Row, returnErr error) { return nil, err } + // apply check constraints + var res interface{} + for _, check := range i.checks { + res, err = check.Eval(i.ctx, row) + if err != nil { + return nil, err + } + if val, ok := res.(bool); !ok || !val { + i.ctx.Warn(3819, "Check constraint '%s' is violated", check.String()) + return nil, nil + } + } + // Do any necessary type conversions to the target schema for i, col := range i.schema { if row[i] != nil { @@ -383,15 +398,15 @@ func validateNullability(dstSchema sql.Schema, row sql.Row) error { } func (p *InsertInto) Expressions() []sql.Expression { - return p.OnDupExprs + return append(p.OnDupExprs, p.Checks...) } func (p *InsertInto) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) { - if len(newExprs) != len(p.OnDupExprs) { - return nil, sql.ErrInvalidChildrenNumber.New(p, len(p.OnDupExprs), 1) + if len(newExprs) != len(p.OnDupExprs)+len(p.Checks) { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(p.OnDupExprs)+len(p.Checks), 1) } - return NewInsertInto(p.Destination, p.Source, p.IsReplace, p.ColumnNames, newExprs), nil + return NewInsertInto(p.Destination, p.Source, p.IsReplace, p.ColumnNames, newExprs[:len(p.OnDupExprs)], newExprs[len(p.OnDupExprs):]), nil } // Resolved implements the Resolvable interface. @@ -404,5 +419,10 @@ func (p *InsertInto) Resolved() bool { return false } } + for _, checkExpr := range p.Checks { + if !checkExpr.Resolved() { + return false + } + } return true } From 103ba519149d7ed5340dafb8db7a2800a4991cda Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Fri, 12 Mar 2021 14:36:53 -0800 Subject: [PATCH 11/16] Comments for modify column --- sql/analyzer/check_constraints.go | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/sql/analyzer/check_constraints.go b/sql/analyzer/check_constraints.go index d3a282a86c..eafa3756f6 100644 --- a/sql/analyzer/check_constraints.go +++ b/sql/analyzer/check_constraints.go @@ -102,27 +102,10 @@ func loadChecks(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.No } return &nc, nil + // TODO : reimplement modify column nodes and throw errors here to protect check columns + //case *plan.DropColumn: //case *plan.RenameColumn: - // nc := *node - // table, err := plan.GetCheckTable(nc.) - // - // if err != nil { - // return node, err - // } - // - // loadedChecks, err := loadChecksFromTable(ctx, table) - // - // if err != nil { - // return nil, err - // } - // - // if len(loadedChecks) != 0 { - // nc.Checks = loadedChecks - // } else { - // nc.Checks = make([]sql.Expression, 0) - // } - // - // return &nc, nil + //case *plan.ModifyColumn: default: return node, nil } From e4ad9332947c624a914d8f9a3262b214b134e599 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Fri, 12 Mar 2021 15:20:13 -0800 Subject: [PATCH 12/16] Add more tests --- enginetest/enginetests.go | 60 ++++++++++++++++++++++++++------------- sql/parse/parse_test.go | 42 +++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 20 deletions(-) diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 8ccc6ec70f..efe8502c96 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -2159,8 +2159,6 @@ func TestCreateCheckConstraints(t *testing.T, harness Harness) { require := require.New(t) e := NewEngine(t, harness) - //e.Analyzer.Debug = true - //e.Analyzer.Verbose = true TestQuery(t, harness, e, "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER)", @@ -2168,7 +2166,12 @@ func TestCreateCheckConstraints(t *testing.T, harness Harness) { nil, ) TestQuery(t, harness, e, - "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0)", + "ALTER TABLE t1 ADD CONSTRAINT chk1 CHECK (b > 0)", + []sql.Row(nil), + nil, + ) + TestQuery(t, harness, e, + "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0) NOT ENFORCED", []sql.Row(nil), nil, ) @@ -2187,16 +2190,27 @@ func TestCreateCheckConstraints(t *testing.T, harness Harness) { checks, err := cht.GetChecks(NewContext(harness)) require.NoError(err) - con := sql.CheckConstraint{ - Name: "chk2", - Expr: expression.NewGreaterThan( - expression.NewUnresolvedColumn("t1.b"), - expression.NewLiteral(int8(0), sql.Int8), - ), - Enforced: true, + con := []sql.CheckConstraint{ + { + Name: "chk1", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("t1.b"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }, + { + Name: "chk2", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("t1.b"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: false, + }, } - cmp, _ := plan.NewCheckDefinition(&con) - expected := []sql.CheckDefinition{*cmp} + cmp1, _ := plan.NewCheckDefinition(&con[0]) + cmp2, _ := plan.NewCheckDefinition(&con[1]) + expected := []sql.CheckDefinition{*cmp1, *cmp2} assert.Equal(t, expected, checks) @@ -2212,22 +2226,28 @@ func TestCreateCheckConstraints(t *testing.T, harness Harness) { func TestChecksOnInsert(t *testing.T, harness Harness) { - //require := require.New(t) - e := NewEngine(t, harness) - //e.Analyzer.Debug = true - //e.Analyzer.Verbose = true TestQuery(t, harness, e, "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER)", []sql.Row(nil), nil, ) + TestQuery(t, harness, e, + "ALTER TABLE t1 ADD CONSTRAINT chk1 CHECK (b > 1) NOT ENFORCED", + []sql.Row(nil), + nil, + ) TestQuery(t, harness, e, "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0)", []sql.Row(nil), nil, ) + TestQuery(t, harness, e, + "ALTER TABLE t1 ADD CONSTRAINT chk3 CHECK (b > -1) ENFORCED", + []sql.Row(nil), + nil, + ) RunQuery(t, e, harness, "INSERT INTO t1 VALUES (1,1)") TestQuery(t, harness, e, `SELECT * FROM t1`, []sql.Row{ @@ -2296,7 +2316,7 @@ func TestDropCheckConstraints(t *testing.T, harness Harness) { nil, ) TestQuery(t, harness, e, - "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0)", + "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0) NOT ENFORCED", []sql.Row(nil), nil, ) @@ -2306,7 +2326,7 @@ func TestDropCheckConstraints(t *testing.T, harness Harness) { nil, ) TestQuery(t, harness, e, - "ALTER TABLE t1 DROP CONSTRAINT chk3", + "ALTER TABLE t1 DROP CONSTRAINT chk2", []sql.Row(nil), nil, ) @@ -2331,9 +2351,9 @@ func TestDropCheckConstraints(t *testing.T, harness Harness) { require.NoError(err) con := sql.CheckConstraint{ - Name: "chk2", + Name: "chk3", Expr: expression.NewGreaterThan( - expression.NewUnresolvedColumn("t1.b"), + expression.NewUnresolvedColumn("t1.c"), expression.NewLiteral(int8(0), sql.Int8), ), Enforced: true, diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 5d68b7ca2e..a6fae1566f 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -614,6 +614,48 @@ var fixtures = map[string]sql.Node{ }}, }, ), + `CREATE TABLE t1(a INTEGER PRIMARY KEY CHECK (a > 0) ENFORCED)`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + false, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }}, + ChDefs: []*sql.CheckConstraint{{ + Name: "", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }}, + }, + ), + `CREATE TABLE t1(a INTEGER PRIMARY KEY CHECK (a > 0) NOT ENFORCED)`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + false, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }}, + ChDefs: []*sql.CheckConstraint{{ + Name: "", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: false, + }}, + }, + ), `DROP TABLE foo;`: plan.NewDropTable( sql.UnresolvedDatabase(""), false, "foo", ), From ff1536cd02a658f6ad9cccaa1be77058c8305630 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Fri, 12 Mar 2021 15:43:32 -0800 Subject: [PATCH 13/16] Fix bad merge --- sql/analyzer/rules.go | 1 - sql/parse/parse.go | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index 03067285e4..3fc580eae1 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -33,7 +33,6 @@ var OnceBeforeDefault = []Rule{ {"resolve_describe_query", resolveDescribeQuery}, {"check_unique_table_names", checkUniqueTableNames}, {"validate_create_trigger", validateCreateTrigger}, - {"validate_stored_procedure", validateStoredProcedure}, {"validate_create_procedure", validateCreateProcedure}, {"validate_check_constraint", validateCreateCheck}, {"assign_info_schema", assignInfoSchema}, diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 6802597507..21adb72b04 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -1176,7 +1176,7 @@ func convertConstraintDefinition(ctx *sql.Context, cd *sqlparser.ConstraintDefin OnUpdate: convertReferenceAction(fkConstraint.OnUpdate), OnDelete: convertReferenceAction(fkConstraint.OnDelete), }, nil - } else if chConstraint, ok := cd.Details.(*sqlparser.CheckConstraintDefinitiong); ok { + } else if chConstraint, ok := cd.Details.(*sqlparser.CheckConstraintDefinition); ok { var c sql.Expression var err error if chConstraint.Expr != nil { From dbb2af22b8324a9f38d9f2e11549ceb0bbdd73e8 Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Fri, 12 Mar 2021 16:15:38 -0800 Subject: [PATCH 14/16] Fix sloppy errors --- enginetest/harness.go | 2 +- memory/table.go | 2 +- sql/analyzer/check_constraints.go | 13 +++---- sql/core.go | 2 +- sql/plan/alter_check.go | 63 ++++++++++--------------------- sql/plan/ddl.go | 2 +- 6 files changed, 29 insertions(+), 55 deletions(-) diff --git a/enginetest/harness.go b/enginetest/harness.go index 700ce74c67..0e66126b97 100755 --- a/enginetest/harness.go +++ b/enginetest/harness.go @@ -67,7 +67,7 @@ type ForeignKeyHarness interface { } // CheckConstraintHarness is an extension to Harness that lets an integrator test their implementation with check constraints. -// Integrator tables must implement sql.CheckAlterableTable and sql.CheckConstraintTable. +// Integrator tables must implement sql.CheckAlterableTable and sql.CheckTable. type CheckConstraintHarness interface { Harness // SupportsCheckConstraint returns whether this harness should accept CREATE CHECK statements as part of test diff --git a/memory/table.go b/memory/table.go index 168505e5ba..42a0360613 100644 --- a/memory/table.go +++ b/memory/table.go @@ -1046,7 +1046,7 @@ func (t *Table) GetIndexes(ctx *sql.Context) ([]sql.Index, error) { return append(indexes, nonPrimaryIndexes...), nil } -// GetForeignKeys implements sql.CheckTable +// GetForeignKeys implements sql.ForeignKeyTable func (t *Table) GetForeignKeys(_ *sql.Context) ([]sql.ForeignKeyConstraint, error) { return t.foreignKeys, nil } diff --git a/sql/analyzer/check_constraints.go b/sql/analyzer/check_constraints.go index eafa3756f6..5b8686c2d3 100644 --- a/sql/analyzer/check_constraints.go +++ b/sql/analyzer/check_constraints.go @@ -25,9 +25,9 @@ import ( "github.com/dolthub/go-mysql-server/sql/plan" ) -// validateCreateTrigger handles CreateTrigger nodes, resolving references to "old" and "new" table references in -// the trigger body. Also validates that these old and new references are being used appropriately -- they are only -// valid for certain kinds of triggers and certain statements. +// validateCreateCheck handles CreateCheck nodes, resolving references to "old" and "new" table references in +// the check body. Also validates that these old and new references are being used appropriately -- they are only +// valid for certain kinds of checks and certain statements. func validateCreateCheck(ctx *sql.Context, a *Analyzer, node sql.Node, scope *Scope) (sql.Node, error) { ct, ok := node.(*plan.CreateCheck) if !ok { @@ -38,6 +38,7 @@ func validateCreateCheck(ctx *sql.Context, a *Analyzer, node sql.Node, scope *Sc if !ok { return node, nil } + checkCols := make(map[string]bool) for _, col := range chAlterable.Schema() { checkCols[col.Name] = true @@ -73,8 +74,8 @@ func validateCreateCheck(ctx *sql.Context, a *Analyzer, node sql.Node, scope *Sc return ct, nil } -// loadChecks loads any triggers that are required for a plan node to operate properly (except for nodes dealing with -// trigger execution). +// loadChecks loads any checks that are required for a plan node to operate properly (except for nodes dealing with +// check execution). func loadChecks(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.Node, error) { span, _ := ctx.Span("loadChecks") defer span.Finish() @@ -84,13 +85,11 @@ func loadChecks(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.No case *plan.InsertInto: nc := *node table, err := plan.GetCheckTable(nc.Destination) - if err != nil { return node, err } loadedChecks, err := loadChecksFromTable(ctx, table) - if err != nil { return nil, err } diff --git a/sql/core.go b/sql/core.go index b4686c28de..a693fd170c 100644 --- a/sql/core.go +++ b/sql/core.go @@ -370,7 +370,7 @@ type CheckTable interface { GetChecks(ctx *Context) ([]CheckDefinition, error) } -// ForeignKeyAlterableTable represents a table that supports foreign key modification operations. +// CheckAlterableTable represents a table that supports check constraints. type CheckAlterableTable interface { Table // CreateCheck creates an check constraint for this table, using the provided parameters. diff --git a/sql/plan/alter_check.go b/sql/plan/alter_check.go index 36b8ece1ab..027a3a5172 100644 --- a/sql/plan/alter_check.go +++ b/sql/plan/alter_check.go @@ -32,25 +32,25 @@ var ( type CreateCheck struct { UnaryNode - ChDef *sql.CheckConstraint + Check *sql.CheckConstraint } type DropCheck struct { UnaryNode - ChDef *sql.CheckConstraint + Check *sql.CheckConstraint } -func NewAlterAddCheck(table sql.Node, chDef *sql.CheckConstraint) *CreateCheck { +func NewAlterAddCheck(table sql.Node, check *sql.CheckConstraint) *CreateCheck { return &CreateCheck{ UnaryNode: UnaryNode{table}, - ChDef: chDef, + Check: check, } } -func NewAlterDropCheck(table sql.Node, chDef *sql.CheckConstraint) *DropCheck { +func NewAlterDropCheck(table sql.Node, check *sql.CheckConstraint) *DropCheck { return &DropCheck{ UnaryNode: UnaryNode{Child: table}, - ChDef: chDef, + Check: check, } } @@ -103,13 +103,13 @@ func getCheckTable(t sql.Table) sql.CheckTable { // Expressions implements the sql.Expressioner interface. func (c *CreateCheck) Expressions() []sql.Expression { - return []sql.Expression{c.ChDef.Expr} + return []sql.Expression{c.Check.Expr} } // Resolved implements the Resolvable interface. func (c *CreateCheck) Resolved() bool { ok := true - sql.Inspect(c.ChDef.Expr, func(expr sql.Expression) bool { + sql.Inspect(c.Check.Expr, func(expr sql.Expression) bool { switch expr.(type) { case *expression.UnresolvedColumn: ok = false @@ -127,16 +127,8 @@ func (c *CreateCheck) WithExpressions(exprs ...sql.Expression) (sql.Node, error) } nc := *c - nc.ChDef.Expr = exprs[0] + nc.Check.Expr = exprs[0] return &nc, nil - //return &CreateCheck{ - // UnaryNode: c.UnaryNode, - // ChDef: &sql.CheckConstraint{ - // Name: c.ChDef.Name, - // Expr: exprs[0], - // Enforced: c.ChDef.Enforced, - // }, - //}, nil } // Execute inserts the rows in the database. @@ -146,31 +138,28 @@ func (p *CreateCheck) Execute(ctx *sql.Context) error { return err } - if err != nil { - return err - } // check existing rows in table var res interface{} rowIter, err := p.UnaryNode.Child.RowIter(ctx, nil) if err != nil { return err } + for { row, err := rowIter.Next() if row == nil || err != io.EOF { break } - res, err = p.ChDef.Expr.Eval(ctx, row) + res, err = p.Check.Expr.Eval(ctx, row) if err != nil { return err } if val, ok := res.(bool); !ok || !val { - return ErrCheckFailed.New(p.ChDef.Expr.String(), row) + return ErrCheckFailed.New(p.Check.Expr.String(), row) } } - check, err := NewCheckDefinition(p.ChDef) - + check, err := NewCheckDefinition(p.Check) if err != nil { return err } @@ -183,7 +172,7 @@ func (p *CreateCheck) WithChildren(children ...sql.Node) (sql.Node, error) { if len(children) != 1 { return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) } - return NewAlterAddCheck(children[0], p.ChDef), nil + return NewAlterAddCheck(children[0], p.Check), nil } func (p *CreateCheck) Schema() sql.Schema { return nil } @@ -198,10 +187,10 @@ func (p *CreateCheck) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error func (p CreateCheck) String() string { pr := sql.NewTreePrinter() - _ = pr.WriteNode("AddCheck(%s)", p.ChDef.Name) + _ = pr.WriteNode("AddCheck(%s)", p.Check.Name) _ = pr.WriteChildren( fmt.Sprintf("Table(%s)", p.UnaryNode.Child.String()), - fmt.Sprintf("Expr(%s)", p.ChDef.Expr.String()), + fmt.Sprintf("Expr(%s)", p.Check.Expr.String()), ) return pr.String() } @@ -212,7 +201,7 @@ func (p *DropCheck) Execute(ctx *sql.Context) error { if err != nil { return err } - return chAlterable.DropCheck(ctx, p.ChDef.Name) + return chAlterable.DropCheck(ctx, p.Check.Name) } // RowIter implements the Node interface. @@ -229,32 +218,18 @@ func (p *DropCheck) WithChildren(children ...sql.Node) (sql.Node, error) { if len(children) != 1 { return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) } - return NewAlterDropCheck(children[0], p.ChDef), nil + return NewAlterDropCheck(children[0], p.Check), nil } func (p *DropCheck) Schema() sql.Schema { return nil } func (p DropCheck) String() string { pr := sql.NewTreePrinter() - _ = pr.WriteNode("DropCheck(%s)", p.ChDef.Name) + _ = pr.WriteNode("DropCheck(%s)", p.Check.Name) _ = pr.WriteChildren(fmt.Sprintf("Table(%s)", p.UnaryNode.Child.String())) return pr.String() } func NewCheckDefinition(check *sql.CheckConstraint) (*sql.CheckDefinition, error) { - //var new sql.Node - //cleaned, err := expression.TransformUp(check.Expr, func(e sql.Expression) (sql.Expression, error) { - // switch expr := e.(type) { - // case *expression.GetField: - // return expr.WithTable(""), nil - // default: - // return expr, nil - // } - // //return new, nil - //}) - //if err != nil { - // return nil, err - //} - enforced := "" if !check.Enforced { enforced = " NOT ENFORCED" diff --git a/sql/plan/ddl.go b/sql/plan/ddl.go index 5cba03499d..72f554bb4b 100644 --- a/sql/plan/ddl.go +++ b/sql/plan/ddl.go @@ -112,10 +112,10 @@ type CreateTable struct { ddlNode name string schema sql.Schema + ifNotExists bool fkDefs []*sql.ForeignKeyConstraint chDefs []*sql.CheckConstraint idxDefs []*IndexDefinition - ifNotExists bool like sql.Node } From de93a9a5b6dfb2fa6b305d0584a408816c2e8c1e Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Fri, 12 Mar 2021 16:43:38 -0800 Subject: [PATCH 15/16] Fix error message --- sql/plan/alter_check.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/plan/alter_check.go b/sql/plan/alter_check.go index 027a3a5172..d1a987d21f 100644 --- a/sql/plan/alter_check.go +++ b/sql/plan/alter_check.go @@ -27,7 +27,7 @@ var ( ErrNoCheckConstraintSupport = errors.NewKind("the table does not support check constraint operations: %s") // ErrCheckFailed is returned when the check constraint evaluates to false - ErrCheckFailed = errors.NewKind("check failed: %s, %s") + ErrCheckFailed = errors.NewKind("check constraint %s is violated.") ) type CreateCheck struct { @@ -155,7 +155,7 @@ func (p *CreateCheck) Execute(ctx *sql.Context) error { return err } if val, ok := res.(bool); !ok || !val { - return ErrCheckFailed.New(p.Check.Expr.String(), row) + return ErrCheckFailed.New(p.Check.Name) } } From 3cc12aa85030045c7223d7e3452e0d36e3fa324f Mon Sep 17 00:00:00 2001 From: Max Hoffman Date: Wed, 17 Mar 2021 12:36:20 -0700 Subject: [PATCH 16/16] Zach's comments, bump vitess --- enginetest/enginetests.go | 96 +++++-------------- enginetest/harness.go | 9 -- go.mod | 3 +- go.sum | 2 + sql/analyzer/check_constraints.go | 19 +++- sql/analyzer/join_search.go | 6 +- sql/analyzer/join_search_test.go | 8 +- sql/analyzer/rules.go | 2 +- .../function/aggregation/json_agg_test.go | 2 +- sql/expression/function/json_unsupported.go | 13 --- sql/json.go | 1 - sql/json_test.go | 10 +- sql/json_value.go | 24 ++--- sql/parse/parse.go | 2 + sql/parse/parse_test.go | 11 +++ sql/plan/alter_check.go | 23 ----- sql/plan/create_index.go | 2 +- sql/testutils.go | 1 - 18 files changed, 82 insertions(+), 152 deletions(-) diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 9471a4c732..6e71190027 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -2192,21 +2192,10 @@ func TestCreateCheckConstraints(t *testing.T, harness Harness) { e := NewEngine(t, harness) - TestQuery(t, harness, e, - "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER)", - []sql.Row(nil), - nil, - ) - TestQuery(t, harness, e, - "ALTER TABLE t1 ADD CONSTRAINT chk1 CHECK (b > 0)", - []sql.Row(nil), - nil, - ) - TestQuery(t, harness, e, - "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0) NOT ENFORCED", - []sql.Row(nil), - nil, - ) + RunQuery(t, e, harness, "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER)") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT chk1 CHECK (b > 0)") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0) NOT ENFORCED") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT CHECK (b > 1)") db, err := e.Catalog.Database("mydb") require.NoError(err) @@ -2239,10 +2228,19 @@ func TestCreateCheckConstraints(t *testing.T, harness Harness) { ), Enforced: false, }, + { + Name: "", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("t1.b"), + expression.NewLiteral(int8(1), sql.Int8), + ), + Enforced: true, + }, } cmp1, _ := plan.NewCheckDefinition(&con[0]) cmp2, _ := plan.NewCheckDefinition(&con[1]) - expected := []sql.CheckDefinition{*cmp1, *cmp2} + cmp3, _ := plan.NewCheckDefinition(&con[2]) + expected := []sql.CheckDefinition{*cmp1, *cmp2, *cmp3} assert.Equal(t, expected, checks) @@ -2260,26 +2258,10 @@ func TestChecksOnInsert(t *testing.T, harness Harness) { e := NewEngine(t, harness) - TestQuery(t, harness, e, - "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER)", - []sql.Row(nil), - nil, - ) - TestQuery(t, harness, e, - "ALTER TABLE t1 ADD CONSTRAINT chk1 CHECK (b > 1) NOT ENFORCED", - []sql.Row(nil), - nil, - ) - TestQuery(t, harness, e, - "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0)", - []sql.Row(nil), - nil, - ) - TestQuery(t, harness, e, - "ALTER TABLE t1 ADD CONSTRAINT chk3 CHECK (b > -1) ENFORCED", - []sql.Row(nil), - nil, - ) + RunQuery(t, e, harness, "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER)") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT chk1 CHECK (b > 1) NOT ENFORCED") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0)") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT chk3 CHECK (b > -1) ENFORCED") RunQuery(t, e, harness, "INSERT INTO t1 VALUES (1,1)") TestQuery(t, harness, e, `SELECT * FROM t1`, []sql.Row{ @@ -2316,11 +2298,7 @@ func TestDisallowedCheckConstraints(t *testing.T, harness Harness) { e := NewEngine(t, harness) var err error - TestQuery(t, harness, e, - "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER)", - []sql.Row(nil), - nil, - ) + RunQuery(t, e, harness, "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER)") // functions, UDFs, procedures _, _, err = e.Query(NewContext(harness), "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (current_user = \"root@\")") @@ -2337,36 +2315,12 @@ func TestDropCheckConstraints(t *testing.T, harness Harness) { e := NewEngine(t, harness) - TestQuery(t, harness, e, - "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER, c integer)", - []sql.Row(nil), - nil, - ) - TestQuery(t, harness, e, - "ALTER TABLE t1 ADD CONSTRAINT chk1 CHECK (a > 0)", - []sql.Row(nil), - nil, - ) - TestQuery(t, harness, e, - "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0) NOT ENFORCED", - []sql.Row(nil), - nil, - ) - TestQuery(t, harness, e, - "ALTER TABLE t1 ADD CONSTRAINT chk3 CHECK (c > 0)", - []sql.Row(nil), - nil, - ) - TestQuery(t, harness, e, - "ALTER TABLE t1 DROP CONSTRAINT chk2", - []sql.Row(nil), - nil, - ) - TestQuery(t, harness, e, - "ALTER TABLE t1 DROP CHECK chk1", - []sql.Row(nil), - nil, - ) + RunQuery(t, e, harness, "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER, c integer)") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT chk1 CHECK (a > 0)") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0) NOT ENFORCED") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT chk3 CHECK (c > 0)") + RunQuery(t, e, harness, "ALTER TABLE t1 DROP CONSTRAINT chk2") + RunQuery(t, e, harness, "ALTER TABLE t1 DROP CHECK chk1") db, err := e.Catalog.Database("mydb") require.NoError(err) diff --git a/enginetest/harness.go b/enginetest/harness.go index 0e66126b97..0a5b104e00 100755 --- a/enginetest/harness.go +++ b/enginetest/harness.go @@ -66,15 +66,6 @@ type ForeignKeyHarness interface { SupportsForeignKeys() bool } -// CheckConstraintHarness is an extension to Harness that lets an integrator test their implementation with check constraints. -// Integrator tables must implement sql.CheckAlterableTable and sql.CheckTable. -type CheckConstraintHarness interface { - Harness - // SupportsCheckConstraint returns whether this harness should accept CREATE CHECK statements as part of test - // setup. - SupportsCheckConstraint() bool -} - // VersionedDBHarness is an extension to Harness that lets an integrator test their implementation of versioned (AS OF) // queries. Integrators must implement sql.VersionedDatabase. For each table version being created, there will be a // call to NewTableAsOf, some number of Delete and Insert operations, and then a call to SnapshotTable. diff --git a/go.mod b/go.mod index 7fac09dc7c..1f01113e9b 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ require ( github.com/VividCortex/gohistogram v1.0.0 // indirect github.com/cespare/xxhash v1.1.0 github.com/dolthub/sqllogictest/go v0.0.0-20201105013724-5123fc66e12c - github.com/dolthub/vitess v0.0.0-20210316150645-9bf20be78424 + github.com/dolthub/vitess v0.0.0-20210317175908-b2663b2c4d9c github.com/fastly/go-utils v0.0.0-20180712184237-d95a45783239 // indirect github.com/go-kit/kit v0.9.0 github.com/go-sql-driver/mysql v1.4.1 @@ -25,6 +25,7 @@ require ( github.com/src-d/go-oniguruma v1.1.0 github.com/stretchr/testify v1.4.0 github.com/tebeka/strftime v0.1.4 // indirect + google.golang.org/genproto v0.0.0-20190926190326-7ee9db18f195 google.golang.org/grpc v1.27.0 // indirect gopkg.in/src-d/go-errors.v1 v1.0.0 ) diff --git a/go.sum b/go.sum index 3b585820b2..07c028ddb2 100755 --- a/go.sum +++ b/go.sum @@ -17,6 +17,8 @@ github.com/dolthub/sqllogictest/go v0.0.0-20201105013724-5123fc66e12c h1:ZIo6IOX github.com/dolthub/sqllogictest/go v0.0.0-20201105013724-5123fc66e12c/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= github.com/dolthub/vitess v0.0.0-20210316150645-9bf20be78424 h1:QGVknlmy9WZoTOTyC50W+HCho7t3/CrpMxJZm+tQ25A= github.com/dolthub/vitess v0.0.0-20210316150645-9bf20be78424/go.mod h1:hUE8oSk2H5JZnvtlLBhJPYC8WZCA5AoSntdLTcBvdBM= +github.com/dolthub/vitess v0.0.0-20210317175908-b2663b2c4d9c h1:u/hKK5EBh+FJlZlcMjSYLB9qXTG+pnyaa6Vww8sSHvY= +github.com/dolthub/vitess v0.0.0-20210317175908-b2663b2c4d9c/go.mod h1:hUE8oSk2H5JZnvtlLBhJPYC8WZCA5AoSntdLTcBvdBM= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fastly/go-utils v0.0.0-20180712184237-d95a45783239 h1:Ghm4eQYC0nEPnSJdVkTrXpu9KtoVCSo1hg7mtI7G9KU= diff --git a/sql/analyzer/check_constraints.go b/sql/analyzer/check_constraints.go index 5b8686c2d3..f6440a4a07 100644 --- a/sql/analyzer/check_constraints.go +++ b/sql/analyzer/check_constraints.go @@ -52,7 +52,12 @@ func validateCreateCheck(ctx *sql.Context, a *Analyzer, node sql.Node, scope *Sc // Make sure that all columns are valid, in the table, and there are no duplicates switch expr := e.(type) { - case *expression.UnresolvedColumn: + case *deferredColumn: + if _, ok := checkCols[expr.Name()]; !ok { + err = sql.ErrTableColumnNotFound.New(expr.Name()) + return false + } + case *expression.GetField: if _, ok := checkCols[expr.Name()]; !ok { err = sql.ErrTableColumnNotFound.New(expr.Name()) return false @@ -84,9 +89,15 @@ func loadChecks(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.No switch node := n.(type) { case *plan.InsertInto: nc := *node - table, err := plan.GetCheckTable(nc.Destination) - if err != nil { - return node, err + + rtable, ok := nc.Destination.(*plan.ResolvedTable) + if !ok { + return node, nil + } + + table, ok := rtable.Table.(sql.CheckAlterableTable) + if !ok { + return node, nil } loadedChecks, err := loadChecksFromTable(ctx, table) diff --git a/sql/analyzer/join_search.go b/sql/analyzer/join_search.go index af36c73d20..a512c07cde 100755 --- a/sql/analyzer/join_search.go +++ b/sql/analyzer/join_search.go @@ -207,7 +207,8 @@ func (jo *joinOrderNode) applyJoinHintTables(tables []string) ([]string, error) assigned := make(map[int]struct{}) order := []int{} remaining := tables -START: for { +START: + for { var i int for i = range jo.commutes { if _, ok := assigned[i]; ok { @@ -368,7 +369,7 @@ func (jo *joinOrderNode) visitJoinSearchNodes(cb func(n *joinSearchNode) bool) { } else if jo.left != nil { stop := false jo.left.visitJoinSearchNodes(func(l *joinSearchNode) bool { - jo.right.visitJoinSearchNodes(func (r *joinSearchNode) bool { + jo.right.visitJoinSearchNodes(func(r *joinSearchNode) bool { if !cb(&joinSearchNode{left: l, right: r}) { stop = true } @@ -439,7 +440,6 @@ func newJoinOrderNode(node sql.Node) *joinOrderNode { } } - // A joinSearchNode is a simplified type representing a join tree node, which is either an internal node (a join) or a // leaf node (a table). The top level node in a join tree is always an internal node. Every internal node has both a // left and a right child. diff --git a/sql/analyzer/join_search_test.go b/sql/analyzer/join_search_test.go index 78b2803209..ac18e99af2 100755 --- a/sql/analyzer/join_search_test.go +++ b/sql/analyzer/join_search_test.go @@ -532,9 +532,9 @@ func TestBuildJoinTree(t *testing.T) { }, }, { - name: "explicit subtree, A((EB)(DC))", + name: "explicit subtree, A((EB)(DC))", tableOrder: tableOrder("A", &joinOrderNode{ - left: tableOrder("E", "B"), + left: tableOrder("E", "B"), right: tableOrder("D", "C"), }), joinConds: []*joinCond{ @@ -562,9 +562,9 @@ func TestBuildJoinTree(t *testing.T) { }, }, { - name: "explicit subtree, A((EB)(D))C", + name: "explicit subtree, A((EB)(D))C", tableOrder: tableOrder("A", &joinOrderNode{ - left: tableOrder("E", "B"), + left: tableOrder("E", "B"), right: tableOrder("D"), }, "C"), joinConds: []*joinCond{ diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index 3fc580eae1..9180f4e53e 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -34,7 +34,6 @@ var OnceBeforeDefault = []Rule{ {"check_unique_table_names", checkUniqueTableNames}, {"validate_create_trigger", validateCreateTrigger}, {"validate_create_procedure", validateCreateProcedure}, - {"validate_check_constraint", validateCreateCheck}, {"assign_info_schema", assignInfoSchema}, } @@ -46,6 +45,7 @@ var DefaultRules = []Rule{ {"pushdown_groupby_aliases", pushdownGroupByAliases}, {"qualify_columns", qualifyColumns}, {"resolve_columns", resolveColumns}, + {"validate_check_constraint", validateCreateCheck}, {"resolve_bareword_set_variables", resolveBarewordSetVariables}, {"resolve_database", resolveDatabase}, {"expand_stars", expandStars}, diff --git a/sql/expression/function/aggregation/json_agg_test.go b/sql/expression/function/aggregation/json_agg_test.go index b91b957801..450f07e74c 100644 --- a/sql/expression/function/aggregation/json_agg_test.go +++ b/sql/expression/function/aggregation/json_agg_test.go @@ -69,7 +69,7 @@ func TestJsonArrayAgg_Empty(t *testing.T) { v, err := j.Eval(ctx, b) assert.NoError(err) - assert.Equal(sql.JSONDocument{Val:[]interface{}(nil)}, v) + assert.Equal(sql.JSONDocument{Val: []interface{}(nil)}, v) } func TestJsonArrayAgg_JSON(t *testing.T) { diff --git a/sql/expression/function/json_unsupported.go b/sql/expression/function/json_unsupported.go index d4f3c1456b..f7b393434c 100644 --- a/sql/expression/function/json_unsupported.go +++ b/sql/expression/function/json_unsupported.go @@ -27,7 +27,6 @@ var ErrUnsupportedJSONFunction = errors.NewKind("unsupported JSON function: %s") // JSON search functions // /////////////////////////// - // JSON_CONTAINS_PATH(json_doc, one_or_all, path[, path] ...) // // JSONContainsPath Returns 0 or 1 to indicate whether a JSON document contains data at a given path or paths. Returns @@ -185,12 +184,10 @@ func (j JSONValue) FunctionName() string { // https://dev.mysql.com/doc/refman/8.0/en/json-search-functions.html#operator_member-of // TODO(andy): relocate - ///////////////////////////// // JSON creation functions // ///////////////////////////// - // JSON_ARRAY([val[, val] ...]) // // JSONArray Evaluates a (possibly empty) list of values and returns a JSON array containing those values. @@ -260,12 +257,10 @@ func (j JSONQuote) FunctionName() string { return "json_quote" } - ///////////////////////////////// // JSON modification functions // ///////////////////////////////// - // JSON_ARRAY_APPEND(json_doc, path, val[, path, val] ...) // // JSONArrayAppend Appends values to the end of the indicated arrays within a JSON document and returns the result. @@ -520,12 +515,10 @@ func (j JSONSet) FunctionName() string { return "json_set" } - ////////////////////////////// // JSON attribute functions // ////////////////////////////// - // JSON_DEPTH(json_doc) // // JSONDepth Returns the maximum depth of a JSON document. Returns NULL if the argument is NULL. An error occurs if the @@ -621,12 +614,10 @@ func (j JSONValid) FunctionName() string { return "json_valid" } - ////////////////////////// // JSON table functions // ////////////////////////// - // JSON_TABLE(expr, path COLUMNS (column_list) [AS] alias) // // JSONTable Extracts data from a JSON document and returns it as a relational table having the specified columns. @@ -649,12 +640,10 @@ func (j JSONTable) FunctionName() string { return "json_table" } - /////////////////////////////// // JSON validation functions // /////////////////////////////// - // JSON_SCHEMA_VALID(schema,document) // // JSONSchemaValid Validates a JSON document against a JSON schema. Both schema and document are required. The schema @@ -711,12 +700,10 @@ func (j JSONSchemaValidationReport) FunctionName() string { return "json_schema_validation_report" } - //////////////////////////// // JSON utility functions // //////////////////////////// - // JSON_PRETTY(json_val) // // JSONPretty Provides pretty-printing of JSON values similar to that implemented in PHP and by other languages and diff --git a/sql/json.go b/sql/json.go index 05451d1432..a3de7622dd 100644 --- a/sql/json.go +++ b/sql/json.go @@ -31,7 +31,6 @@ type JsonType interface { type jsonType struct{} - // Compare implements Type interface. func (t jsonType) Compare(a interface{}, b interface{}) (int, error) { var err error diff --git a/sql/json_test.go b/sql/json_test.go index f70b1221af..248a7101b8 100644 --- a/sql/json_test.go +++ b/sql/json_test.go @@ -32,7 +32,7 @@ func TestJsonCompare(t *testing.T) { {`true`, `[0]`, 1}, {`[0]`, `{"a": 0}`, 1}, {`{"a": 0}`, `"a"`, 1}, - { `"a"`, `0`, 1}, + {`"a"`, `0`, 1}, {`0`, `null`, 1}, // null @@ -63,20 +63,20 @@ func TestJsonCompare(t *testing.T) { // objects {`{"a": 0}`, `{"a": 0}`, 0}, // deterministic object ordering with arbitrary rules - {`{"a": 1}`, `{"a": 0}`, 1}, // 1 > 0 - {`{"a": 0}`, `{"a": 0, "b": 1}`, -1}, // longer + {`{"a": 1}`, `{"a": 0}`, 1}, // 1 > 0 + {`{"a": 0}`, `{"a": 0, "b": 1}`, -1}, // longer {`{"a": 0, "c": 2}`, `{"a": 0, "b": 1}`, 1}, // "c" > "b" // nested { left: `{"one": ["x", "y", "z"], "two": { "a": 0, "b": 1}, "three": false, "four": null, "five": " "}`, right: `{"one": ["x", "y", "z"], "two": { "a": 0, "b": 1}, "three": false, "four": null, "five": " "}`, - cmp: 0, + cmp: 0, }, { left: `{"one": ["x", "y"], "two": { "a": 0, "b": 1}, "three": false, "four": null, "five": " "}`, right: `{"one": ["x", "y", "z"], "two": { "a": 0, "b": 1}, "three": false, "four": null, "five": " "}`, - cmp: -1, + cmp: -1, }, } diff --git a/sql/json_value.go b/sql/json_value.go index 3fef785b18..eb3f402d85 100644 --- a/sql/json_value.go +++ b/sql/json_value.go @@ -45,7 +45,6 @@ type SearchableJSONValue interface { Search() (path string, err error) } - type JSONDocument struct { Val interface{} } @@ -69,7 +68,6 @@ func (doc JSONDocument) ToString() (string, error) { return string(bb), err } - var _ SearchableJSONValue = JSONDocument{} func (doc JSONDocument) Contains(path string) (ok bool, err error) { @@ -112,7 +110,6 @@ func ConcatenateJSONValues(vals ...JSONValue) (JSONValue, error) { return JSONDocument{Val: arr}, nil } - // JSON values can be compared using the =, <, <=, >, >=, <>, !=, and <=> operators. BETWEEN IN() GREATEST() LEAST() are // not yet supported with JSON values. // @@ -252,8 +249,8 @@ func compareJSONArray(a []interface{}, b interface{}) (int, error) { func compareJSONObject(a map[string]interface{}, b interface{}) (int, error) { switch b := b.(type) { case - bool, - []interface{}: + bool, + []interface{}: // a is lower precedence return -1, nil @@ -284,9 +281,9 @@ func compareJSONObject(a map[string]interface{}, b interface{}) (int, error) { func compareJSONString(a string, b interface{}) (int, error) { switch b := b.(type) { case - bool, - []interface{}, - map[string]interface{}: + bool, + []interface{}, + map[string]interface{}: // a is lower precedence return -1, nil @@ -302,10 +299,10 @@ func compareJSONString(a string, b interface{}) (int, error) { func compareJSONNumber(a float64, b interface{}) (int, error) { switch b := b.(type) { case - bool, - []interface{}, - map[string]interface{}, - string: + bool, + []interface{}, + map[string]interface{}, + string: // a is lower precedence return -1, nil @@ -324,7 +321,6 @@ func compareJSONNumber(a float64, b interface{}) (int, error) { } } - func jsonObjectKeyIntersection(a, b map[string]interface{}) (ks []string) { for key := range a { if _, ok := b[key]; ok { @@ -339,7 +335,7 @@ func jsonObjectDeterministicOrder(a, b map[string]interface{}, inter []string) ( if len(a) > len(b) { return 1, nil } - if len(a) < len(b){ + if len(a) < len(b) { return -1, nil } diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 1b2f117ff3..9293f92a0b 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -871,6 +871,8 @@ func convertAlterTable(ctx *sql.Context, ddl *sqlparser.DDL) (sql.Node, error) { case namedConstraint: // For simple named constraint drops, fill in a partial foreign key constraint. This will need to be changed if // we ever support other kinds of constraints than foreign keys (e.g. CHECK) + // TODO: this fails if check constraint delete desired but not indicated + // It works for memory engine right now but won't for Dolt return plan.NewAlterDropForeignKey(table, &sql.ForeignKeyConstraint{ Name: c.name, }), nil diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 405cc0dc14..b3dd5db30d 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -875,6 +875,17 @@ var fixtures = map[string]sql.Node{ Enforced: true, }, ), + `ALTER TABLE t1 ADD CONSTRAINT CHECK (a > 0)`: plan.NewAlterAddCheck( + plan.NewUnresolvedTable("t1", ""), + &sql.CheckConstraint{ + Name: "", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }, + ), `ALTER TABLE t1 DROP FOREIGN KEY fk_name`: plan.NewAlterDropForeignKey( plan.NewUnresolvedTable("t1", ""), &sql.ForeignKeyConstraint{ diff --git a/sql/plan/alter_check.go b/sql/plan/alter_check.go index d1a987d21f..272a6c4abe 100644 --- a/sql/plan/alter_check.go +++ b/sql/plan/alter_check.go @@ -78,29 +78,6 @@ func getCheckAlterableTable(t sql.Table) (sql.CheckAlterableTable, error) { } } -func GetCheckTable(node sql.Node) (sql.CheckTable, error) { - switch node := node.(type) { - case sql.CheckTable: - return node, nil - case *ResolvedTable: - return getCheckTable(node.Table), nil - default: - return nil, ErrNoCheckConstraintSupport.New(node.String()) - } -} - -// getCheckTable returns the underlying getCheckTable for the table given, or nil if it isn't a getCheckTable -func getCheckTable(t sql.Table) sql.CheckTable { - switch t := t.(type) { - case sql.CheckTable: - return t - case sql.TableWrapper: - return getCheckTable(t.Underlying()) - default: - return nil - } -} - // Expressions implements the sql.Expressioner interface. func (c *CreateCheck) Expressions() []sql.Expression { return []sql.Expression{c.Check.Expr} diff --git a/sql/plan/create_index.go b/sql/plan/create_index.go index 4cae8be9df..c58b8f60b5 100644 --- a/sql/plan/create_index.go +++ b/sql/plan/create_index.go @@ -137,7 +137,7 @@ func (c *CreateIndex) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error } for _, e := range exprs { - if sql.IsBlob(e.Type()) || sql.IsJSON(e.Type()){ + if sql.IsBlob(e.Type()) || sql.IsJSON(e.Type()) { return nil, ErrExprTypeNotIndexable.New(e, e.Type()) } } diff --git a/sql/testutils.go b/sql/testutils.go index 41330776ce..7c30e60eac 100644 --- a/sql/testutils.go +++ b/sql/testutils.go @@ -30,4 +30,3 @@ func MustJSON(s string) JSONDocument { } return JSONDocument{Val: doc} } -