diff --git a/engine.go b/engine.go index e675b7f41f..7eeb680ef8 100644 --- a/engine.go +++ b/engine.go @@ -176,7 +176,7 @@ func (e *Engine) QueryWithBindings( case *plan.CreateIndex: typ = sql.CreateIndexProcess perm = auth.ReadPerm | auth.WritePerm - case *plan.CreateForeignKey, *plan.DropForeignKey, *plan.AlterIndex, *plan.CreateView, + case *plan.CreateForeignKey, *plan.CreateCheck, *plan.DropForeignKey, *plan.AlterIndex, *plan.CreateView, *plan.DeleteFrom, *plan.DropIndex, *plan.DropView, *plan.InsertInto, *plan.LockTables, *plan.UnlockTables, *plan.Update: @@ -242,7 +242,7 @@ func ResolveDefaults(tableName string, schema []*ColumnWithRawDefault) (sql.Sche return unresolvedSchema, nil } // *plan.CreateTable properly handles resolving default values, so we hijack it - createTable := plan.NewCreateTable(db, tableName, unresolvedSchema, false, nil, nil) + createTable := plan.NewCreateTable(db, tableName, false, &plan.TableSpec{Schema: unresolvedSchema}) analyzed, err := e.Analyzer.Analyze(ctx, createTable, nil) if err != nil { return nil, err diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 551171a118..6e71190027 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -2187,6 +2187,177 @@ func TestDropForeignKeys(t *testing.T, harness Harness) { assert.True(t, sql.ErrTableNotFound.Is(err)) } +func TestCreateCheckConstraints(t *testing.T, harness Harness) { + require := require.New(t) + + e := NewEngine(t, harness) + + RunQuery(t, e, harness, "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER)") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT chk1 CHECK (b > 0)") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0) NOT ENFORCED") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT CHECK (b > 1)") + + db, err := e.Catalog.Database("mydb") + require.NoError(err) + + ctx := NewContext(harness) + table, ok, err := db.GetTableInsensitive(ctx, "t1") + require.NoError(err) + require.True(ok) + + cht, ok := table.(sql.CheckTable) + require.True(ok) + + checks, err := cht.GetChecks(NewContext(harness)) + require.NoError(err) + + con := []sql.CheckConstraint{ + { + Name: "chk1", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("t1.b"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }, + { + Name: "chk2", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("t1.b"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: false, + }, + { + Name: "", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("t1.b"), + expression.NewLiteral(int8(1), sql.Int8), + ), + Enforced: true, + }, + } + cmp1, _ := plan.NewCheckDefinition(&con[0]) + cmp2, _ := plan.NewCheckDefinition(&con[1]) + cmp3, _ := plan.NewCheckDefinition(&con[2]) + expected := []sql.CheckDefinition{*cmp1, *cmp2, *cmp3} + + assert.Equal(t, expected, checks) + + // Some faulty create statements + _, _, err = e.Query(NewContext(harness), "ALTER TABLE t2 ADD CONSTRAINT chk2 CHECK (c > 0)") + require.Error(err) + assert.True(t, sql.ErrTableNotFound.Is(err)) + + _, _, err = e.Query(NewContext(harness), "ALTER TABLE t1 ADD CONSTRAINT chk3 CHECK (c > 0)") + require.Error(err) + assert.True(t, sql.ErrTableColumnNotFound.Is(err)) +} + +func TestChecksOnInsert(t *testing.T, harness Harness) { + + e := NewEngine(t, harness) + + RunQuery(t, e, harness, "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER)") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT chk1 CHECK (b > 1) NOT ENFORCED") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0)") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT chk3 CHECK (b > -1) ENFORCED") + RunQuery(t, e, harness, "INSERT INTO t1 VALUES (1,1)") + TestQuery(t, harness, e, `SELECT * FROM t1`, + []sql.Row{ + {1, 1}, + }, + nil, + ) + RunQuery(t, e, harness, "INSERT INTO t1 VALUES (0,0)") + TestQuery(t, harness, e, `SELECT * FROM t1`, + []sql.Row{ + {1, 1}, + }, + nil, + ) + + ctx := NewContext(harness) + require.True(t, len(ctx.Warnings()) > 0) + + expectedCode := 3819 + condition := false + for _, warning := range ctx.Warnings() { + if warning.Code == expectedCode { + condition = true + break + } + } + + require.True(t, condition) + +} + +func TestDisallowedCheckConstraints(t *testing.T, harness Harness) { + require := require.New(t) + e := NewEngine(t, harness) + var err error + + RunQuery(t, e, harness, "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER)") + + // functions, UDFs, procedures + _, _, err = e.Query(NewContext(harness), "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (current_user = \"root@\")") + require.Error(err) + assert.True(t, sql.ErrInvalidConstraintFunctionsNotSupported.Is(err)) + + _, _, err = e.Query(NewContext(harness), "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK ((select count(*) from t1) = 0)") + require.Error(err) + assert.True(t, sql.ErrInvalidConstraintSubqueryNotSupported.Is(err)) +} + +func TestDropCheckConstraints(t *testing.T, harness Harness) { + require := require.New(t) + + e := NewEngine(t, harness) + + RunQuery(t, e, harness, "CREATE TABLE t1 (a INTEGER PRIMARY KEY, b INTEGER, c integer)") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT chk1 CHECK (a > 0)") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT chk2 CHECK (b > 0) NOT ENFORCED") + RunQuery(t, e, harness, "ALTER TABLE t1 ADD CONSTRAINT chk3 CHECK (c > 0)") + RunQuery(t, e, harness, "ALTER TABLE t1 DROP CONSTRAINT chk2") + RunQuery(t, e, harness, "ALTER TABLE t1 DROP CHECK chk1") + + db, err := e.Catalog.Database("mydb") + require.NoError(err) + + ctx := NewContext(harness) + table, ok, err := db.GetTableInsensitive(ctx, "t1") + require.NoError(err) + require.True(ok) + + cht, ok := table.(sql.CheckTable) + require.True(ok) + + checks, err := cht.GetChecks(NewContext(harness)) + require.NoError(err) + + con := sql.CheckConstraint{ + Name: "chk3", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("t1.c"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + } + cmp, _ := plan.NewCheckDefinition(&con) + expected := []sql.CheckDefinition{*cmp} + + assert.Equal(t, expected, checks) + + // Some faulty create statements + _, _, err = e.Query(NewContext(harness), "ALTER TABLE t2 DROP CONSTRAINT chk2") + require.Error(err) + assert.True(t, sql.ErrTableNotFound.Is(err)) + + _, _, err = e.Query(NewContext(harness), "ALTER TABLE t1 DROP CHECK chk3") + require.NoError(err) +} + func TestNaturalJoin(t *testing.T, harness Harness) { require := require.New(t) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 2195cc0acc..fa9146e27a 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -332,6 +332,22 @@ func TestDropForeignKeys(t *testing.T) { enginetest.TestDropForeignKeys(t, enginetest.NewDefaultMemoryHarness()) } +func TestCreateCheckConstraints(t *testing.T) { + enginetest.TestCreateCheckConstraints(t, enginetest.NewDefaultMemoryHarness()) +} + +func TestChecksOnInsert(t *testing.T) { + enginetest.TestChecksOnInsert(t, enginetest.NewDefaultMemoryHarness()) +} + +func TestTestDisallowedCheckConstraints(t *testing.T) { + enginetest.TestDisallowedCheckConstraints(t, enginetest.NewDefaultMemoryHarness()) +} + +func TestDropCheckConstraints(t *testing.T) { + enginetest.TestDropCheckConstraints(t, enginetest.NewDefaultMemoryHarness()) +} + func TestExplode(t *testing.T) { enginetest.TestExplode(t, enginetest.NewDefaultMemoryHarness()) } diff --git a/go.mod b/go.mod index 7fac09dc7c..1f01113e9b 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ require ( github.com/VividCortex/gohistogram v1.0.0 // indirect github.com/cespare/xxhash v1.1.0 github.com/dolthub/sqllogictest/go v0.0.0-20201105013724-5123fc66e12c - github.com/dolthub/vitess v0.0.0-20210316150645-9bf20be78424 + github.com/dolthub/vitess v0.0.0-20210317175908-b2663b2c4d9c github.com/fastly/go-utils v0.0.0-20180712184237-d95a45783239 // indirect github.com/go-kit/kit v0.9.0 github.com/go-sql-driver/mysql v1.4.1 @@ -25,6 +25,7 @@ require ( github.com/src-d/go-oniguruma v1.1.0 github.com/stretchr/testify v1.4.0 github.com/tebeka/strftime v0.1.4 // indirect + google.golang.org/genproto v0.0.0-20190926190326-7ee9db18f195 google.golang.org/grpc v1.27.0 // indirect gopkg.in/src-d/go-errors.v1 v1.0.0 ) diff --git a/go.sum b/go.sum index 3b585820b2..07c028ddb2 100755 --- a/go.sum +++ b/go.sum @@ -17,6 +17,8 @@ github.com/dolthub/sqllogictest/go v0.0.0-20201105013724-5123fc66e12c h1:ZIo6IOX github.com/dolthub/sqllogictest/go v0.0.0-20201105013724-5123fc66e12c/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= github.com/dolthub/vitess v0.0.0-20210316150645-9bf20be78424 h1:QGVknlmy9WZoTOTyC50W+HCho7t3/CrpMxJZm+tQ25A= github.com/dolthub/vitess v0.0.0-20210316150645-9bf20be78424/go.mod h1:hUE8oSk2H5JZnvtlLBhJPYC8WZCA5AoSntdLTcBvdBM= +github.com/dolthub/vitess v0.0.0-20210317175908-b2663b2c4d9c h1:u/hKK5EBh+FJlZlcMjSYLB9qXTG+pnyaa6Vww8sSHvY= +github.com/dolthub/vitess v0.0.0-20210317175908-b2663b2c4d9c/go.mod h1:hUE8oSk2H5JZnvtlLBhJPYC8WZCA5AoSntdLTcBvdBM= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fastly/go-utils v0.0.0-20180712184237-d95a45783239 h1:Ghm4eQYC0nEPnSJdVkTrXpu9KtoVCSo1hg7mtI7G9KU= diff --git a/memory/table.go b/memory/table.go index 79152aa31d..4fe63fa1f0 100644 --- a/memory/table.go +++ b/memory/table.go @@ -37,6 +37,7 @@ type Table struct { columns []int indexes map[string]sql.Index foreignKeys []sql.ForeignKeyConstraint + checks []sql.CheckDefinition pkIndexesEnabled bool // Data storage @@ -66,6 +67,8 @@ var _ sql.IndexAlterableTable = (*Table)(nil) var _ sql.IndexedTable = (*Table)(nil) var _ sql.ForeignKeyAlterableTable = (*Table)(nil) var _ sql.ForeignKeyTable = (*Table)(nil) +var _ sql.CheckAlterableTable = (*Table)(nil) +var _ sql.CheckTable = (*Table)(nil) var _ sql.AutoIncrementTable = (*Table)(nil) var _ sql.StatisticsTable = (*Table)(nil) @@ -1103,6 +1106,12 @@ func (t *Table) CreateForeignKey(_ *sql.Context, fkName string, columns []string } } + for _, key := range t.checks { + if key.Name == fkName { + return fmt.Errorf("constraint %s already exists", fkName) + } + } + t.foreignKeys = append(t.foreignKeys, sql.ForeignKeyConstraint{ Name: fkName, Columns: columns, @@ -1117,15 +1126,54 @@ func (t *Table) CreateForeignKey(_ *sql.Context, fkName string, columns []string // DropForeignKey implements sql.ForeignKeyAlterableTable. func (t *Table) DropForeignKey(ctx *sql.Context, fkName string) error { + return t.dropConstraint(ctx, fkName) +} + +func (t *Table) dropConstraint(ctx *sql.Context, name string) error { for i, key := range t.foreignKeys { - if key.Name == fkName { + if key.Name == name { t.foreignKeys = append(t.foreignKeys[:i], t.foreignKeys[i+1:]...) return nil } } + for i, key := range t.checks { + if key.Name == name { + t.checks = append(t.checks[:i], t.checks[i+1:]...) + return nil + } + } return nil } +// GetChecks implements sql.CheckTable +func (t *Table) GetChecks(_ *sql.Context) ([]sql.CheckDefinition, error) { + return t.checks, nil +} + +// CreateCheck implements sql.CheckAlterableTable +func (t *Table) CreateCheck(_ *sql.Context, check *sql.CheckDefinition) error { + for _, key := range t.checks { + if key.Name == check.Name { + return fmt.Errorf("constraint %s already exists", check.Name) + } + } + + for _, key := range t.foreignKeys { + if key.Name == check.Name { + return fmt.Errorf("constraint %s already exists", check.Name) + } + } + + t.checks = append(t.checks, *check) + + return nil +} + +// func (t *Table) DropCheck(ctx *sql.Context, chName string) error {} implements sql.CheckAlterableTable. +func (t *Table) DropCheck(ctx *sql.Context, chName string) error { + return t.dropConstraint(ctx, chName) +} + func (t *Table) createIndex(name string, columns []sql.IndexColumn, constraint sql.IndexConstraint, comment string) (sql.Index, error) { if t.indexes[name] != nil { // TODO: extract a standard error type for this diff --git a/sql/analyzer/check_constraints.go b/sql/analyzer/check_constraints.go new file mode 100644 index 0000000000..f6440a4a07 --- /dev/null +++ b/sql/analyzer/check_constraints.go @@ -0,0 +1,172 @@ +// Copyright 2021 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package analyzer + +import ( + "strings" + + "github.com/dolthub/vitess/go/vt/sqlparser" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/parse" + "github.com/dolthub/go-mysql-server/sql/plan" +) + +// validateCreateCheck handles CreateCheck nodes, resolving references to "old" and "new" table references in +// the check body. Also validates that these old and new references are being used appropriately -- they are only +// valid for certain kinds of checks and certain statements. +func validateCreateCheck(ctx *sql.Context, a *Analyzer, node sql.Node, scope *Scope) (sql.Node, error) { + ct, ok := node.(*plan.CreateCheck) + if !ok { + return node, nil + } + + chAlterable, ok := ct.UnaryNode.Child.(sql.Table) + if !ok { + return node, nil + } + + checkCols := make(map[string]bool) + for _, col := range chAlterable.Schema() { + checkCols[col.Name] = true + } + + var err error + plan.InspectExpressionsWithNode(node, func(n sql.Node, e sql.Expression) bool { + if _, ok := n.(*plan.CreateCheck); !ok { + return true + } + + // Make sure that all columns are valid, in the table, and there are no duplicates + switch expr := e.(type) { + case *deferredColumn: + if _, ok := checkCols[expr.Name()]; !ok { + err = sql.ErrTableColumnNotFound.New(expr.Name()) + return false + } + case *expression.GetField: + if _, ok := checkCols[expr.Name()]; !ok { + err = sql.ErrTableColumnNotFound.New(expr.Name()) + return false + } + case *expression.UnresolvedFunction: + err = sql.ErrInvalidConstraintFunctionsNotSupported.New(expr.String()) + return false + case *plan.Subquery: + err = sql.ErrInvalidConstraintSubqueryNotSupported.New(expr.String()) + return false + } + return true + }) + + if err != nil { + return nil, err + } + + return ct, nil +} + +// loadChecks loads any checks that are required for a plan node to operate properly (except for nodes dealing with +// check execution). +func loadChecks(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.Node, error) { + span, _ := ctx.Span("loadChecks") + defer span.Finish() + + return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) { + switch node := n.(type) { + case *plan.InsertInto: + nc := *node + + rtable, ok := nc.Destination.(*plan.ResolvedTable) + if !ok { + return node, nil + } + + table, ok := rtable.Table.(sql.CheckAlterableTable) + if !ok { + return node, nil + } + + loadedChecks, err := loadChecksFromTable(ctx, table) + if err != nil { + return nil, err + } + + if len(loadedChecks) != 0 { + nc.Checks = loadedChecks + } else { + nc.Checks = make([]sql.Expression, 0) + } + + return &nc, nil + // TODO : reimplement modify column nodes and throw errors here to protect check columns + //case *plan.DropColumn: + //case *plan.RenameColumn: + //case *plan.ModifyColumn: + default: + return node, nil + } + }) +} + +func loadChecksFromTable(ctx *sql.Context, table sql.Table) ([]sql.Expression, error) { + var loadedChecks []sql.Expression + if checkTable, ok := table.(sql.CheckTable); ok { + checks, err := checkTable.GetChecks(ctx) + if err != nil { + return nil, err + } + for _, ch := range checks { + constraint, err := convertCheckDefToConstraint(ctx, &ch) + if err != nil { + return nil, err + } + if constraint.Enforced { + loadedChecks = append(loadedChecks, constraint.Expr) + } + } + } + return loadedChecks, nil +} + +func convertCheckDefToConstraint(ctx *sql.Context, check *sql.CheckDefinition) (*sql.CheckConstraint, error) { + parsed, err := sqlparser.ParseStrictDDL(check.AlterStatement) + if err != nil { + return nil, err + } + + ddl, ok := parsed.(*sqlparser.DDL) + if !ok || ddl.ConstraintAction == "" || len(ddl.TableSpec.Constraints) != 1 || strings.ToLower(ddl.ConstraintAction) != sqlparser.AddStr { + return nil, parse.ErrInvalidCheckConstraint.New(check.AlterStatement) + } + + parsedConstraint := ddl.TableSpec.Constraints[0] + chConstraint, ok := parsedConstraint.Details.(*sqlparser.CheckConstraintDefinition) + if !ok || chConstraint.Expr == nil { + return nil, parse.ErrInvalidCheckConstraint.New(check.AlterStatement) + } + + c, err := parse.ExprToExpression(ctx, chConstraint.Expr) + if err != nil { + return nil, err + } + + return &sql.CheckConstraint{ + Name: parsedConstraint.Name, + Expr: c, + Enforced: chConstraint.Enforced, + }, nil +} diff --git a/sql/analyzer/join_search.go b/sql/analyzer/join_search.go index af36c73d20..a512c07cde 100755 --- a/sql/analyzer/join_search.go +++ b/sql/analyzer/join_search.go @@ -207,7 +207,8 @@ func (jo *joinOrderNode) applyJoinHintTables(tables []string) ([]string, error) assigned := make(map[int]struct{}) order := []int{} remaining := tables -START: for { +START: + for { var i int for i = range jo.commutes { if _, ok := assigned[i]; ok { @@ -368,7 +369,7 @@ func (jo *joinOrderNode) visitJoinSearchNodes(cb func(n *joinSearchNode) bool) { } else if jo.left != nil { stop := false jo.left.visitJoinSearchNodes(func(l *joinSearchNode) bool { - jo.right.visitJoinSearchNodes(func (r *joinSearchNode) bool { + jo.right.visitJoinSearchNodes(func(r *joinSearchNode) bool { if !cb(&joinSearchNode{left: l, right: r}) { stop = true } @@ -439,7 +440,6 @@ func newJoinOrderNode(node sql.Node) *joinOrderNode { } } - // A joinSearchNode is a simplified type representing a join tree node, which is either an internal node (a join) or a // leaf node (a table). The top level node in a join tree is always an internal node. Every internal node has both a // left and a right child. diff --git a/sql/analyzer/join_search_test.go b/sql/analyzer/join_search_test.go index 78b2803209..ac18e99af2 100755 --- a/sql/analyzer/join_search_test.go +++ b/sql/analyzer/join_search_test.go @@ -532,9 +532,9 @@ func TestBuildJoinTree(t *testing.T) { }, }, { - name: "explicit subtree, A((EB)(DC))", + name: "explicit subtree, A((EB)(DC))", tableOrder: tableOrder("A", &joinOrderNode{ - left: tableOrder("E", "B"), + left: tableOrder("E", "B"), right: tableOrder("D", "C"), }), joinConds: []*joinCond{ @@ -562,9 +562,9 @@ func TestBuildJoinTree(t *testing.T) { }, }, { - name: "explicit subtree, A((EB)(D))C", + name: "explicit subtree, A((EB)(D))C", tableOrder: tableOrder("A", &joinOrderNode{ - left: tableOrder("E", "B"), + left: tableOrder("E", "B"), right: tableOrder("D"), }, "C"), joinConds: []*joinCond{ diff --git a/sql/analyzer/resolve_create_like.go b/sql/analyzer/resolve_create_like.go index 27418fe921..3fe186a3de 100644 --- a/sql/analyzer/resolve_create_like.go +++ b/sql/analyzer/resolve_create_like.go @@ -68,5 +68,11 @@ func resolveCreateLike(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) tempCol.Source = planCreate.Name() newSch[i] = &tempCol } - return plan.NewCreateTable(planCreate.Database(), planCreate.Name(), newSch, planCreate.IfNotExists(), idxDefs, nil), nil + + tableSpec := &plan.TableSpec{ + Schema: newSch, + IdxDefs: idxDefs, + } + + return plan.NewCreateTable(planCreate.Database(), planCreate.Name(), planCreate.IfNotExists(), tableSpec), nil } diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index 040f8f39ee..9180f4e53e 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -25,6 +25,7 @@ var OnceBeforeDefault = []Rule{ {"resolve_views", resolveViews}, {"resolve_common_table_expressions", resolveCommonTableExpressions}, {"resolve_tables", resolveTables}, + {"load_check_constraints", loadChecks}, {"resolve_set_variables", resolveSetVariables}, {"resolve_create_like", resolveCreateLike}, {"resolve_subqueries", resolveSubqueries}, @@ -44,6 +45,7 @@ var DefaultRules = []Rule{ {"pushdown_groupby_aliases", pushdownGroupByAliases}, {"qualify_columns", qualifyColumns}, {"resolve_columns", resolveColumns}, + {"validate_check_constraint", validateCreateCheck}, {"resolve_bareword_set_variables", resolveBarewordSetVariables}, {"resolve_database", resolveDatabase}, {"expand_stars", expandStars}, diff --git a/sql/core.go b/sql/core.go index fd20d3168f..2067cfdf58 100644 --- a/sql/core.go +++ b/sql/core.go @@ -225,6 +225,20 @@ type ForeignKeyConstraint struct { OnDelete ForeignKeyReferenceOption } +// CheckDefinition defines a trigger. Integrators are not expected to parse or understand the trigger definitions, +// but must store and return them when asked. +type CheckDefinition struct { + Name string // The name of this check. Check names in a database are unique. + AlterStatement string // Reference for expression body +} + +// CheckConstraint declares a boolean-eval constraint. +type CheckConstraint struct { + Name string + Expr Expression + Enforced bool +} + // TableWrapper is a node that wraps the real table. This is needed because // wrappers cannot implement some methods the table may implement. type TableWrapper interface { @@ -351,6 +365,23 @@ type ForeignKeyAlterableTable interface { DropForeignKey(ctx *Context, fkName string) error } +// CheckTable is a table that can declare its check constraints. +type CheckTable interface { + Table + // GetChecks returns the check constraints on this table. + GetChecks(ctx *Context) ([]CheckDefinition, error) +} + +// CheckAlterableTable represents a table that supports check constraints. +type CheckAlterableTable interface { + Table + // CreateCheck creates an check constraint for this table, using the provided parameters. + // Returns an error if the constraint name already exists. + CreateCheck(ctx *Context, check *CheckDefinition) error + // DropCheck removes a check constraint from the database. + DropCheck(ctx *Context, chName string) error +} + // InsertableTable is a table that can process insertion of new rows. type InsertableTable interface { Table diff --git a/sql/errors.go b/sql/errors.go index ad32705ec3..e28eed909f 100755 --- a/sql/errors.go +++ b/sql/errors.go @@ -169,6 +169,12 @@ var ( // ErrCannotDropDatabaseDoesntExist is returned when a DROP DATABASE is callend when a table is dropped that doesn't exist. ErrCannotDropDatabaseDoesntExist = errors.NewKind("can't drop database %s; database doesn't exist") + // ErrInvalidConstraintFunctionsNotSupported is returned when a CONSTRAINT CHECK is called with a sub-function expression. + ErrInvalidConstraintFunctionsNotSupported = errors.NewKind("Invalid constraint expression, functions not supported: %s") + + // ErrInvalidConstraintSubqueryNotSupported is returned when a CONSTRAINT CHECK is called with a sub-query expression. + ErrInvalidConstraintSubqueryNotSupported = errors.NewKind("Invalid constraint expression, sub-queries not supported: %s") + // ErrColumnCountMismatch is returned when a view, derived table or common table expression has a declared column // list with a different number of columns than the schema of the table. ErrColumnCountMismatch = errors.NewKind("In definition of view, derived table or common table expression, SELECT list and column names list have different column counts") diff --git a/sql/expression/function/aggregation/json_agg_test.go b/sql/expression/function/aggregation/json_agg_test.go index b91b957801..450f07e74c 100644 --- a/sql/expression/function/aggregation/json_agg_test.go +++ b/sql/expression/function/aggregation/json_agg_test.go @@ -69,7 +69,7 @@ func TestJsonArrayAgg_Empty(t *testing.T) { v, err := j.Eval(ctx, b) assert.NoError(err) - assert.Equal(sql.JSONDocument{Val:[]interface{}(nil)}, v) + assert.Equal(sql.JSONDocument{Val: []interface{}(nil)}, v) } func TestJsonArrayAgg_JSON(t *testing.T) { diff --git a/sql/expression/function/json_unsupported.go b/sql/expression/function/json_unsupported.go index d4f3c1456b..f7b393434c 100644 --- a/sql/expression/function/json_unsupported.go +++ b/sql/expression/function/json_unsupported.go @@ -27,7 +27,6 @@ var ErrUnsupportedJSONFunction = errors.NewKind("unsupported JSON function: %s") // JSON search functions // /////////////////////////// - // JSON_CONTAINS_PATH(json_doc, one_or_all, path[, path] ...) // // JSONContainsPath Returns 0 or 1 to indicate whether a JSON document contains data at a given path or paths. Returns @@ -185,12 +184,10 @@ func (j JSONValue) FunctionName() string { // https://dev.mysql.com/doc/refman/8.0/en/json-search-functions.html#operator_member-of // TODO(andy): relocate - ///////////////////////////// // JSON creation functions // ///////////////////////////// - // JSON_ARRAY([val[, val] ...]) // // JSONArray Evaluates a (possibly empty) list of values and returns a JSON array containing those values. @@ -260,12 +257,10 @@ func (j JSONQuote) FunctionName() string { return "json_quote" } - ///////////////////////////////// // JSON modification functions // ///////////////////////////////// - // JSON_ARRAY_APPEND(json_doc, path, val[, path, val] ...) // // JSONArrayAppend Appends values to the end of the indicated arrays within a JSON document and returns the result. @@ -520,12 +515,10 @@ func (j JSONSet) FunctionName() string { return "json_set" } - ////////////////////////////// // JSON attribute functions // ////////////////////////////// - // JSON_DEPTH(json_doc) // // JSONDepth Returns the maximum depth of a JSON document. Returns NULL if the argument is NULL. An error occurs if the @@ -621,12 +614,10 @@ func (j JSONValid) FunctionName() string { return "json_valid" } - ////////////////////////// // JSON table functions // ////////////////////////// - // JSON_TABLE(expr, path COLUMNS (column_list) [AS] alias) // // JSONTable Extracts data from a JSON document and returns it as a relational table having the specified columns. @@ -649,12 +640,10 @@ func (j JSONTable) FunctionName() string { return "json_table" } - /////////////////////////////// // JSON validation functions // /////////////////////////////// - // JSON_SCHEMA_VALID(schema,document) // // JSONSchemaValid Validates a JSON document against a JSON schema. Both schema and document are required. The schema @@ -711,12 +700,10 @@ func (j JSONSchemaValidationReport) FunctionName() string { return "json_schema_validation_report" } - //////////////////////////// // JSON utility functions // //////////////////////////// - // JSON_PRETTY(json_val) // // JSONPretty Provides pretty-printing of JSON values similar to that implemented in PHP and by other languages and diff --git a/sql/json.go b/sql/json.go index 05451d1432..a3de7622dd 100644 --- a/sql/json.go +++ b/sql/json.go @@ -31,7 +31,6 @@ type JsonType interface { type jsonType struct{} - // Compare implements Type interface. func (t jsonType) Compare(a interface{}, b interface{}) (int, error) { var err error diff --git a/sql/json_test.go b/sql/json_test.go index f70b1221af..248a7101b8 100644 --- a/sql/json_test.go +++ b/sql/json_test.go @@ -32,7 +32,7 @@ func TestJsonCompare(t *testing.T) { {`true`, `[0]`, 1}, {`[0]`, `{"a": 0}`, 1}, {`{"a": 0}`, `"a"`, 1}, - { `"a"`, `0`, 1}, + {`"a"`, `0`, 1}, {`0`, `null`, 1}, // null @@ -63,20 +63,20 @@ func TestJsonCompare(t *testing.T) { // objects {`{"a": 0}`, `{"a": 0}`, 0}, // deterministic object ordering with arbitrary rules - {`{"a": 1}`, `{"a": 0}`, 1}, // 1 > 0 - {`{"a": 0}`, `{"a": 0, "b": 1}`, -1}, // longer + {`{"a": 1}`, `{"a": 0}`, 1}, // 1 > 0 + {`{"a": 0}`, `{"a": 0, "b": 1}`, -1}, // longer {`{"a": 0, "c": 2}`, `{"a": 0, "b": 1}`, 1}, // "c" > "b" // nested { left: `{"one": ["x", "y", "z"], "two": { "a": 0, "b": 1}, "three": false, "four": null, "five": " "}`, right: `{"one": ["x", "y", "z"], "two": { "a": 0, "b": 1}, "three": false, "four": null, "five": " "}`, - cmp: 0, + cmp: 0, }, { left: `{"one": ["x", "y"], "two": { "a": 0, "b": 1}, "three": false, "four": null, "five": " "}`, right: `{"one": ["x", "y", "z"], "two": { "a": 0, "b": 1}, "three": false, "four": null, "five": " "}`, - cmp: -1, + cmp: -1, }, } diff --git a/sql/json_value.go b/sql/json_value.go index 3fef785b18..eb3f402d85 100644 --- a/sql/json_value.go +++ b/sql/json_value.go @@ -45,7 +45,6 @@ type SearchableJSONValue interface { Search() (path string, err error) } - type JSONDocument struct { Val interface{} } @@ -69,7 +68,6 @@ func (doc JSONDocument) ToString() (string, error) { return string(bb), err } - var _ SearchableJSONValue = JSONDocument{} func (doc JSONDocument) Contains(path string) (ok bool, err error) { @@ -112,7 +110,6 @@ func ConcatenateJSONValues(vals ...JSONValue) (JSONValue, error) { return JSONDocument{Val: arr}, nil } - // JSON values can be compared using the =, <, <=, >, >=, <>, !=, and <=> operators. BETWEEN IN() GREATEST() LEAST() are // not yet supported with JSON values. // @@ -252,8 +249,8 @@ func compareJSONArray(a []interface{}, b interface{}) (int, error) { func compareJSONObject(a map[string]interface{}, b interface{}) (int, error) { switch b := b.(type) { case - bool, - []interface{}: + bool, + []interface{}: // a is lower precedence return -1, nil @@ -284,9 +281,9 @@ func compareJSONObject(a map[string]interface{}, b interface{}) (int, error) { func compareJSONString(a string, b interface{}) (int, error) { switch b := b.(type) { case - bool, - []interface{}, - map[string]interface{}: + bool, + []interface{}, + map[string]interface{}: // a is lower precedence return -1, nil @@ -302,10 +299,10 @@ func compareJSONString(a string, b interface{}) (int, error) { func compareJSONNumber(a float64, b interface{}) (int, error) { switch b := b.(type) { case - bool, - []interface{}, - map[string]interface{}, - string: + bool, + []interface{}, + map[string]interface{}, + string: // a is lower precedence return -1, nil @@ -324,7 +321,6 @@ func compareJSONNumber(a float64, b interface{}) (int, error) { } } - func jsonObjectKeyIntersection(a, b map[string]interface{}) (ks []string) { for key := range a { if _, ok := b[key]; ok { @@ -339,7 +335,7 @@ func jsonObjectDeterministicOrder(a, b map[string]interface{}, inter []string) ( if len(a) > len(b) { return 1, nil } - if len(a) < len(b){ + if len(a) < len(b) { return -1, nil } diff --git a/sql/parse/parse.go b/sql/parse/parse.go index 993c1d0f02..9293f92a0b 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -56,6 +56,8 @@ var ( ErrInvalidAutoIncCols = errors.NewKind("there can be only one auto_increment column and it must be defined as a key") ErrUnknownConstraintDefinition = errors.NewKind("unknown constraint definition: %s, %T") + + ErrInvalidCheckConstraint = errors.NewKind("invalid constraint definition: %s") ) var ( @@ -242,7 +244,7 @@ func convertIfConditional(ctx *sql.Context, n sqlparser.IfStatementCondition) (* if err != nil { return nil, err } - condition, err := exprToExpression(ctx, n.Expr) + condition, err := ExprToExpression(ctx, n.Expr) if err != nil { return nil, err } @@ -360,7 +362,7 @@ func convertShow(ctx *sql.Context, s *sqlparser.Show, query string) (sql.Node, e if s.ShowTablesOpt.Filter != nil { if s.ShowTablesOpt.Filter.Filter != nil { var err error - filter, err = exprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) + filter, err = ExprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) if err != nil { return nil, err } @@ -385,7 +387,7 @@ func convertShow(ctx *sql.Context, s *sqlparser.Show, query string) (sql.Node, e if s.Filter != nil { if s.Filter.Filter != nil { var err error - filter, err = exprToExpression(ctx, s.Filter.Filter) + filter, err = ExprToExpression(ctx, s.Filter.Filter) if err != nil { return nil, err } @@ -417,7 +419,7 @@ func convertShow(ctx *sql.Context, s *sqlparser.Show, query string) (sql.Node, e if s.ShowTablesOpt.Filter != nil { if s.ShowTablesOpt.Filter.Filter != nil { var err error - filter, err = exprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) + filter, err = ExprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) if err != nil { return nil, err } @@ -431,7 +433,7 @@ func convertShow(ctx *sql.Context, s *sqlparser.Show, query string) (sql.Node, e if s.ShowTablesOpt.AsOf != nil { var err error - asOf, err = exprToExpression(ctx, s.ShowTablesOpt.AsOf) + asOf, err = ExprToExpression(ctx, s.ShowTablesOpt.AsOf) if err != nil { return nil, err } @@ -467,7 +469,7 @@ func convertShow(ctx *sql.Context, s *sqlparser.Show, query string) (sql.Node, e } if s.ShowTablesOpt.Filter.Filter != nil { - filter, err := exprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) + filter, err := ExprToExpression(ctx, s.ShowTablesOpt.Filter.Filter) if err != nil { return nil, err } @@ -490,7 +492,7 @@ func convertShow(ctx *sql.Context, s *sqlparser.Show, query string) (sql.Node, e } if s.ShowCollationFilterOpt != nil { - filterExpr, err := exprToExpression(ctx, *s.ShowCollationFilterOpt) + filterExpr, err := ExprToExpression(ctx, *s.ShowCollationFilterOpt) if err != nil { return nil, err } @@ -504,7 +506,7 @@ func convertShow(ctx *sql.Context, s *sqlparser.Show, query string) (sql.Node, e if s.Filter != nil { if s.Filter.Filter != nil { var err error - filter, err = exprToExpression(ctx, s.Filter.Filter) + filter, err = ExprToExpression(ctx, s.Filter.Filter) if err != nil { return nil, err } @@ -809,7 +811,7 @@ func convertCreateProcedure(ctx *sql.Context, query string, c *sqlparser.DDL) (s func convertCall(ctx *sql.Context, c *sqlparser.Call) (sql.Node, error) { params := make([]sql.Expression, len(c.Params)) for i, param := range c.Params { - expr, err := exprToExpression(ctx, param) + expr, err := ExprToExpression(ctx, param) if err != nil { return nil, err } @@ -847,21 +849,30 @@ func convertAlterTable(ctx *sql.Context, ddl *sqlparser.DDL) (sql.Node, error) { } switch strings.ToLower(ddl.ConstraintAction) { case sqlparser.AddStr: - if fkConstraint, ok := parsedConstraint.(*sql.ForeignKeyConstraint); ok { + switch c := parsedConstraint.(type) { + case *sql.ForeignKeyConstraint: return plan.NewAlterAddForeignKey( table, - plan.NewUnresolvedTable(fkConstraint.ReferencedTable, ddl.Table.Qualifier.String()), - fkConstraint), nil - } else { + plan.NewUnresolvedTable(c.ReferencedTable, ddl.Table.Qualifier.String()), + c), nil + + case *sql.CheckConstraint: + return plan.NewAlterAddCheck(table, c), nil + default: return nil, ErrUnsupportedFeature.New(sqlparser.String(ddl)) + } case sqlparser.DropStr: switch c := parsedConstraint.(type) { case *sql.ForeignKeyConstraint: return plan.NewAlterDropForeignKey(table, c), nil + case *sql.CheckConstraint: + return plan.NewAlterDropCheck(table, c), nil case namedConstraint: // For simple named constraint drops, fill in a partial foreign key constraint. This will need to be changed if // we ever support other kinds of constraints than foreign keys (e.g. CHECK) + // TODO: this fails if check constraint delete desired but not indicated + // It works for memory engine right now but won't for Dolt return plan.NewAlterDropForeignKey(table, &sql.ForeignKeyConstraint{ Name: c.name, }), nil @@ -1052,6 +1063,7 @@ func convertCreateTable(ctx *sql.Context, c *sqlparser.DDL) (sql.Node, error) { } var fkDefs []*sql.ForeignKeyConstraint + var chDefs []*sql.CheckConstraint for _, unknownConstraint := range c.TableSpec.Constraints { parsedConstraint, err := convertConstraintDefinition(ctx, unknownConstraint) if err != nil { @@ -1060,6 +1072,8 @@ func convertCreateTable(ctx *sql.Context, c *sqlparser.DDL) (sql.Node, error) { switch constraint := parsedConstraint.(type) { case *sql.ForeignKeyConstraint: fkDefs = append(fkDefs, constraint) + case *sql.CheckConstraint: + chDefs = append(chDefs, constraint) default: return nil, ErrUnknownConstraintDefinition.New(unknownConstraint.Name, unknownConstraint) } @@ -1128,8 +1142,15 @@ func convertCreateTable(ctx *sql.Context, c *sqlparser.DDL) (sql.Node, error) { } } + tableSpec := &plan.TableSpec{ + Schema: schema, + IdxDefs: idxDefs, + FkDefs: fkDefs, + ChDefs: chDefs, + } + return plan.NewCreateTable( - sql.UnresolvedDatabase(""), c.Table.Name.String(), schema, c.IfNotExists, idxDefs, fkDefs), nil + sql.UnresolvedDatabase(""), c.Table.Name.String(), c.IfNotExists, tableSpec), nil } type namedConstraint struct { @@ -1154,6 +1175,21 @@ func convertConstraintDefinition(ctx *sql.Context, cd *sqlparser.ConstraintDefin OnUpdate: convertReferenceAction(fkConstraint.OnUpdate), OnDelete: convertReferenceAction(fkConstraint.OnDelete), }, nil + } else if chConstraint, ok := cd.Details.(*sqlparser.CheckConstraintDefinition); ok { + var c sql.Expression + var err error + if chConstraint.Expr != nil { + c, err = ExprToExpression(ctx, chConstraint.Expr) + if err != nil { + return nil, err + } + } + + return &sql.CheckConstraint{ + Name: cd.Name, + Expr: c, + Enforced: chConstraint.Enforced, + }, nil } else if len(cd.Name) > 0 && cd.Details == nil { return namedConstraint{cd.Name}, nil } @@ -1226,6 +1262,7 @@ func convertInsert(ctx *sql.Context, i *sqlparser.Insert) (sql.Node, error) { isReplace, columnsToStrings(i.Columns), onDupExprs, + nil, ), nil } @@ -1331,6 +1368,7 @@ func convertLoad(ctx *sql.Context, d *sqlparser.Load) (sql.Node, error) { false, ld.ColumnNames, nil, + nil, ), nil } @@ -1432,7 +1470,7 @@ func columnDefinitionToColumn(ctx *sql.Context, cd *sqlparser.ColumnDefinition, var defaultVal *sql.ColumnDefaultValue if cd.Type.Default != nil { - parsedExpr, err := exprToExpression(ctx, cd.Type.Default) + parsedExpr, err := ExprToExpression(ctx, cd.Type.Default) if err != nil { return nil, err } @@ -1489,7 +1527,7 @@ func valuesToValues(ctx *sql.Context, v sqlparser.Values) (sql.Node, error) { exprs := make([]sql.Expression, len(vt)) exprTuples[i] = exprs for j, e := range vt { - expr, err := exprToExpression(ctx, e) + expr, err := ExprToExpression(ctx, e) if err != nil { return nil, err } @@ -1544,7 +1582,7 @@ func tableExprToTable( case sqlparser.TableName: var node *plan.UnresolvedTable if t.AsOf != nil { - asOfExpr, err := exprToExpression(ctx, t.AsOf.Time) + asOfExpr, err := ExprToExpression(ctx, t.AsOf.Time) if err != nil { return nil, err } @@ -1597,7 +1635,7 @@ func tableExprToTable( return plan.NewCrossJoin(left, right), nil } - cond, err := exprToExpression(ctx, t.Condition.On) + cond, err := ExprToExpression(ctx, t.Condition.On) if err != nil { return nil, err } @@ -1616,7 +1654,7 @@ func tableExprToTable( } func whereToFilter(ctx *sql.Context, w *sqlparser.Where, child sql.Node) (*plan.Filter, error) { - c, err := exprToExpression(ctx, w.Expr) + c, err := ExprToExpression(ctx, w.Expr) if err != nil { return nil, err } @@ -1636,7 +1674,7 @@ func orderByToSort(ctx *sql.Context, ob sqlparser.OrderBy, child sql.Node) (*pla func orderByToSortFields(ctx *sql.Context, ob sqlparser.OrderBy) ([]sql.SortField, error) { var sortFields []sql.SortField for _, o := range ob { - e, err := exprToExpression(ctx, o.Expr) + e, err := ExprToExpression(ctx, o.Expr) if err != nil { return nil, err } @@ -1675,7 +1713,7 @@ func limitToLimit( } func havingToHaving(ctx *sql.Context, having *sqlparser.Where, node sql.Node) (sql.Node, error) { - cond, err := exprToExpression(ctx, having.Expr) + cond, err := ExprToExpression(ctx, having.Expr) if err != nil { return nil, err } @@ -1703,7 +1741,7 @@ func offsetToOffset( // getInt64Literal returns an int64 *expression.Literal for the value given, or an unsupported error with the string // given if the expression doesn't represent an integer literal. func getInt64Literal(ctx *sql.Context, expr sqlparser.Expr, errStr string) (*expression.Literal, error) { - e, err := exprToExpression(ctx, expr) + e, err := ExprToExpression(ctx, expr) if err != nil { return nil, err } @@ -1873,7 +1911,7 @@ func StringToColumnDefaultValue(ctx *sql.Context, exprStr string) (*sql.ColumnDe if !ok { return nil, fmt.Errorf("DefaultStringToExpression expected *sqlparser.AliasedExpr but received %T", parserSelect.SelectExprs[0]) } - parsedExpr, err := exprToExpression(ctx, aliasedExpr.Expr) + parsedExpr, err := ExprToExpression(ctx, aliasedExpr.Expr) if err != nil { return nil, err } @@ -1903,7 +1941,7 @@ func MustStringToColumnDefaultValue(ctx *sql.Context, exprStr string, outType sq return expr } -func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error) { +func ExprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error) { switch v := e.(type) { default: return nil, ErrUnsupportedSyntax.New(sqlparser.String(e)) @@ -1915,14 +1953,14 @@ func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error err error ) if v.Name != nil { - name, err = exprToExpression(ctx, v.Name) + name, err = ExprToExpression(ctx, v.Name) } else { - name, err = exprToExpression(ctx, v.StrVal) + name, err = ExprToExpression(ctx, v.StrVal) } if err != nil { return nil, err } - from, err := exprToExpression(ctx, v.From) + from, err := ExprToExpression(ctx, v.From) if err != nil { return nil, err } @@ -1930,7 +1968,7 @@ func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error if v.To == nil { return function.NewSubstring(name, from) } - to, err := exprToExpression(ctx, v.To) + to, err := ExprToExpression(ctx, v.To) if err != nil { return nil, err } @@ -1940,7 +1978,7 @@ func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error case *sqlparser.IsExpr: return isExprToExpression(ctx, v) case *sqlparser.NotExpr: - c, err := exprToExpression(ctx, v.Expr) + c, err := ExprToExpression(ctx, v.Expr) if err != nil { return nil, err } @@ -1981,50 +2019,50 @@ func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error return expression.NewUnresolvedFunction(v.Name.Lowered(), isAggregateFunc(v), overToWindow(ctx, v.Over), exprs...), nil case *sqlparser.ParenExpr: - return exprToExpression(ctx, v.Expr) + return ExprToExpression(ctx, v.Expr) case *sqlparser.AndExpr: - lhs, err := exprToExpression(ctx, v.Left) + lhs, err := ExprToExpression(ctx, v.Left) if err != nil { return nil, err } - rhs, err := exprToExpression(ctx, v.Right) + rhs, err := ExprToExpression(ctx, v.Right) if err != nil { return nil, err } return expression.NewAnd(lhs, rhs), nil case *sqlparser.OrExpr: - lhs, err := exprToExpression(ctx, v.Left) + lhs, err := ExprToExpression(ctx, v.Left) if err != nil { return nil, err } - rhs, err := exprToExpression(ctx, v.Right) + rhs, err := ExprToExpression(ctx, v.Right) if err != nil { return nil, err } return expression.NewOr(lhs, rhs), nil case *sqlparser.ConvertExpr: - expr, err := exprToExpression(ctx, v.Expr) + expr, err := ExprToExpression(ctx, v.Expr) if err != nil { return nil, err } return expression.NewConvert(expr, v.Type.Type), nil case *sqlparser.RangeCond: - val, err := exprToExpression(ctx, v.Left) + val, err := ExprToExpression(ctx, v.Left) if err != nil { return nil, err } - lower, err := exprToExpression(ctx, v.From) + lower, err := ExprToExpression(ctx, v.From) if err != nil { return nil, err } - upper, err := exprToExpression(ctx, v.To) + upper, err := ExprToExpression(ctx, v.To) if err != nil { return nil, err } @@ -2040,7 +2078,7 @@ func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error case sqlparser.ValTuple: var exprs = make([]sql.Expression, len(v)) for i, e := range v { - expr, err := exprToExpression(ctx, e) + expr, err := ExprToExpression(ctx, e) if err != nil { return nil, err } @@ -2067,9 +2105,9 @@ func exprToExpression(ctx *sql.Context, e sqlparser.Expr) (sql.Expression, error return intervalExprToExpression(ctx, v) case *sqlparser.CollateExpr: // TODO: handle collation - return exprToExpression(ctx, v.Expr) + return ExprToExpression(ctx, v.Expr) case *sqlparser.ValuesFuncExpr: - col, err := exprToExpression(ctx, v.Name) + col, err := ExprToExpression(ctx, v.Name) if err != nil { return nil, err } @@ -2090,7 +2128,7 @@ func overToWindow(ctx *sql.Context, over *sqlparser.Over) *sql.Window { partitions := make([]sql.Expression, len(over.PartitionBy)) for i, expr := range over.PartitionBy { var err error - partitions[i], err = exprToExpression(ctx, expr) + partitions[i], err = ExprToExpression(ctx, expr) if err != nil { return nil } @@ -2179,7 +2217,7 @@ func convertVal(v *sqlparser.SQLVal) (sql.Expression, error) { } func isExprToExpression(ctx *sql.Context, c *sqlparser.IsExpr) (sql.Expression, error) { - e, err := exprToExpression(ctx, c.Expr) + e, err := ExprToExpression(ctx, c.Expr) if err != nil { return nil, err } @@ -2203,12 +2241,12 @@ func isExprToExpression(ctx *sql.Context, c *sqlparser.IsExpr) (sql.Expression, } func comparisonExprToExpression(ctx *sql.Context, c *sqlparser.ComparisonExpr) (sql.Expression, error) { - left, err := exprToExpression(ctx, c.Left) + left, err := ExprToExpression(ctx, c.Left) if err != nil { return nil, err } - right, err := exprToExpression(ctx, c.Right) + right, err := ExprToExpression(ctx, c.Right) if err != nil { return nil, err } @@ -2264,7 +2302,7 @@ func comparisonExprToExpression(ctx *sql.Context, c *sqlparser.ComparisonExpr) ( func groupByToExpressions(ctx *sql.Context, g sqlparser.GroupBy) ([]sql.Expression, error) { es := make([]sql.Expression, len(g)) for i, ve := range g { - e, err := exprToExpression(ctx, ve) + e, err := ExprToExpression(ctx, ve) if err != nil { return nil, err } @@ -2285,7 +2323,7 @@ func selectExprToExpression(ctx *sql.Context, se sqlparser.SelectExpr) (sql.Expr } return expression.NewQualifiedStar(e.TableName.Name.String()), nil case *sqlparser.AliasedExpr: - expr, err := exprToExpression(ctx, e.Expr) + expr, err := ExprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -2301,7 +2339,7 @@ func selectExprToExpression(ctx *sql.Context, se sqlparser.SelectExpr) (sql.Expr func unaryExprToExpression(ctx *sql.Context, e *sqlparser.UnaryExpr) (sql.Expression, error) { switch strings.ToLower(e.Operator) { case sqlparser.MinusStr: - expr, err := exprToExpression(ctx, e.Expr) + expr, err := ExprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -2309,7 +2347,7 @@ func unaryExprToExpression(ctx *sql.Context, e *sqlparser.UnaryExpr) (sql.Expres return expression.NewUnaryMinus(expr), nil case sqlparser.PlusStr: // Unary plus expressions do nothing (do not turn the expression positive). Just return the underlying expression. - return exprToExpression(ctx, e.Expr) + return ExprToExpression(ctx, e.Expr) default: return nil, ErrUnsupportedFeature.New("unary operator: " + e.Operator) @@ -2331,12 +2369,12 @@ func binaryExprToExpression(ctx *sql.Context, be *sqlparser.BinaryExpr) (sql.Exp sqlparser.IntDivStr, sqlparser.ModStr: - l, err := exprToExpression(ctx, be.Left) + l, err := ExprToExpression(ctx, be.Left) if err != nil { return nil, err } - r, err := exprToExpression(ctx, be.Right) + r, err := ExprToExpression(ctx, be.Right) if err != nil { return nil, err } @@ -2367,7 +2405,7 @@ func caseExprToExpression(ctx *sql.Context, e *sqlparser.CaseExpr) (sql.Expressi var err error if e.Expr != nil { - expr, err = exprToExpression(ctx, e.Expr) + expr, err = ExprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -2376,13 +2414,13 @@ func caseExprToExpression(ctx *sql.Context, e *sqlparser.CaseExpr) (sql.Expressi var branches []expression.CaseBranch for _, w := range e.Whens { var cond sql.Expression - cond, err = exprToExpression(ctx, w.Cond) + cond, err = ExprToExpression(ctx, w.Cond) if err != nil { return nil, err } var val sql.Expression - val, err = exprToExpression(ctx, w.Val) + val, err = ExprToExpression(ctx, w.Val) if err != nil { return nil, err } @@ -2395,7 +2433,7 @@ func caseExprToExpression(ctx *sql.Context, e *sqlparser.CaseExpr) (sql.Expressi var elseExpr sql.Expression if e.Else != nil { - elseExpr, err = exprToExpression(ctx, e.Else) + elseExpr, err = ExprToExpression(ctx, e.Else) if err != nil { return nil, err } @@ -2405,7 +2443,7 @@ func caseExprToExpression(ctx *sql.Context, e *sqlparser.CaseExpr) (sql.Expressi } func intervalExprToExpression(ctx *sql.Context, e *sqlparser.IntervalExpr) (sql.Expression, error) { - expr, err := exprToExpression(ctx, e.Expr) + expr, err := ExprToExpression(ctx, e.Expr) if err != nil { return nil, err } @@ -2416,11 +2454,11 @@ func intervalExprToExpression(ctx *sql.Context, e *sqlparser.IntervalExpr) (sql. func setExprsToExpressions(ctx *sql.Context, e sqlparser.SetExprs) ([]sql.Expression, error) { res := make([]sql.Expression, len(e)) for i, updateExpr := range e { - colName, err := exprToExpression(ctx, updateExpr.Name) + colName, err := ExprToExpression(ctx, updateExpr.Name) if err != nil { return nil, err } - innerExpr, err := exprToExpression(ctx, updateExpr.Expr) + innerExpr, err := ExprToExpression(ctx, updateExpr.Expr) if err != nil { return nil, err } @@ -2434,7 +2472,7 @@ func convertShowTableStatus(ctx *sql.Context, s *sqlparser.Show) (sql.Node, erro if s.Filter != nil { if s.Filter.Filter != nil { var err error - filter, err = exprToExpression(ctx, s.Filter.Filter) + filter, err = ExprToExpression(ctx, s.Filter.Filter) if err != nil { return nil, err } diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index 883ac82c42..b3dd5db30d 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -48,493 +48,613 @@ var fixtures = map[string]sql.Node{ `CREATE TABLE t1(a INTEGER, b TEXT, c DATE, d TIMESTAMP, e VARCHAR(20), f BLOB NOT NULL, g DATETIME, h CHAR(40))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: true, - }, { - Name: "b", - Type: sql.Text, - Nullable: true, - }, { - Name: "c", - Type: sql.Date, - Nullable: true, - }, { - Name: "d", - Type: sql.Timestamp, - Nullable: true, - }, { - Name: "e", - Type: sql.MustCreateStringWithDefaults(sqltypes.VarChar, 20), - Nullable: true, - }, { - Name: "f", - Type: sql.Blob, - Nullable: false, - }, { - Name: "g", - Type: sql.Datetime, - Nullable: true, - }, { - Name: "h", - Type: sql.MustCreateStringWithDefaults(sqltypes.Char, 40), - Nullable: true, - }}, false, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + }, { + Name: "c", + Type: sql.Date, + Nullable: true, + }, { + Name: "d", + Type: sql.Timestamp, + Nullable: true, + }, { + Name: "e", + Type: sql.MustCreateStringWithDefaults(sqltypes.VarChar, 20), + Nullable: true, + }, { + Name: "f", + Type: sql.Blob, + Nullable: false, + }, { + Name: "g", + Type: sql.Datetime, + Nullable: true, + }, { + Name: "h", + Type: sql.MustCreateStringWithDefaults(sqltypes.Char, 40), + Nullable: true, + }}, + }, ), `CREATE TABLE t1(a INTEGER NOT NULL PRIMARY KEY, b TEXT)`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Text, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + PrimaryKey: false, + }}, + }, ), `CREATE TABLE t1(a INTEGER NOT NULL PRIMARY KEY COMMENT "hello", b TEXT COMMENT "goodbye")`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - Comment: "hello", - }, { - Name: "b", - Type: sql.Text, - Nullable: true, - PrimaryKey: false, - Comment: "goodbye", - }}, false, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + Comment: "hello", + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + PrimaryKey: false, + Comment: "goodbye", + }}, + }, ), `CREATE TABLE t1(a INTEGER, b TEXT, PRIMARY KEY (a))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Text, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: true, + PrimaryKey: false, + }}, + }, ), `CREATE TABLE t1(a INTEGER, b TEXT, PRIMARY KEY (a, b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Text, - Nullable: false, - PrimaryKey: true, - }}, false, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: false, + PrimaryKey: true, + }}, + }, ), `CREATE TABLE IF NOT EXISTS t1(a INTEGER, b TEXT, PRIMARY KEY (a, b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Text, - Nullable: false, - PrimaryKey: true, - }}, true, - nil, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Text, + Nullable: false, + PrimaryKey: true, + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, INDEX (b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - []*plan.IndexDefinition{{ - IndexName: "", - Using: sql.IndexUsing_Default, - Constraint: sql.IndexConstraint_None, - Columns: []sql.IndexColumn{{"b", 0}}, - Comment: "", - }}, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + IdxDefs: []*plan.IndexDefinition{{ + IndexName: "", + Using: sql.IndexUsing_Default, + Constraint: sql.IndexConstraint_None, + Columns: []sql.IndexColumn{{"b", 0}}, + Comment: "", + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, INDEX idx_name (b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - []*plan.IndexDefinition{{ - IndexName: "idx_name", - Using: sql.IndexUsing_Default, - Constraint: sql.IndexConstraint_None, - Columns: []sql.IndexColumn{{"b", 0}}, - Comment: "", - }}, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + IdxDefs: []*plan.IndexDefinition{{ + IndexName: "idx_name", + Using: sql.IndexUsing_Default, + Constraint: sql.IndexConstraint_None, + Columns: []sql.IndexColumn{{"b", 0}}, + Comment: "", + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, INDEX idx_name (b) COMMENT 'hi')`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - []*plan.IndexDefinition{{ - IndexName: "idx_name", - Using: sql.IndexUsing_Default, - Constraint: sql.IndexConstraint_None, - Columns: []sql.IndexColumn{{"b", 0}}, - Comment: "hi", - }}, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + IdxDefs: []*plan.IndexDefinition{{ + IndexName: "idx_name", + Using: sql.IndexUsing_Default, + Constraint: sql.IndexConstraint_None, + Columns: []sql.IndexColumn{{"b", 0}}, + Comment: "hi", + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, UNIQUE INDEX (b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - []*plan.IndexDefinition{{ - IndexName: "", - Using: sql.IndexUsing_Default, - Constraint: sql.IndexConstraint_Unique, - Columns: []sql.IndexColumn{{"b", 0}}, - Comment: "", - }}, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + IdxDefs: []*plan.IndexDefinition{{ + IndexName: "", + Using: sql.IndexUsing_Default, + Constraint: sql.IndexConstraint_Unique, + Columns: []sql.IndexColumn{{"b", 0}}, + Comment: "", + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, UNIQUE (b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - []*plan.IndexDefinition{{ - IndexName: "", - Using: sql.IndexUsing_Default, - Constraint: sql.IndexConstraint_Unique, - Columns: []sql.IndexColumn{{"b", 0}}, - Comment: "", - }}, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + IdxDefs: []*plan.IndexDefinition{{ + IndexName: "", + Using: sql.IndexUsing_Default, + Constraint: sql.IndexConstraint_Unique, + Columns: []sql.IndexColumn{{"b", 0}}, + Comment: "", + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, INDEX (b, a))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - []*plan.IndexDefinition{{ - IndexName: "", - Using: sql.IndexUsing_Default, - Constraint: sql.IndexConstraint_None, - Columns: []sql.IndexColumn{{"b", 0}, {"a", 0}}, - Comment: "", - }}, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + IdxDefs: []*plan.IndexDefinition{{ + IndexName: "", + Using: sql.IndexUsing_Default, + Constraint: sql.IndexConstraint_None, + Columns: []sql.IndexColumn{{"b", 0}, {"a", 0}}, + Comment: "", + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b INTEGER, INDEX (b), INDEX (b, a))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - []*plan.IndexDefinition{{ - IndexName: "", - Using: sql.IndexUsing_Default, - Constraint: sql.IndexConstraint_None, - Columns: []sql.IndexColumn{{"b", 0}}, - Comment: "", - }, { - IndexName: "", - Using: sql.IndexUsing_Default, - Constraint: sql.IndexConstraint_None, - Columns: []sql.IndexColumn{{"b", 0}, {"a", 0}}, - Comment: "", - }}, - nil, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + IdxDefs: []*plan.IndexDefinition{{ + IndexName: "", + Using: sql.IndexUsing_Default, + Constraint: sql.IndexConstraint_None, + Columns: []sql.IndexColumn{{"b", 0}}, + Comment: "", + }, { + IndexName: "", + Using: sql.IndexUsing_Default, + Constraint: sql.IndexConstraint_None, + Columns: []sql.IndexColumn{{"b", 0}, {"a", 0}}, + Comment: "", + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, FOREIGN KEY (b_id) REFERENCES t0(b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b_id", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - []*sql.ForeignKeyConstraint{{ - Name: "", - Columns: []string{"b_id"}, - ReferencedTable: "t0", - ReferencedColumns: []string{"b"}, - OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, - OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, - }}, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b_id", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + FkDefs: []*sql.ForeignKeyConstraint{{ + Name: "", + Columns: []string{"b_id"}, + ReferencedTable: "t0", + ReferencedColumns: []string{"b"}, + OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, + OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, CONSTRAINT fk_name FOREIGN KEY (b_id) REFERENCES t0(b))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b_id", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - []*sql.ForeignKeyConstraint{{ - Name: "fk_name", - Columns: []string{"b_id"}, - ReferencedTable: "t0", - ReferencedColumns: []string{"b"}, - OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, - OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, - }}, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b_id", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + FkDefs: []*sql.ForeignKeyConstraint{{ + Name: "fk_name", + Columns: []string{"b_id"}, + ReferencedTable: "t0", + ReferencedColumns: []string{"b"}, + OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, + OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, FOREIGN KEY (b_id) REFERENCES t0(b) ON UPDATE CASCADE)`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b_id", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - []*sql.ForeignKeyConstraint{{ - Name: "", - Columns: []string{"b_id"}, - ReferencedTable: "t0", - ReferencedColumns: []string{"b"}, - OnUpdate: sql.ForeignKeyReferenceOption_Cascade, - OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, - }}, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b_id", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + FkDefs: []*sql.ForeignKeyConstraint{{ + Name: "", + Columns: []string{"b_id"}, + ReferencedTable: "t0", + ReferencedColumns: []string{"b"}, + OnUpdate: sql.ForeignKeyReferenceOption_Cascade, + OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, FOREIGN KEY (b_id) REFERENCES t0(b) ON DELETE RESTRICT)`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b_id", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - []*sql.ForeignKeyConstraint{{ - Name: "", - Columns: []string{"b_id"}, - ReferencedTable: "t0", - ReferencedColumns: []string{"b"}, - OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, - OnDelete: sql.ForeignKeyReferenceOption_Restrict, - }}, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b_id", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + FkDefs: []*sql.ForeignKeyConstraint{{ + Name: "", + Columns: []string{"b_id"}, + ReferencedTable: "t0", + ReferencedColumns: []string{"b"}, + OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, + OnDelete: sql.ForeignKeyReferenceOption_Restrict, + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, FOREIGN KEY (b_id) REFERENCES t0(b) ON UPDATE SET NULL ON DELETE NO ACTION)`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b_id", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - []*sql.ForeignKeyConstraint{{ - Name: "", - Columns: []string{"b_id"}, - ReferencedTable: "t0", - ReferencedColumns: []string{"b"}, - OnUpdate: sql.ForeignKeyReferenceOption_SetNull, - OnDelete: sql.ForeignKeyReferenceOption_NoAction, - }}, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b_id", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }}, + + FkDefs: []*sql.ForeignKeyConstraint{{ + Name: "", + Columns: []string{"b_id"}, + ReferencedTable: "t0", + ReferencedColumns: []string{"b"}, + OnUpdate: sql.ForeignKeyReferenceOption_SetNull, + OnDelete: sql.ForeignKeyReferenceOption_NoAction, + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, c_id BIGINT, FOREIGN KEY (b_id, c_id) REFERENCES t0(b, c))`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b_id", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }, { - Name: "c_id", - Type: sql.Int64, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - []*sql.ForeignKeyConstraint{{ - Name: "", - Columns: []string{"b_id", "c_id"}, - ReferencedTable: "t0", - ReferencedColumns: []string{"b", "c"}, - OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, - OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, - }}, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b_id", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }, { + Name: "c_id", + Type: sql.Int64, + Nullable: true, + PrimaryKey: false, + }}, + FkDefs: []*sql.ForeignKeyConstraint{{ + Name: "", + Columns: []string{"b_id", "c_id"}, + ReferencedTable: "t0", + ReferencedColumns: []string{"b", "c"}, + OnUpdate: sql.ForeignKeyReferenceOption_DefaultAction, + OnDelete: sql.ForeignKeyReferenceOption_DefaultAction, + }}, + }, ), `CREATE TABLE t1(a INTEGER PRIMARY KEY, b_id INTEGER, c_id BIGINT, CONSTRAINT fk_name FOREIGN KEY (b_id, c_id) REFERENCES t0(b, c) ON UPDATE RESTRICT ON DELETE CASCADE)`: plan.NewCreateTable( sql.UnresolvedDatabase(""), "t1", - sql.Schema{{ - Name: "a", - Type: sql.Int32, - Nullable: false, - PrimaryKey: true, - }, { - Name: "b_id", - Type: sql.Int32, - Nullable: true, - PrimaryKey: false, - }, { - Name: "c_id", - Type: sql.Int64, - Nullable: true, - PrimaryKey: false, - }}, false, - nil, - []*sql.ForeignKeyConstraint{{ - Name: "fk_name", - Columns: []string{"b_id", "c_id"}, - ReferencedTable: "t0", - ReferencedColumns: []string{"b", "c"}, - OnUpdate: sql.ForeignKeyReferenceOption_Restrict, - OnDelete: sql.ForeignKeyReferenceOption_Cascade, - }}, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }, { + Name: "b_id", + Type: sql.Int32, + Nullable: true, + PrimaryKey: false, + }, { + Name: "c_id", + Type: sql.Int64, + Nullable: true, + PrimaryKey: false, + }}, + FkDefs: []*sql.ForeignKeyConstraint{{ + Name: "fk_name", + Columns: []string{"b_id", "c_id"}, + ReferencedTable: "t0", + ReferencedColumns: []string{"b", "c"}, + OnUpdate: sql.ForeignKeyReferenceOption_Restrict, + OnDelete: sql.ForeignKeyReferenceOption_Cascade, + }}, + }, + ), + `CREATE TABLE t1(a INTEGER PRIMARY KEY, CHECK (a > 0))`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + false, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }}, + ChDefs: []*sql.CheckConstraint{{ + Name: "", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }}, + }, + ), + `CREATE TABLE t1(a INTEGER PRIMARY KEY CHECK (a > 0))`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + false, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }}, + ChDefs: []*sql.CheckConstraint{{ + Name: "", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }}, + }, + ), + `CREATE TABLE t1(a INTEGER PRIMARY KEY, CONSTRAINT ch1 CHECK (a > 0))`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + false, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }}, + ChDefs: []*sql.CheckConstraint{{ + Name: "ch1", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }}, + }, + ), + `CREATE TABLE t1(a INTEGER PRIMARY KEY CHECK (a > 0) ENFORCED)`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + false, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }}, + ChDefs: []*sql.CheckConstraint{{ + Name: "", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }}, + }, + ), + `CREATE TABLE t1(a INTEGER PRIMARY KEY CHECK (a > 0) NOT ENFORCED)`: plan.NewCreateTable( + sql.UnresolvedDatabase(""), + "t1", + false, + &plan.TableSpec{ + Schema: sql.Schema{{ + Name: "a", + Type: sql.Int32, + Nullable: false, + PrimaryKey: true, + }}, + ChDefs: []*sql.CheckConstraint{{ + Name: "", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: false, + }}, + }, ), `DROP TABLE foo;`: plan.NewDropTable( sql.UnresolvedDatabase(""), false, "foo", @@ -733,6 +853,39 @@ var fixtures = map[string]sql.Node{ OnDelete: sql.ForeignKeyReferenceOption_Cascade, }, ), + `ALTER TABLE t1 ADD CHECK (a > 0)`: plan.NewAlterAddCheck( + plan.NewUnresolvedTable("t1", ""), + &sql.CheckConstraint{ + Name: "", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }, + ), + `ALTER TABLE t1 ADD CONSTRAINT ch1 CHECK (a > 0)`: plan.NewAlterAddCheck( + plan.NewUnresolvedTable("t1", ""), + &sql.CheckConstraint{ + Name: "ch1", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }, + ), + `ALTER TABLE t1 ADD CONSTRAINT CHECK (a > 0)`: plan.NewAlterAddCheck( + plan.NewUnresolvedTable("t1", ""), + &sql.CheckConstraint{ + Name: "", + Expr: expression.NewGreaterThan( + expression.NewUnresolvedColumn("a"), + expression.NewLiteral(int8(0), sql.Int8), + ), + Enforced: true, + }, + ), `ALTER TABLE t1 DROP FOREIGN KEY fk_name`: plan.NewAlterDropForeignKey( plan.NewUnresolvedTable("t1", ""), &sql.ForeignKeyConstraint{ @@ -1086,6 +1239,7 @@ var fixtures = map[string]sql.Node{ false, []string{"col1", "col2"}, []sql.Expression{}, + nil, ), `INSERT INTO t1 (col1, col2) VALUES (?, ?)`: plan.NewInsertInto( plan.NewUnresolvedTable("t1", ""), @@ -1096,6 +1250,7 @@ var fixtures = map[string]sql.Node{ false, []string{"col1", "col2"}, []sql.Expression{}, + nil, ), `UPDATE t1 SET col1 = ?, col2 = ? WHERE id = ?`: plan.NewUpdate( plan.NewFilter( @@ -1116,6 +1271,7 @@ var fixtures = map[string]sql.Node{ true, []string{"col1", "col2"}, []sql.Expression{}, + nil, ), `SHOW TABLES`: plan.NewShowTables(sql.UnresolvedDatabase(""), false, nil), `SHOW FULL TABLES`: plan.NewShowTables(sql.UnresolvedDatabase(""), true, nil), @@ -2428,6 +2584,7 @@ var fixtures = map[string]sql.Node{ false, []string{"a", "b"}, []sql.Expression{}, + nil, ), }), ), @@ -2455,6 +2612,7 @@ var fixtures = map[string]sql.Node{ false, []string{"a", "b"}, []sql.Expression{}, + nil, ), `CREATE TRIGGER myTrigger BEFORE UPDATE ON foo FOR EACH ROW INSERT INTO zzz (a,b) VALUES (old.a, old.b)`, `INSERT INTO zzz (a,b) VALUES (old.a, old.b)`, @@ -2472,6 +2630,7 @@ var fixtures = map[string]sql.Node{ false, []string{"a", "b"}, []sql.Expression{}, + nil, ), `CREATE TRIGGER myTrigger BEFORE UPDATE ON foo FOR EACH ROW FOLLOWS yourTrigger INSERT INTO zzz (a,b) VALUES (old.a, old.b)`, `INSERT INTO zzz (a,b) VALUES (old.a, old.b)`, diff --git a/sql/plan/alter_check.go b/sql/plan/alter_check.go new file mode 100644 index 0000000000..272a6c4abe --- /dev/null +++ b/sql/plan/alter_check.go @@ -0,0 +1,224 @@ +// Copyright 2021 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plan + +import ( + "fmt" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "gopkg.in/src-d/go-errors.v1" + "io" +) + +var ( + // ErrNoCheckConstraintSupport is returned when the table does not support CONSTRAINT CHECK operations. + ErrNoCheckConstraintSupport = errors.NewKind("the table does not support check constraint operations: %s") + + // ErrCheckFailed is returned when the check constraint evaluates to false + ErrCheckFailed = errors.NewKind("check constraint %s is violated.") +) + +type CreateCheck struct { + UnaryNode + Check *sql.CheckConstraint +} + +type DropCheck struct { + UnaryNode + Check *sql.CheckConstraint +} + +func NewAlterAddCheck(table sql.Node, check *sql.CheckConstraint) *CreateCheck { + return &CreateCheck{ + UnaryNode: UnaryNode{table}, + Check: check, + } +} + +func NewAlterDropCheck(table sql.Node, check *sql.CheckConstraint) *DropCheck { + return &DropCheck{ + UnaryNode: UnaryNode{Child: table}, + Check: check, + } +} + +func getCheckAlterable(node sql.Node) (sql.CheckAlterableTable, error) { + switch node := node.(type) { + case sql.CheckAlterableTable: + return node, nil + case *ResolvedTable: + return getCheckAlterableTable(node.Table) + default: + return nil, ErrNoCheckConstraintSupport.New(node.String()) + } +} + +func getCheckAlterableTable(t sql.Table) (sql.CheckAlterableTable, error) { + switch t := t.(type) { + case sql.CheckAlterableTable: + return t, nil + case sql.TableWrapper: + return getCheckAlterableTable(t.Underlying()) + case *ResolvedTable: + return getCheckAlterableTable(t.Table) + default: + return nil, ErrNoCheckConstraintSupport.New(t.Name()) + } +} + +// Expressions implements the sql.Expressioner interface. +func (c *CreateCheck) Expressions() []sql.Expression { + return []sql.Expression{c.Check.Expr} +} + +// Resolved implements the Resolvable interface. +func (c *CreateCheck) Resolved() bool { + ok := true + sql.Inspect(c.Check.Expr, func(expr sql.Expression) bool { + switch expr.(type) { + case *expression.UnresolvedColumn: + ok = false + return false + } + return true + }) + return ok +} + +// WithExpressions implements the sql.Expressioner interface. +func (c *CreateCheck) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { + if len(exprs) != 1 { + return nil, fmt.Errorf("expected one expression, got: %d", len(exprs)) + } + + nc := *c + nc.Check.Expr = exprs[0] + return &nc, nil +} + +// Execute inserts the rows in the database. +func (p *CreateCheck) Execute(ctx *sql.Context) error { + chAlterable, err := getCheckAlterable(p.UnaryNode.Child) + if err != nil { + return err + } + + // check existing rows in table + var res interface{} + rowIter, err := p.UnaryNode.Child.RowIter(ctx, nil) + if err != nil { + return err + } + + for { + row, err := rowIter.Next() + if row == nil || err != io.EOF { + break + } + res, err = p.Check.Expr.Eval(ctx, row) + if err != nil { + return err + } + if val, ok := res.(bool); !ok || !val { + return ErrCheckFailed.New(p.Check.Name) + } + } + + check, err := NewCheckDefinition(p.Check) + if err != nil { + return err + } + + return chAlterable.CreateCheck(ctx, check) +} + +// WithChildren implements the Node interface. +func (p *CreateCheck) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) + } + return NewAlterAddCheck(children[0], p.Check), nil +} + +func (p *CreateCheck) Schema() sql.Schema { return nil } + +func (p *CreateCheck) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { + err := p.Execute(ctx) + if err != nil { + return nil, err + } + return sql.RowsToRowIter(), nil +} + +func (p CreateCheck) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("AddCheck(%s)", p.Check.Name) + _ = pr.WriteChildren( + fmt.Sprintf("Table(%s)", p.UnaryNode.Child.String()), + fmt.Sprintf("Expr(%s)", p.Check.Expr.String()), + ) + return pr.String() +} + +// Execute inserts the rows in the database. +func (p *DropCheck) Execute(ctx *sql.Context) error { + chAlterable, err := getCheckAlterable(p.UnaryNode.Child) + if err != nil { + return err + } + return chAlterable.DropCheck(ctx, p.Check.Name) +} + +// RowIter implements the Node interface. +func (p *DropCheck) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { + err := p.Execute(ctx) + if err != nil { + return nil, err + } + return sql.RowsToRowIter(), nil +} + +// WithChildren implements the Node interface. +func (p *DropCheck) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) + } + return NewAlterDropCheck(children[0], p.Check), nil +} +func (p *DropCheck) Schema() sql.Schema { return nil } + +func (p DropCheck) String() string { + pr := sql.NewTreePrinter() + _ = pr.WriteNode("DropCheck(%s)", p.Check.Name) + _ = pr.WriteChildren(fmt.Sprintf("Table(%s)", p.UnaryNode.Child.String())) + return pr.String() +} + +func NewCheckDefinition(check *sql.CheckConstraint) (*sql.CheckDefinition, error) { + enforced := "" + if !check.Enforced { + enforced = " NOT ENFORCED" + } + + return &sql.CheckDefinition{ + Name: check.Name, + AlterStatement: fmt.Sprintf( + "ALTER TABLE _ ADD CONSTRAINT %s CHECK (%s)%s", + check.Name, + check.Expr.String(), + enforced, + ), + }, nil +} diff --git a/sql/plan/create_index.go b/sql/plan/create_index.go index 4cae8be9df..c58b8f60b5 100644 --- a/sql/plan/create_index.go +++ b/sql/plan/create_index.go @@ -137,7 +137,7 @@ func (c *CreateIndex) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error } for _, e := range exprs { - if sql.IsBlob(e.Type()) || sql.IsJSON(e.Type()){ + if sql.IsBlob(e.Type()) || sql.IsJSON(e.Type()) { return nil, ErrExprTypeNotIndexable.New(e, e.Type()) } } diff --git a/sql/plan/ddl.go b/sql/plan/ddl.go index 1371257f9f..72f554bb4b 100644 --- a/sql/plan/ddl.go +++ b/sql/plan/ddl.go @@ -75,6 +75,38 @@ type IndexDefinition struct { Comment string } +// TableSpec is a node describing the schema of a table. +type TableSpec struct { + Schema sql.Schema + FkDefs []*sql.ForeignKeyConstraint + ChDefs []*sql.CheckConstraint + IdxDefs []*IndexDefinition +} + +func (c *TableSpec) WithSchema(schema sql.Schema) (*TableSpec, error) { + nc := *c + nc.Schema = schema + return &nc, nil +} + +func (c *TableSpec) WithForeignKeys(fkDefs []*sql.ForeignKeyConstraint) (*TableSpec, error) { + nc := *c + nc.FkDefs = fkDefs + return &nc, nil +} + +func (c *TableSpec) WithCheckConstraints(chDefs []*sql.CheckConstraint) (*TableSpec, error) { + nc := *c + nc.ChDefs = chDefs + return &nc, nil +} + +func (c *TableSpec) WithIndices(idxDefs []*IndexDefinition) (*TableSpec, error) { + nc := *c + nc.IdxDefs = idxDefs + return &nc, nil +} + // CreateTable is a node describing the creation of some table. type CreateTable struct { ddlNode @@ -82,6 +114,7 @@ type CreateTable struct { schema sql.Schema ifNotExists bool fkDefs []*sql.ForeignKeyConstraint + chDefs []*sql.CheckConstraint idxDefs []*IndexDefinition like sql.Node } @@ -91,18 +124,19 @@ var _ sql.Node = (*CreateTable)(nil) var _ sql.Expressioner = (*CreateTable)(nil) // NewCreateTable creates a new CreateTable node -func NewCreateTable(db sql.Database, name string, schema sql.Schema, ifNotExists bool, idxDefs []*IndexDefinition, fkDefs []*sql.ForeignKeyConstraint) *CreateTable { - for _, s := range schema { +func NewCreateTable(db sql.Database, name string, ifNotExists bool, tableSpec *TableSpec) *CreateTable { + for _, s := range tableSpec.Schema { s.Source = name } return &CreateTable{ ddlNode: ddlNode{db}, name: name, - schema: schema, + schema: tableSpec.Schema, + fkDefs: tableSpec.FkDefs, + chDefs: tableSpec.ChDefs, + idxDefs: tableSpec.IdxDefs, ifNotExists: ifNotExists, - idxDefs: idxDefs, - fkDefs: fkDefs, } } @@ -151,7 +185,7 @@ func (c *CreateTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error } //TODO: in the event that foreign keys or indexes aren't supported, you'll be left with a created table and no foreign keys/indexes //this also means that if a foreign key or index fails, you'll only have what was declared up to the failure - if len(c.idxDefs) > 0 || len(c.fkDefs) > 0 { + if len(c.idxDefs) > 0 || len(c.fkDefs) > 0 || len(c.chDefs) > 0 { tableNode, ok, err := c.db.GetTableInsensitive(ctx, c.name) if err != nil { return sql.RowsToRowIter(), err @@ -183,6 +217,22 @@ func (c *CreateTable) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error } } } + if len(c.chDefs) > 0 { + chAlterable, ok := tableNode.(sql.CheckAlterableTable) + if !ok { + return sql.RowsToRowIter(), ErrNoCheckConstraintSupport.New(c.name) + } + for _, ch := range c.chDefs { + check, err := NewCheckDefinition(ch) + if err != nil { + return nil, err + } + err = chAlterable.CreateCheck(ctx, check) + if err != nil { + return sql.RowsToRowIter(), err + } + } + } } return sql.RowsToRowIter(), nil } diff --git a/sql/plan/ddl_test.go b/sql/plan/ddl_test.go index 1b6c54479c..d32fca75e4 100644 --- a/sql/plan/ddl_test.go +++ b/sql/plan/ddl_test.go @@ -100,7 +100,7 @@ func TestDropTable(t *testing.T) { } func createTable(t *testing.T, db sql.Database, name string, schema sql.Schema, ifNotExists bool) error { - c := NewCreateTable(db, name, schema, ifNotExists, nil, nil) + c := NewCreateTable(db, name, ifNotExists, &TableSpec{Schema: schema}) rows, err := c.RowIter(sql.NewEmptyContext(), nil) if err != nil { diff --git a/sql/plan/insert.go b/sql/plan/insert.go index a123a60453..87cada87cd 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -46,16 +46,18 @@ type InsertInto struct { ColumnNames []string IsReplace bool OnDupExprs []sql.Expression + Checks []sql.Expression } // NewInsertInto creates an InsertInto node. -func NewInsertInto(dst, src sql.Node, isReplace bool, cols []string, onDupExprs []sql.Expression) *InsertInto { +func NewInsertInto(dst, src sql.Node, isReplace bool, cols []string, onDupExprs []sql.Expression, checks []sql.Expression) *InsertInto { return &InsertInto{ Destination: dst, Source: src, ColumnNames: cols, IsReplace: isReplace, OnDupExprs: onDupExprs, + Checks: checks, } } @@ -81,6 +83,7 @@ type insertIter struct { rowSource sql.RowIter ctx *sql.Context updateExprs []sql.Expression + checks []sql.Expression tableNode sql.Node closed bool } @@ -119,6 +122,7 @@ func newInsertIter( values sql.Node, isReplace bool, onDupUpdateExpr []sql.Expression, + checks []sql.Expression, row sql.Row, ) (*insertIter, error) { dstSchema := table.Schema() @@ -155,6 +159,7 @@ func newInsertIter( updater: updater, rowSource: rowIter, updateExprs: onDupUpdateExpr, + checks: checks, ctx: ctx, }, nil } @@ -182,6 +187,19 @@ func (i insertIter) Next() (returnRow sql.Row, returnErr error) { return nil, err } + // apply check constraints + var res interface{} + for _, check := range i.checks { + res, err = check.Eval(i.ctx, row) + if err != nil { + return nil, err + } + if val, ok := res.(bool); !ok || !val { + i.ctx.Warn(3819, "Check constraint '%s' is violated", check.String()) + return nil, nil + } + } + // Do any necessary type conversions to the target schema for i, col := range i.schema { if row[i] != nil { @@ -328,7 +346,7 @@ func (i insertIter) Close(ctx *sql.Context) error { // RowIter implements the Node interface. func (p *InsertInto) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { - return newInsertIter(ctx, p.Destination, p.Source, p.IsReplace, p.OnDupExprs, row) + return newInsertIter(ctx, p.Destination, p.Source, p.IsReplace, p.OnDupExprs, p.Checks, row) } // WithChildren implements the Node interface. @@ -382,15 +400,15 @@ func validateNullability(dstSchema sql.Schema, row sql.Row) error { } func (p *InsertInto) Expressions() []sql.Expression { - return p.OnDupExprs + return append(p.OnDupExprs, p.Checks...) } func (p *InsertInto) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) { - if len(newExprs) != len(p.OnDupExprs) { - return nil, sql.ErrInvalidChildrenNumber.New(p, len(p.OnDupExprs), 1) + if len(newExprs) != len(p.OnDupExprs)+len(p.Checks) { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(p.OnDupExprs)+len(p.Checks), 1) } - return NewInsertInto(p.Destination, p.Source, p.IsReplace, p.ColumnNames, newExprs), nil + return NewInsertInto(p.Destination, p.Source, p.IsReplace, p.ColumnNames, newExprs[:len(p.OnDupExprs)], newExprs[len(p.OnDupExprs):]), nil } // Resolved implements the Resolvable interface. @@ -403,5 +421,10 @@ func (p *InsertInto) Resolved() bool { return false } } + for _, checkExpr := range p.Checks { + if !checkExpr.Resolved() { + return false + } + } return true } diff --git a/sql/testutils.go b/sql/testutils.go index 41330776ce..7c30e60eac 100644 --- a/sql/testutils.go +++ b/sql/testutils.go @@ -30,4 +30,3 @@ func MustJSON(s string) JSONDocument { } return JSONDocument{Val: doc} } -