diff --git a/ddl/db_test.go b/ddl/db_test.go index 8e4b214e2ce75..49371d0e8d840 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -1832,6 +1832,59 @@ func (s *testDBSuite1) TestCreateTable(c *C) { c.Assert(err.Error(), Equals, "[types:1291]Column 'a' has duplicated value 'B' in ENUM") } +func (s *testDBSuite2) TestCreateTableWithSetCol(c *C) { + s.tk = testkit.NewTestKitWithInit(c, s.store) + s.tk.MustExec("create table t_set (a int, b set('e') default '');") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` int(11) DEFAULT NULL,\n" + + " `b` set('e') DEFAULT ''\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + s.tk.MustExec("drop table t_set") + s.tk.MustExec("create table t_set (a set('a', 'b', 'c', 'd') default 'a,C,c');") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` set('a','b','c','d') DEFAULT 'a,c'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + + // It's for failure cases. + // The type of default value is string. + s.tk.MustExec("drop table t_set") + failedSQL := "create table t_set (a set('1', '4', '10') default '3');" + s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault) + failedSQL = "create table t_set (a set('1', '4', '10') default '1,4,11');" + s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault) + failedSQL = "create table t_set (a set('1', '4', '10') default '1 ,4');" + s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault) + // The type of default value is int. + failedSQL = "create table t_set (a set('1', '4', '10') default 0);" + s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault) + failedSQL = "create table t_set (a set('1', '4', '10') default 8);" + s.tk.MustGetErrCode(failedSQL, tmysql.ErrInvalidDefault) + + // The type of default value is int. + // It's for successful cases + s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 1);") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` set('1','4','10','21') DEFAULT '1'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + s.tk.MustExec("drop table t_set") + s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 2);") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` set('1','4','10','21') DEFAULT '4'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + s.tk.MustExec("drop table t_set") + s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 3);") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` set('1','4','10','21') DEFAULT '1,4'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + s.tk.MustExec("drop table t_set") + s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 15);") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` set('1','4','10','21') DEFAULT '1,4,10,21'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + s.tk.MustExec("insert into t_set value()") + s.tk.MustQuery("select * from t_set").Check(testkit.Rows("1,4,10,21")) +} + func (s *testDBSuite2) TestTableForeignKey(c *C) { s.tk = testkit.NewTestKit(c, s.store) s.tk.MustExec("use test") diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index d09ca0746b522..bdf76bc3d5efc 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -620,8 +620,8 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o return col, constraints, nil } -func getDefaultValue(ctx sessionctx.Context, colName string, c *ast.ColumnOption, t *types.FieldType) (interface{}, error) { - tp, fsp := t.Tp, t.Decimal +func getDefaultValue(ctx sessionctx.Context, col *table.Column, c *ast.ColumnOption) (interface{}, error) { + tp, fsp := col.FieldType.Tp, col.FieldType.Decimal if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime { switch x := c.Expr.(type) { case *ast.FuncCallExpr: @@ -633,14 +633,14 @@ func getDefaultValue(ctx sessionctx.Context, colName string, c *ast.ColumnOption } } if defaultFsp != fsp { - return nil, ErrInvalidDefaultValue.GenWithStackByArgs(colName) + return nil, ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) } } } vd, err := expression.GetTimeValue(ctx, c.Expr, tp, int8(fsp)) value := vd.GetValue() if err != nil { - return nil, ErrInvalidDefaultValue.GenWithStackByArgs(colName) + return nil, ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) } // Value is nil means `default null`. @@ -681,14 +681,14 @@ func getDefaultValue(ctx sessionctx.Context, colName string, c *ast.ColumnOption return strconv.FormatUint(value, 10), nil } - if tp == mysql.TypeDuration { - var err error - if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx, t); err != nil { + switch tp { + case mysql.TypeSet: + return setSetDefaultValue(v, col) + case mysql.TypeDuration: + if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx, &col.FieldType); err != nil { return "", errors.Trace(err) } - } - - if tp == mysql.TypeBit { + case mysql.TypeBit: if v.Kind() == types.KindInt64 || v.Kind() == types.KindUint64 { // For BIT fields, convert int into BinaryLiteral. return types.NewBinaryLiteralFromUint(v.GetUint64(), -1).ToString(), nil @@ -698,6 +698,58 @@ func getDefaultValue(ctx sessionctx.Context, colName string, c *ast.ColumnOption return v.ToString() } +// setSetDefaultValue sets the default value for the set type. See https://dev.mysql.com/doc/refman/5.7/en/set.html. +func setSetDefaultValue(v types.Datum, col *table.Column) (string, error) { + if v.Kind() == types.KindInt64 { + setCnt := len(col.Elems) + maxLimit := int64(1< maxLimit { + return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) + } + setVal, err := types.ParseSetValue(col.Elems, uint64(val)) + if err != nil { + return "", errors.Trace(err) + } + v.SetMysqlSet(setVal) + return v.ToString() + } + + str, err := v.ToString() + if err != nil { + return "", errors.Trace(err) + } + if str == "" { + return str, nil + } + + valMap := make(map[string]struct{}, len(col.Elems)) + dVals := strings.Split(strings.ToLower(str), ",") + for _, dv := range dVals { + valMap[dv] = struct{}{} + } + var existCnt int + for dv := range valMap { + for i := range col.Elems { + e := strings.ToLower(col.Elems[i]) + if e == dv { + existCnt++ + break + } + } + } + if existCnt != len(valMap) { + return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) + } + setVal, err := types.ParseSetName(col.Elems, str) + if err != nil { + return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) + } + v.SetMysqlSet(setVal) + + return v.ToString() +} + func removeOnUpdateNowFlag(c *table.Column) { // For timestamp Col, if it is set null or default value, // OnUpdateNowFlag should be removed. @@ -2491,7 +2543,7 @@ func modifiable(origin *types.FieldType, to *types.FieldType) error { func setDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (bool, error) { hasDefaultValue := false - value, err := getDefaultValue(ctx, col.Name.L, option, &col.FieldType) + value, err := getDefaultValue(ctx, col, option) if err != nil { return hasDefaultValue, errors.Trace(err) } diff --git a/executor/seqtest/seq_executor_test.go b/executor/seqtest/seq_executor_test.go index c10ac5d426f38..391efeb5d43da 100644 --- a/executor/seqtest/seq_executor_test.go +++ b/executor/seqtest/seq_executor_test.go @@ -587,7 +587,7 @@ func (s *seqTestSuite) TestShow(c *C) { "c4|varchar(6)|YES||1|", "c5|varchar(6)|YES||'C6'|", "c6|enum('s','m','l','xl')|YES||xl|", - "c7|set('a','b','c','d')|YES||a,c,c|", + "c7|set('a','b','c','d')|YES||a,c|", "c8|datetime|YES||CURRENT_TIMESTAMP|DEFAULT_GENERATED on update CURRENT_TIMESTAMP", "c9|year(4)|YES||2014|", )) diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index 533195ce83c2f..1336055068cd9 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -614,7 +614,7 @@ func checkColumn(colDef *ast.ColumnDef) error { if len(tp.Elems) > mysql.MaxTypeSetMembers { return types.ErrTooBigSet.GenWithStack("Too many strings for column %s and SET", colDef.Name.Name.O) } - // Check set elements. See https://dev.mysql.com/doc/refman/5.7/en/set.html . + // Check set elements. See https://dev.mysql.com/doc/refman/5.7/en/set.html. for _, str := range colDef.Tp.Elems { if strings.Contains(str, ",") { return types.ErrIllegalValueForType.GenWithStackByArgs(types.TypeStr(tp.Tp), str)