diff --git a/executor/analyze_test.go b/executor/analyze_test.go index 8f862265c7e22..b83c677237a91 100644 --- a/executor/analyze_test.go +++ b/executor/analyze_test.go @@ -258,7 +258,7 @@ func (s *testSuite1) TestFastAnalyze(c *C) { c.Assert(err, IsNil) tableInfo := table.Meta() tbl := dom.StatsHandle().GetTableStats(tableInfo) - c.Assert(tbl.String(), Equals, "Table:43 Count:20\n"+ + c.Assert(tbl.String(), Equals, "Table:45 Count:20\n"+ "column:1 ndv:20 totColSize:0\n"+ "num: 6 lower_bound: 3 upper_bound: 15 repeats: 1\n"+ "num: 7 lower_bound: 18 upper_bound: 33 repeats: 1\n"+ diff --git a/executor/builder.go b/executor/builder.go index 08ca8272d387f..e70a9f2eaf982 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -776,6 +776,7 @@ func (b *executorBuilder) buildGrant(grant *ast.GrantStmt) Executor { Level: grant.Level, Users: grant.Users, WithGrant: grant.WithGrant, + TLSOptions: grant.TLSOptions, is: b.is, } return e diff --git a/executor/grant.go b/executor/grant.go index d4a6fac507771..122cf4930399d 100644 --- a/executor/grant.go +++ b/executor/grant.go @@ -15,6 +15,7 @@ package executor import ( "context" + "encoding/json" "fmt" "strings" @@ -24,8 +25,10 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/privilege/privileges" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/sqlexec" ) @@ -47,6 +50,7 @@ type GrantExec struct { Level *ast.GrantLevel Users []*ast.UserSpec WithGrant bool + TLSOptions []*ast.TLSOption is infoschema.InfoSchema done bool @@ -86,9 +90,16 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { } // If there is no privilege entry in corresponding table, insert a new one. - // DB scope: mysql.DB - // Table scope: mysql.Tables_priv - // Column scope: mysql.Columns_priv + // Global scope: mysql.global_priv + // DB scope: mysql.DB + // Table scope: mysql.Tables_priv + // Column scope: mysql.Columns_priv + if e.TLSOptions != nil { + err = checkAndInitGlobalPriv(e.ctx, user.User.Username, user.User.Hostname) + if err != nil { + return err + } + } switch e.Level.Level { case ast.GrantLevelDB: err := checkAndInitDBPriv(e.ctx, dbName, e.is, user.User.Username, user.User.Hostname) @@ -113,7 +124,11 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { } defer func() { e.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusInTrans, false) }() } - + // Grant global priv to user. + err = e.grantGlobalPriv(user) + if err != nil { + return err + } // Grant each priv to the user. for _, priv := range privs { if len(priv.Cols) > 0 { @@ -124,7 +139,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { return err } } - err := e.grantPriv(priv, user) + err := e.grantLevelPriv(priv, user) if err != nil { return err } @@ -134,6 +149,20 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { return nil } +// checkAndInitGlobalPriv checks if global scope privilege entry exists in mysql.global_priv. +// If not exists, insert a new one. +func checkAndInitGlobalPriv(ctx sessionctx.Context, user string, host string) error { + ok, err := globalPrivEntryExists(ctx, user, host) + if err != nil { + return err + } + if ok { + return nil + } + // Entry does not exist for user-host-db. Insert a new entry. + return initGlobalPrivEntry(ctx, user, host) +} + // checkAndInitDBPriv checks if DB scope privilege entry exists in mysql.DB. // If unexists, insert a new one. func checkAndInitDBPriv(ctx sessionctx.Context, dbName string, is infoschema.InfoSchema, user string, host string) error { @@ -190,6 +219,13 @@ func (e *GrantExec) checkAndInitColumnPriv(user string, host string, cols []*ast return nil } +// initGlobalPrivEntry inserts a new row into mysql.DB with empty privilege. +func initGlobalPrivEntry(ctx sessionctx.Context, user string, host string) error { + sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, PRIV) VALUES ('%s', '%s', '%s')`, mysql.SystemDB, mysql.GlobalPrivTable, host, user, "{}") + _, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql) + return err +} + // initDBPrivEntry inserts a new row into mysql.DB with empty privilege. func initDBPrivEntry(ctx sessionctx.Context, user string, host string, db string) error { sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB) VALUES ('%s', '%s', '%s')`, mysql.SystemDB, mysql.DBTable, host, user, db) @@ -211,25 +247,110 @@ func initColumnPrivEntry(ctx sessionctx.Context, user string, host string, db st return err } -// grantPriv grants priv to user in s.Level scope. -func (e *GrantExec) grantPriv(priv *ast.PrivElem, user *ast.UserSpec) error { +// grantGlobalPriv grants priv to user in global scope. +func (e *GrantExec) grantGlobalPriv(user *ast.UserSpec) error { + if len(e.TLSOptions) == 0 { + return nil + } + priv, err := tlsOption2GlobalPriv(e.TLSOptions) + if err != nil { + return errors.Trace(err) + } + sql := fmt.Sprintf(`UPDATE %s.%s SET PRIV = '%s' WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.GlobalPrivTable, priv, user.User.Username, user.User.Hostname) + _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql) + return err +} + +var emptyGP = privileges.GlobalPrivValue{SSLType: privileges.SslTypeNotSpecified} + +func tlsOption2GlobalPriv(tlsOptions []*ast.TLSOption) (priv []byte, err error) { + if len(tlsOptions) == 0 { + priv = []byte("{}") + return + } + dupSet := make(map[int]struct{}) + for _, opt := range tlsOptions { + if _, dup := dupSet[opt.Type]; dup { + var typeName string + switch opt.Type { + case ast.Cipher: + typeName = "CIPHER" + case ast.Issuer: + typeName = "ISSUER" + case ast.Subject: + typeName = "SUBJECT" + } + err = errors.Errorf("Duplicate require %s clause", typeName) + return + } + dupSet[opt.Type] = struct{}{} + } + gp := privileges.GlobalPrivValue{SSLType: privileges.SslTypeNotSpecified} + for _, tlsOpt := range tlsOptions { + switch tlsOpt.Type { + case ast.TslNone: + gp.SSLType = privileges.SslTypeNone + case ast.Ssl: + gp.SSLType = privileges.SslTypeAny + case ast.X509: + gp.SSLType = privileges.SslTypeX509 + case ast.Cipher: + gp.SSLType = privileges.SslTypeSpecified + if len(tlsOpt.Value) > 0 { + if _, ok := util.SupportCipher[tlsOpt.Value]; !ok { + err = errors.Errorf("Unsupported cipher suit: %s", tlsOpt.Value) + return + } + gp.SSLCipher = tlsOpt.Value + } + case ast.Issuer: + err = util.CheckSupportX509NameOneline(tlsOpt.Value) + if err != nil { + return + } + gp.SSLType = privileges.SslTypeSpecified + gp.X509Issuer = tlsOpt.Value + case ast.Subject: + err = util.CheckSupportX509NameOneline(tlsOpt.Value) + if err != nil { + return + } + gp.SSLType = privileges.SslTypeSpecified + gp.X509Subject = tlsOpt.Value + default: + err = errors.Errorf("Unknown ssl type: %#v", tlsOpt.Type) + return + } + } + if gp == emptyGP { + return + } + priv, err = json.Marshal(&gp) + if err != nil { + return + } + return +} + +// grantLevelPriv grants priv to user in s.Level scope. +func (e *GrantExec) grantLevelPriv(priv *ast.PrivElem, user *ast.UserSpec) error { switch e.Level.Level { case ast.GrantLevelGlobal: - return e.grantGlobalPriv(priv, user) + return e.grantGlobalLevel(priv, user) case ast.GrantLevelDB: - return e.grantDBPriv(priv, user) + return e.grantDBLevel(priv, user) case ast.GrantLevelTable: if len(priv.Cols) == 0 { - return e.grantTablePriv(priv, user) + return e.grantTableLevel(priv, user) } - return e.grantColumnPriv(priv, user) + return e.grantColumnLevel(priv, user) default: return errors.Errorf("Unknown grant level: %#v", e.Level) } } -// grantGlobalPriv manipulates mysql.user table. -func (e *GrantExec) grantGlobalPriv(priv *ast.PrivElem, user *ast.UserSpec) error { +// grantGlobalLevel manipulates mysql.user table. +func (e *GrantExec) grantGlobalLevel(priv *ast.PrivElem, user *ast.UserSpec) error { if priv.Priv == 0 { return nil } @@ -242,8 +363,8 @@ func (e *GrantExec) grantGlobalPriv(priv *ast.PrivElem, user *ast.UserSpec) erro return err } -// grantDBPriv manipulates mysql.db table. -func (e *GrantExec) grantDBPriv(priv *ast.PrivElem, user *ast.UserSpec) error { +// grantDBLevel manipulates mysql.db table. +func (e *GrantExec) grantDBLevel(priv *ast.PrivElem, user *ast.UserSpec) error { dbName := e.Level.DBName if len(dbName) == 0 { dbName = e.ctx.GetSessionVars().CurrentDB @@ -257,8 +378,8 @@ func (e *GrantExec) grantDBPriv(priv *ast.PrivElem, user *ast.UserSpec) error { return err } -// grantTablePriv manipulates mysql.tables_priv table. -func (e *GrantExec) grantTablePriv(priv *ast.PrivElem, user *ast.UserSpec) error { +// grantTableLevel manipulates mysql.tables_priv table. +func (e *GrantExec) grantTableLevel(priv *ast.PrivElem, user *ast.UserSpec) error { dbName := e.Level.DBName if len(dbName) == 0 { dbName = e.ctx.GetSessionVars().CurrentDB @@ -273,8 +394,8 @@ func (e *GrantExec) grantTablePriv(priv *ast.PrivElem, user *ast.UserSpec) error return err } -// grantColumnPriv manipulates mysql.tables_priv table. -func (e *GrantExec) grantColumnPriv(priv *ast.PrivElem, user *ast.UserSpec) error { +// grantColumnLevel manipulates mysql.tables_priv table. +func (e *GrantExec) grantColumnLevel(priv *ast.PrivElem, user *ast.UserSpec) error { dbName, tbl, err := getTargetSchemaAndTable(e.ctx, e.Level.DBName, e.Level.TableName, e.is) if err != nil { return err @@ -473,6 +594,12 @@ func recordExists(ctx sessionctx.Context, sql string) (bool, error) { return len(rows) > 0, nil } +// globalPrivEntryExists checks if there is an entry with key user-host in mysql.global_priv. +func globalPrivEntryExists(ctx sessionctx.Context, name string, host string) (bool, error) { + sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s';`, mysql.SystemDB, mysql.GlobalPrivTable, name, host) + return recordExists(ctx, sql) +} + // dbUserExists checks if there is an entry with key user-host-db in mysql.DB. func dbUserExists(ctx sessionctx.Context, name string, host string, db string) (bool, error) { sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s';`, mysql.SystemDB, mysql.DBTable, name, host, db) diff --git a/executor/grant_test.go b/executor/grant_test.go index f652ec37a1448..d5724a4cda089 100644 --- a/executor/grant_test.go +++ b/executor/grant_test.go @@ -237,3 +237,73 @@ func (s *testSuite3) TestGrantUnderANSIQuotes(c *C) { tk.MustExec(`REVOKE ALL PRIVILEGES ON video_ulimit.* FROM web@'%';`) tk.MustExec(`DROP USER IF EXISTS 'web'@'%'`) } + +func (s *testSuite3) TestMaintainRequire(c *C) { + tk := testkit.NewTestKit(c, s.store) + + // test create with require + tk.MustExec(`CREATE USER 'ssl_auser'@'%' require issuer '/CN=TiDB admin/OU=TiDB/O=PingCAP/L=San Francisco/ST=California/C=US' subject '/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH' cipher 'AES128-GCM-SHA256'`) + tk.MustExec(`CREATE USER 'ssl_buser'@'%' require subject '/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH' cipher 'AES128-GCM-SHA256'`) + tk.MustExec(`CREATE USER 'ssl_cuser'@'%' require cipher 'AES128-GCM-SHA256'`) + tk.MustExec(`CREATE USER 'ssl_duser'@'%'`) + tk.MustExec(`CREATE USER 'ssl_euser'@'%' require none`) + tk.MustExec(`CREATE USER 'ssl_fuser'@'%' require ssl`) + tk.MustExec(`CREATE USER 'ssl_guser'@'%' require x509`) + tk.MustQuery("select * from mysql.global_priv where `user` like 'ssl_%'").Check(testkit.Rows( + "% ssl_auser {\"ssl_type\":3,\"ssl_cipher\":\"AES128-GCM-SHA256\",\"x509_issuer\":\"/CN=TiDB admin/OU=TiDB/O=PingCAP/L=San Francisco/ST=California/C=US\",\"x509_subject\":\"/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH\"}", + "% ssl_buser {\"ssl_type\":3,\"ssl_cipher\":\"AES128-GCM-SHA256\",\"x509_subject\":\"/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH\"}", + "% ssl_cuser {\"ssl_type\":3,\"ssl_cipher\":\"AES128-GCM-SHA256\"}", + "% ssl_duser {}", + "% ssl_euser {}", + "% ssl_fuser {\"ssl_type\":1}", + "% ssl_guser {\"ssl_type\":2}", + )) + + // test grant with require + tk.MustExec("CREATE USER 'u1'@'%'") + tk.MustExec("GRANT ALL ON *.* TO 'u1'@'%' require issuer '/CN=TiDB admin/OU=TiDB/O=PingCAP/L=San Francisco/ST=California/C=US' and subject '/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH'") // add new require. + tk.MustQuery("select priv from mysql.global_priv where `Host` = '%' and `User` = 'u1'").Check(testkit.Rows("{\"ssl_type\":3,\"x509_issuer\":\"/CN=TiDB admin/OU=TiDB/O=PingCAP/L=San Francisco/ST=California/C=US\",\"x509_subject\":\"/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH\"}")) + tk.MustExec("GRANT ALL ON *.* TO 'u1'@'%' require cipher 'AES128-GCM-SHA256'") // modify always overwrite. + tk.MustQuery("select priv from mysql.global_priv where `Host` = '%' and `User` = 'u1'").Check(testkit.Rows("{\"ssl_type\":3,\"ssl_cipher\":\"AES128-GCM-SHA256\"}")) + tk.MustExec("GRANT select ON *.* TO 'u1'@'%'") // modify without require should not modify old require. + tk.MustQuery("select priv from mysql.global_priv where `Host` = '%' and `User` = 'u1'").Check(testkit.Rows("{\"ssl_type\":3,\"ssl_cipher\":\"AES128-GCM-SHA256\"}")) + tk.MustExec("GRANT ALL ON *.* TO 'u1'@'%' require none") // use require none to clean up require. + tk.MustQuery("select priv from mysql.global_priv where `Host` = '%' and `User` = 'u1'").Check(testkit.Rows("{}")) + + // test alter with require + tk.MustExec("CREATE USER 'u2'@'%'") + tk.MustExec("alter user 'u2'@'%' require ssl") + tk.MustQuery("select priv from mysql.global_priv where `Host` = '%' and `User` = 'u2'").Check(testkit.Rows("{\"ssl_type\":1}")) + tk.MustExec("alter user 'u2'@'%' require x509") + tk.MustQuery("select priv from mysql.global_priv where `Host` = '%' and `User` = 'u2'").Check(testkit.Rows("{\"ssl_type\":2}")) + tk.MustExec("alter user 'u2'@'%' require issuer '/CN=TiDB admin/OU=TiDB/O=PingCAP/L=San Francisco/ST=California/C=US' subject '/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH' cipher 'AES128-GCM-SHA256'") + tk.MustQuery("select priv from mysql.global_priv where `Host` = '%' and `User` = 'u2'").Check(testkit.Rows("{\"ssl_type\":3,\"ssl_cipher\":\"AES128-GCM-SHA256\",\"x509_issuer\":\"/CN=TiDB admin/OU=TiDB/O=PingCAP/L=San Francisco/ST=California/C=US\",\"x509_subject\":\"/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH\"}")) + tk.MustExec("alter user 'u2'@'%' require none") + tk.MustQuery("select priv from mysql.global_priv where `Host` = '%' and `User` = 'u2'").Check(testkit.Rows("{}")) + + // test show create user + tk.MustExec(`CREATE USER 'u3'@'%' require issuer '/CN=TiDB admin/OU=TiDB/O=PingCAP/L=San Francisco/ST=California/C=US' subject '/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH' cipher 'AES128-GCM-SHA256'`) + tk.MustQuery("show create user 'u3'").Check(testkit.Rows("CREATE USER 'u3'@'%' IDENTIFIED WITH 'mysql_native_password' AS '' REQUIRE CIPHER 'AES128-GCM-SHA256' ISSUER '/CN=TiDB admin/OU=TiDB/O=PingCAP/L=San Francisco/ST=California/C=US' SUBJECT '/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH' PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK")) + + // check issuer/subject/cipher value + _, err := tk.Exec(`CREATE USER 'u4'@'%' require issuer 'CN=TiDB,OU=PingCAP'`) + c.Assert(err, NotNil) + _, err = tk.Exec(`CREATE USER 'u5'@'%' require subject '/CN=TiDB\OU=PingCAP'`) + c.Assert(err, NotNil) + _, err = tk.Exec(`CREATE USER 'u6'@'%' require subject '/CN=TiDB\NC=PingCAP'`) + c.Assert(err, NotNil) + _, err = tk.Exec(`CREATE USER 'u7'@'%' require cipher 'AES128-GCM-SHA1'`) + c.Assert(err, NotNil) + _, err = tk.Exec(`CREATE USER 'u8'@'%' require subject '/CN'`) + c.Assert(err, NotNil) + _, err = tk.Exec(`CREATE USER 'u9'@'%' require cipher 'TLS_AES_256_GCM_SHA384' cipher 'RC4-SHA'`) + c.Assert(err.Error(), Equals, "Duplicate require CIPHER clause") + _, err = tk.Exec(`CREATE USER 'u9'@'%' require issuer 'CN=TiDB,OU=PingCAP' issuer 'CN=TiDB,OU=PingCAP2'`) + c.Assert(err.Error(), Equals, "Duplicate require ISSUER clause") + _, err = tk.Exec(`CREATE USER 'u9'@'%' require subject '/CN=TiDB\OU=PingCAP' subject '/CN=TiDB\OU=PingCAP2'`) + c.Assert(err.Error(), Equals, "Duplicate require SUBJECT clause") + _, err = tk.Exec(`CREATE USER 'u9'@'%' require ssl ssl`) + c.Assert(err, NotNil) + _, err = tk.Exec(`CREATE USER 'u9'@'%' require x509 x509`) + c.Assert(err, NotNil) +} diff --git a/executor/show.go b/executor/show.go index b96864c9811eb..98d5b5d431df8 100644 --- a/executor/show.go +++ b/executor/show.go @@ -16,6 +16,7 @@ package executor import ( "bytes" "context" + gjson "encoding/json" "fmt" "sort" "strconv" @@ -41,6 +42,7 @@ import ( plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/plugin" "github.com/pingcap/tidb/privilege" + "github.com/pingcap/tidb/privilege/privileges" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/tikv" @@ -49,6 +51,7 @@ import ( "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/format" + "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/sqlexec" ) @@ -984,8 +987,24 @@ func (e *ShowExec) fetchShowCreateUser() error { return ErrCannotUser.GenWithStackByArgs("SHOW CREATE USER", fmt.Sprintf("'%s'@'%s'", e.User.Username, e.User.Hostname)) } - showStr := fmt.Sprintf("CREATE USER '%s'@'%s' IDENTIFIED WITH 'mysql_native_password' AS '%s' REQUIRE NONE PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK", - e.User.Username, e.User.Hostname, checker.GetEncodedPassword(e.User.Username, e.User.Hostname)) + sql = fmt.Sprintf(`SELECT PRIV FROM %s.%s WHERE User='%s' AND Host='%s'`, + mysql.SystemDB, mysql.GlobalPrivTable, userName, hostName) + rows, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql) + if err != nil { + return errors.Trace(err) + } + require := "NONE" + if len(rows) == 1 { + privData := rows[0].GetString(0) + var privValue privileges.GlobalPrivValue + err = gjson.Unmarshal(hack.Slice(privData), &privValue) + if err != nil { + return errors.Trace(err) + } + require = privValue.RequireStr() + } + showStr := fmt.Sprintf("CREATE USER '%s'@'%s' IDENTIFIED WITH 'mysql_native_password' AS '%s' REQUIRE %s PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK", + e.User.Username, e.User.Hostname, checker.GetEncodedPassword(e.User.Username, e.User.Hostname), require) e.appendRow([]interface{}{showStr}) return nil } diff --git a/executor/simple.go b/executor/simple.go index 2242cc1916e7b..2ad86d9bce0a9 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -38,6 +38,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/sqlexec" "go.uber.org/zap" @@ -659,7 +660,13 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm } } + privData, err := tlsOption2GlobalPriv(s.TLSOptions) + if err != nil { + return err + } + users := make([]string, 0, len(s.Specs)) + privs := make([]string, 0, len(s.Specs)) for _, spec := range s.Specs { exists, err1 := userExists(e.ctx, spec.User.Username, spec.User.Hostname) if err1 != nil { @@ -683,6 +690,11 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm user = fmt.Sprintf(`('%s', '%s', '%s', 'Y')`, spec.User.Hostname, spec.User.Username, pwd) } users = append(users, user) + + if len(privData) != 0 { + priv := fmt.Sprintf(`('%s', '%s', '%s')`, spec.User.Hostname, spec.User.Username, hack.String(privData)) + privs = append(privs, priv) + } } if len(users) == 0 { return nil @@ -692,10 +704,37 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm if s.IsCreateRole { sql = fmt.Sprintf(`INSERT INTO %s.%s (Host, User, Password, Account_locked) VALUES %s;`, mysql.SystemDB, mysql.UserTable, strings.Join(users, ", ")) } - _, _, err := e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql) + + restrictedCtx, err := e.getSysSession() if err != nil { return err } + defer e.releaseSysSession(restrictedCtx) + sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) + + if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + return errors.Trace(err) + } + _, err = sqlExecutor.Execute(context.Background(), sql) + if err != nil { + if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + return rollbackErr + } + return err + } + if len(privs) != 0 { + sql = fmt.Sprintf("INSERT IGNORE INTO %s.%s (Host, User, Priv) VALUES %s", mysql.SystemDB, mysql.GlobalPrivTable, strings.Join(privs, ", ")) + _, err = sqlExecutor.Execute(context.Background(), sql) + if err != nil { + if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + return rollbackErr + } + return err + } + } + if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + return errors.Trace(err) + } domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) return err } @@ -713,6 +752,11 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error { s.Specs = []*ast.UserSpec{spec} } + privData, err := tlsOption2GlobalPriv(s.TLSOptions) + if err != nil { + return err + } + failedUsers := make([]string, 0, len(s.Specs)) for _, spec := range s.Specs { exists, err := userExists(e.ctx, spec.User.Username, spec.User.Hostname) @@ -740,6 +784,15 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error { if err != nil { failedUsers = append(failedUsers, spec.User.String()) } + + if len(privData) > 0 { + sql = fmt.Sprintf("INSERT INTO %s.%s (Host, User, Priv) VALUES ('%s','%s','%s') ON DUPLICATE KEY UPDATE Priv = values(Priv)", + mysql.SystemDB, mysql.GlobalPrivTable, spec.User.Hostname, spec.User.Username, hack.String(privData)) + _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql) + if err != nil { + failedUsers = append(failedUsers, spec.User.String()) + } + } } if len(failedUsers) > 0 { // Commit the transaction even if we returns error @@ -873,6 +926,16 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { break } + // delete privileges from mysql.global_priv + sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.GlobalPrivTable, user.Hostname, user.Username) + if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + failedUsers = append(failedUsers, user.String()) + if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + return err + } + continue + } + // delete privileges from mysql.db sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.DBTable, user.Hostname, user.Username) if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { diff --git a/go.mod b/go.mod index 145cbc9930c51..fc51faf5b87c9 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,7 @@ require ( github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e github.com/pingcap/kvproto v0.0.0-20191106014506-c5d88d699a8d github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd - github.com/pingcap/parser v0.0.0-20191219054832-7df8c2c0e634 + github.com/pingcap/parser v0.0.0-20191224043251-93f4d5ec2623 github.com/pingcap/pd v1.1.0-beta.0.20190912093418-dc03c839debd github.com/pingcap/tidb-tools v3.0.6-0.20191119150227-ff0a3c6e5763+incompatible github.com/pingcap/tipb v0.0.0-20191120045257-1b9900292ab6 diff --git a/go.sum b/go.sum index e7eed589b6bc3..e1a3f3b79c772 100644 --- a/go.sum +++ b/go.sum @@ -153,8 +153,8 @@ github.com/pingcap/kvproto v0.0.0-20191106014506-c5d88d699a8d h1:zTHgLr8+0LTEJmj github.com/pingcap/kvproto v0.0.0-20191106014506-c5d88d699a8d/go.mod h1:QMdbTAXCHzzygQzqcG9uVUgU2fKeSN1GmfMiykdSzzY= github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd h1:hWDol43WY5PGhsh3+8794bFHY1bPrmu6bTalpssCrGg= github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd/go.mod h1:WpHUKhNZ18v116SvGrmjkA9CBhYmuUTKL+p8JC9ANEw= -github.com/pingcap/parser v0.0.0-20191219054832-7df8c2c0e634 h1:2gAX2u/8VTLox9qF138DwqkZTJD2oGOwbGhTcZHR8lQ= -github.com/pingcap/parser v0.0.0-20191219054832-7df8c2c0e634/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= +github.com/pingcap/parser v0.0.0-20191224043251-93f4d5ec2623 h1:/BJjVyJlNKWMMrgPsbzk5Y9VPJWwHKYttj3oWxnFQ9U= +github.com/pingcap/parser v0.0.0-20191224043251-93f4d5ec2623/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= github.com/pingcap/pd v1.1.0-beta.0.20190912093418-dc03c839debd h1:bKj6hodu/ro78B0oN2yicdGn0t4yd9XjnyoW95qmWic= github.com/pingcap/pd v1.1.0-beta.0.20190912093418-dc03c839debd/go.mod h1:I7TEby5BHTYIxgHszfsOJSBsk8b2Qt8QrSIgdv5n5QQ= github.com/pingcap/tidb-tools v3.0.6-0.20191119150227-ff0a3c6e5763+incompatible h1:I8HirWsu1MZp6t9G/g8yKCEjJJxtHooKakEgccvdJ4M= @@ -181,7 +181,6 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d h1:GoAlyOgbOEIFd github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/remyoudompheng/bigfft v0.0.0-20190512091148-babf20351dd7 h1:FUL3b97ZY2EPqg2NbXKuMHs5pXJB9hjj1fDHnF2vl28= github.com/remyoudompheng/bigfft v0.0.0-20190512091148-babf20351dd7/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/sergi/go-diff v1.0.1-0.20180205163309-da645544ed44 h1:tB9NOR21++IjLyVx3/PCPhWMwqGNCMQEH96A6dMZ/gc= github.com/sergi/go-diff v1.0.1-0.20180205163309-da645544ed44/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shirou/gopsutil v2.18.10+incompatible h1:cy84jW6EVRPa5g9HAHrlbxMSIjBhDSX0OFYyMYminYs= github.com/shirou/gopsutil v2.18.10+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= diff --git a/privilege/privilege.go b/privilege/privilege.go index 9362239618441..7e1655e151a23 100644 --- a/privilege/privilege.go +++ b/privilege/privilege.go @@ -14,6 +14,7 @@ package privilege import ( + "crypto/tls" "github.com/pingcap/parser/auth" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx" @@ -45,7 +46,7 @@ type Manager interface { RequestVerificationWithUser(db, table, column string, priv mysql.PrivilegeType, user *auth.UserIdentity) bool // ConnectionVerification verifies user privilege for connection. - ConnectionVerification(user, host string, auth, salt []byte) (string, string, bool) + ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) (string, string, bool) // DBIsVisible returns true is the database is visible to current user. DBIsVisible(activeRole []*auth.RoleIdentity, db string) bool diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index f82f50e7dbbed..ec17da7e89ca0 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -15,7 +15,9 @@ package privileges import ( "context" + "encoding/json" "fmt" + "go.uber.org/zap" "sort" "strings" "sync/atomic" @@ -29,6 +31,8 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/hack" + "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tidb/util/stringutil" log "github.com/sirupsen/logrus" @@ -62,6 +66,71 @@ type UserRecord struct { patTypes []byte } +type globalPrivRecord struct { + Host string + User string + Priv GlobalPrivValue + Broken bool + + // patChars is compiled from Host, cached for pattern match performance. + patChars []byte + patTypes []byte +} + +// SSLType is enum value for GlobalPrivValue.SSLType. +// the value is compatible with MySQL storage json value. +type SSLType int + +const ( + // SslTypeNotSpecified indicates . + SslTypeNotSpecified SSLType = iota - 1 + // SslTypeNone indicates not require use ssl. + SslTypeNone + // SslTypeAny indicates require use ssl but not validate cert. + SslTypeAny + // SslTypeX509 indicates require use ssl and validate cert. + SslTypeX509 + // SslTypeSpecified indicates require use ssl and validate cert's subject or issuer. + SslTypeSpecified +) + +// GlobalPrivValue is store json format for priv column in mysql.global_priv. +type GlobalPrivValue struct { + SSLType SSLType `json:"ssl_type,omitempty"` + SSLCipher string `json:"ssl_cipher,omitempty"` + X509Issuer string `json:"x509_issuer,omitempty"` + X509Subject string `json:"x509_subject,omitempty"` +} + +// RequireStr returns describe string after `REQUIRE` clause. +func (g *GlobalPrivValue) RequireStr() string { + require := "NONE" + switch g.SSLType { + case SslTypeAny: + require = "SSL" + case SslTypeX509: + require = "X509" + case SslTypeSpecified: + var s []string + if len(g.SSLCipher) > 0 { + s = append(s, "CIPHER") + s = append(s, "'"+g.SSLCipher+"'") + } + if len(g.X509Issuer) > 0 { + s = append(s, "ISSUER") + s = append(s, "'"+g.X509Issuer+"'") + } + if len(g.X509Subject) > 0 { + s = append(s, "SUBJECT") + s = append(s, "'"+g.X509Subject+"'") + } + if len(s) > 0 { + require = strings.Join(s, " ") + } + } + return require +} + type dbRecord struct { Host string DB string @@ -138,6 +207,7 @@ func (g roleGraphEdgesTable) Find(user, host string) bool { // MySQLPrivilege is the in-memory cache of mysql privilege tables. type MySQLPrivilege struct { User []UserRecord + Global map[string][]globalPrivRecord DB []dbRecord TablesPriv []tablesPrivRecord ColumnsPriv []columnsPrivRecord @@ -190,6 +260,11 @@ func (p *MySQLPrivilege) LoadAll(ctx sessionctx.Context) error { return errors.Trace(err) } + err = p.LoadGlobalPrivTable(ctx) + if err != nil { + return errors.Trace(err) + } + err = p.LoadDBTable(ctx) if err != nil { if !noSuchTable(err) { @@ -354,6 +429,11 @@ func (p MySQLPrivilege) SortUserTable() { sort.Sort(sortedUserRecord(p.User)) } +// LoadGlobalPrivTable loads the mysql.global_priv table from database. +func (p *MySQLPrivilege) LoadGlobalPrivTable(ctx sessionctx.Context) error { + return p.loadTable(ctx, "select HIGH_PRIORITY Host,User,Priv from mysql.global_priv", p.decodeGlobalPrivTableRow) +} + // LoadDBTable loads the mysql.db table from database. func (p *MySQLPrivilege) LoadDBTable(ctx sessionctx.Context) error { return p.loadTable(ctx, "select HIGH_PRIORITY Host,DB,User,Select_priv,Insert_priv,Update_priv,Delete_priv,Create_priv,Drop_priv,Grant_priv,Index_priv,Alter_priv,Execute_priv,Create_view_priv,Show_view_priv from mysql.db order by host, db, user;", p.decodeDBTableRow) @@ -438,6 +518,40 @@ func (p *MySQLPrivilege) decodeUserTableRow(row chunk.Row, fs []*ast.ResultField return nil } +func (p *MySQLPrivilege) decodeGlobalPrivTableRow(row chunk.Row, fs []*ast.ResultField) error { + var value globalPrivRecord + for i, f := range fs { + switch { + case f.ColumnAsName.L == "host": + value.Host = row.GetString(i) + value.patChars, value.patTypes = stringutil.CompilePattern(value.Host, '\\') + case f.ColumnAsName.L == "user": + value.User = row.GetString(i) + case f.ColumnAsName.L == "priv": + privData := row.GetString(i) + if len(privData) > 0 { + var privValue GlobalPrivValue + err := json.Unmarshal(hack.Slice(privData), &privValue) + if err != nil { + logutil.Logger(context.Background()).Error("one user global priv data is broken, forbidden login until data be fixed", + zap.String("user", value.User), zap.String("host", value.Host)) + value.Broken = true + } else { + value.Priv.SSLType = privValue.SSLType + value.Priv.SSLCipher = privValue.SSLCipher + value.Priv.X509Issuer = privValue.X509Issuer + value.Priv.X509Subject = privValue.X509Subject + } + } + } + } + if p.Global == nil { + p.Global = make(map[string][]globalPrivRecord) + } + p.Global[value.User] = append(p.Global[value.User], value) + return nil +} + func (p *MySQLPrivilege) decodeDBTableRow(row chunk.Row, fs []*ast.ResultField) error { var value dbRecord for i, f := range fs { @@ -577,6 +691,10 @@ func decodeSetToPrivilege(s types.Set) mysql.PrivilegeType { return ret } +func (record *globalPrivRecord) match(user, host string) bool { + return record.User == user && patternMatch(host, record.patChars, record.patTypes) +} + func (record *UserRecord) match(user, host string) bool { return record.User == user && patternMatch(host, record.patChars, record.patTypes) } @@ -620,6 +738,20 @@ func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord { return nil } +func (p *MySQLPrivilege) matchGlobalPriv(user, host string) *globalPrivRecord { + uGlobal, exists := p.Global[user] + if !exists { + return nil + } + for i := 0; i < len(uGlobal); i++ { + record := &uGlobal[i] + if record.match(user, host) { + return record + } + } + return nil +} + func (p *MySQLPrivilege) matchUser(user, host string) *UserRecord { for i := 0; i < len(p.User); i++ { record := &p.User[i] diff --git a/privilege/privileges/cache_test.go b/privilege/privileges/cache_test.go index 3681c63f6f168..5f6ead6090891 100644 --- a/privilege/privileges/cache_test.go +++ b/privilege/privileges/cache_test.go @@ -76,6 +76,25 @@ func (s *testCacheSuite) TestLoadUserTable(c *C) { c.Assert(user[3].Privileges, Equals, mysql.CreateUserPriv|mysql.IndexPriv|mysql.ExecutePriv|mysql.CreateViewPriv|mysql.ShowViewPriv|mysql.ShowDBPriv|mysql.SuperPriv|mysql.TriggerPriv) } +func (s *testCacheSuite) TestLoadGlobalPrivTable(c *C) { + se, err := session.CreateSession4Test(s.store) + c.Assert(err, IsNil) + defer se.Close() + mustExec(c, se, "use mysql;") + mustExec(c, se, "truncate table global_priv") + + mustExec(c, se, `INSERT INTO mysql.global_priv VALUES ("%", "tu", "{\"access\":0,\"plugin\":\"mysql_native_password\",\"ssl_type\":3, \"ssl_cipher\":\"cipher\",\"x509_subject\":\"\C=ZH1\", \"x509_issuer\":\"\C=ZH2\", \"password_last_changed\":1}")`) + + var p privileges.MySQLPrivilege + err = p.LoadGlobalPrivTable(se) + c.Assert(err, IsNil) + c.Assert(p.Global["tu"][0].Host, Equals, `%`) + c.Assert(p.Global["tu"][0].User, Equals, `tu`) + c.Assert(p.Global["tu"][0].Priv.SSLType, Equals, privileges.SslTypeSpecified) + c.Assert(p.Global["tu"][0].Priv.X509Issuer, Equals, "C=ZH2") + c.Assert(p.Global["tu"][0].Priv.X509Subject, Equals, "C=ZH1") +} + func (s *testCacheSuite) TestLoadDBTable(c *C) { se, err := session.CreateSession4Test(s.store) c.Assert(err, IsNil) @@ -397,6 +416,25 @@ func (s *testCacheSuite) TestSortUserTable(c *C) { checkUserRecord(p.User, result, c) } +func (s *testCacheSuite) TestGlobalPrivValueRequireStr(c *C) { + var ( + none = privileges.GlobalPrivValue{SSLType: privileges.SslTypeNone} + tls = privileges.GlobalPrivValue{SSLType: privileges.SslTypeAny} + x509 = privileges.GlobalPrivValue{SSLType: privileges.SslTypeX509} + spec = privileges.GlobalPrivValue{SSLType: privileges.SslTypeSpecified, SSLCipher: "c1", X509Subject: "s1", X509Issuer: "i1"} + spec2 = privileges.GlobalPrivValue{SSLType: privileges.SslTypeSpecified, X509Subject: "s1", X509Issuer: "i1"} + spec3 = privileges.GlobalPrivValue{SSLType: privileges.SslTypeSpecified, X509Issuer: "i1"} + spec4 = privileges.GlobalPrivValue{SSLType: privileges.SslTypeSpecified} + ) + c.Assert(none.RequireStr(), Equals, "NONE") + c.Assert(tls.RequireStr(), Equals, "SSL") + c.Assert(x509.RequireStr(), Equals, "X509") + c.Assert(spec.RequireStr(), Equals, "CIPHER 'c1' ISSUER 'i1' SUBJECT 's1'") + c.Assert(spec2.RequireStr(), Equals, "ISSUER 'i1' SUBJECT 's1'") + c.Assert(spec3.RequireStr(), Equals, "ISSUER 'i1'") + c.Assert(spec4.RequireStr(), Equals, "NONE") +} + func checkUserRecord(x, y []privileges.UserRecord, c *C) { c.Assert(len(x), Equals, len(y)) for i := 0; i < len(x); i++ { diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index f6477fc001cf6..0aba853727e6f 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -15,6 +15,8 @@ package privileges import ( "context" + "crypto/tls" + "fmt" "strings" "github.com/pingcap/parser/auth" @@ -22,6 +24,7 @@ import ( "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" ) @@ -102,7 +105,7 @@ func (p *UserPrivileges) GetEncodedPassword(user, host string) string { } // ConnectionVerification implements the Manager interface. -func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte) (u string, h string, success bool) { +func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (u string, h string, success bool) { if SkipWithGrant { p.user = user p.host = host @@ -121,6 +124,16 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio u = record.User h = record.Host + globalPriv := mysqlPriv.matchGlobalPriv(user, host) + if globalPriv != nil { + if !p.checkSSL(globalPriv, tlsState) { + logutil.Logger(context.Background()).Error("global priv check ssl fail", + zap.String("user", user), zap.String("host", host)) + success = false + return + } + } + // Login a locked account is not allowed. locked := record.AccountLocked if locked { @@ -164,6 +177,102 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio return } +type checkResult int + +const ( + notCheck checkResult = iota + pass + fail +) + +func (p *UserPrivileges) checkSSL(priv *globalPrivRecord, tlsState *tls.ConnectionState) bool { + if priv.Broken { + logutil.Logger(context.Background()).Debug("ssl check failure, due to broken global_priv record", + zap.String("user", priv.User), zap.String("host", priv.Host)) + return false + } + switch priv.Priv.SSLType { + case SslTypeNotSpecified, SslTypeNone: + return true + case SslTypeAny: + r := tlsState != nil + if !r { + logutil.Logger(context.Background()).Debug("ssl check failure, require ssl but not use ssl", + zap.String("user", priv.User), zap.String("host", priv.Host)) + } + return r + case SslTypeX509: + if tlsState == nil { + logutil.Logger(context.Background()).Debug("ssl check failure, require x509 but not use ssl", + zap.String("user", priv.User), zap.String("host", priv.Host)) + return false + } + hasCert := false + for _, chain := range tlsState.VerifiedChains { + if len(chain) > 0 { + hasCert = true + break + } + } + if !hasCert { + logutil.Logger(context.Background()).Debug("ssl check failure, require x509 but no verified cert", + zap.String("user", priv.User), zap.String("host", priv.Host)) + } + return hasCert + case SslTypeSpecified: + if tlsState == nil { + logutil.Logger(context.Background()).Debug("ssl check failure, require subject/issuer/cipher but not use ssl", + zap.String("user", priv.User), zap.String("host", priv.Host)) + return false + } + if len(priv.Priv.SSLCipher) > 0 && priv.Priv.SSLCipher != util.TLSCipher2String(tlsState.CipherSuite) { + logutil.Logger(context.Background()).Debug("ssl check failure for cipher", zap.String("user", priv.User), zap.String("host", priv.Host), + zap.String("require", priv.Priv.SSLCipher), zap.String("given", util.TLSCipher2String(tlsState.CipherSuite))) + return false + } + var ( + hasCert = false + matchIssuer checkResult + matchSubject checkResult + ) + for _, chain := range tlsState.VerifiedChains { + if len(chain) == 0 { + continue + } + cert := chain[0] + if len(priv.Priv.X509Issuer) > 0 { + given := util.X509NameOnline(cert.Issuer) + if priv.Priv.X509Issuer == given { + matchIssuer = pass + } else if matchIssuer == notCheck { + matchIssuer = fail + logutil.Logger(context.Background()).Debug("ssl check failure for issuer", zap.String("user", priv.User), zap.String("host", priv.Host), + zap.String("require", priv.Priv.X509Issuer), zap.String("given", given)) + } + } + if len(priv.Priv.X509Subject) > 0 { + given := util.X509NameOnline(cert.Subject) + if priv.Priv.X509Subject == given { + matchSubject = pass + } else if matchSubject == notCheck { + matchSubject = fail + logutil.Logger(context.Background()).Debug("ssl check failure for subject", zap.String("user", priv.User), zap.String("host", priv.Host), + zap.String("require", priv.Priv.X509Subject), zap.String("given", given)) + } + } + hasCert = true + } + checkResult := hasCert && matchIssuer != fail && matchSubject != fail + if !checkResult && !hasCert { + logutil.Logger(context.Background()).Debug("ssl check failure, require issuer/subject but no verified cert", + zap.String("user", priv.User), zap.String("host", priv.Host)) + } + return checkResult + default: + panic(fmt.Sprintf("support ssl_type: %d", priv.Priv.SSLType)) + } +} + // DBIsVisible implements the Manager interface. func (p *UserPrivileges) DBIsVisible(activeRoles []*auth.RoleIdentity, db string) bool { if SkipWithGrant { diff --git a/privilege/privileges/privileges_test.go b/privilege/privileges/privileges_test.go index a4c8961d2e57d..3245f2b95d679 100644 --- a/privilege/privileges/privileges_test.go +++ b/privilege/privileges/privileges_test.go @@ -16,6 +16,9 @@ package privileges_test import ( "bytes" "context" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "fmt" "strings" "testing" @@ -31,6 +34,7 @@ import ( "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/store/mockstore" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testleak" "github.com/pingcap/tidb/util/testutil" @@ -446,17 +450,239 @@ func (s *testPrivilegeSuite) TestSelectViewSecurity(c *C) { } func (s *testPrivilegeSuite) TestRoleAdminSecurity(c *C) { + se := newSession(c, s.store, s.dbName) + mustExec(c, se, `CREATE USER 'ar1'@'localhost';`) + mustExec(c, se, `CREATE USER 'ar2'@'localhost';`) + mustExec(c, se, `GRANT ALL ON *.* to ar1@localhost`) + defer func() { + c.Assert(se.Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil), IsTrue) + mustExec(c, se, "drop user 'ar1'@'localhost'") + mustExec(c, se, "drop user 'ar2'@'localhost'") + }() + + c.Assert(se.Auth(&auth.UserIdentity{Username: "ar1", Hostname: "localhost"}, nil, nil), IsTrue) + mustExec(c, se, `create role r_test1@localhost`) + + c.Assert(se.Auth(&auth.UserIdentity{Username: "ar2", Hostname: "localhost"}, nil, nil), IsTrue) + _, err := se.Execute(context.Background(), `create role r_test2@localhost`) + c.Assert(terror.ErrorEqual(err, core.ErrSpecificAccessDenied), IsTrue) +} + +func (s *testPrivilegeSuite) TestCheckCertBasedAuth(c *C) { se := newSession(c, s.store, s.dbName) mustExec(c, se, `CREATE USER 'r1'@'localhost';`) - mustExec(c, se, `CREATE USER 'r2'@'localhost';`) - mustExec(c, se, `GRANT ALL ON *.* to r1@localhost`) + mustExec(c, se, `CREATE USER 'r2'@'localhost' require none;`) + mustExec(c, se, `CREATE USER 'r3'@'localhost' require ssl;`) + mustExec(c, se, `CREATE USER 'r4'@'localhost' require x509;`) + mustExec(c, se, `CREATE USER 'r5'@'localhost' require issuer '/C=US/ST=California/L=San Francisco/O=PingCAP/OU=TiDB/CN=TiDB admin' + subject '/C=ZH/ST=Beijing/L=Haidian/O=PingCAP.Inc/OU=TiDB/CN=tester1' cipher 'TLS_AES_128_GCM_SHA256'`) + mustExec(c, se, `CREATE USER 'r6'@'localhost' require issuer '/C=US/ST=California/L=San Francisco/O=PingCAP/OU=TiDB/CN=TiDB admin' + subject '/C=ZH/ST=Beijing/L=Haidian/O=PingCAP.Inc/OU=TiDB/CN=tester1'`) + mustExec(c, se, `CREATE USER 'r7_issuer_only'@'localhost' require issuer '/C=US/ST=California/L=San Francisco/O=PingCAP/OU=TiDB/CN=TiDB admin'`) + mustExec(c, se, `CREATE USER 'r8_subject_only'@'localhost' require subject '/C=ZH/ST=Beijing/L=Haidian/O=PingCAP.Inc/OU=TiDB/CN=tester1'`) + mustExec(c, se, `CREATE USER 'r9_subject_disorder'@'localhost' require subject '/ST=Beijing/C=ZH/L=Haidian/O=PingCAP.Inc/OU=TiDB/CN=tester1'`) + mustExec(c, se, `CREATE USER 'r10_issuer_disorder'@'localhost' require issuer '/ST=California/C=US/L=San Francisco/O=PingCAP/OU=TiDB/CN=TiDB admin'`) + mustExec(c, se, `CREATE USER 'r11_cipher_only'@'localhost' require cipher 'TLS_AES_256_GCM_SHA384'`) + mustExec(c, se, `CREATE USER 'r12_old_tidb_user'@'localhost'`) + mustExec(c, se, "DELETE FROM mysql.global_priv WHERE `user` = 'r12_old_tidb_user' and `host` = 'localhost'") + mustExec(c, se, `CREATE USER 'r13_broken_user'@'localhost'require issuer '/C=US/ST=California/L=San Francisco/O=PingCAP/OU=TiDB/CN=TiDB admin' + subject '/C=ZH/ST=Beijing/L=Haidian/O=PingCAP.Inc/OU=TiDB/CN=tester1'`) + mustExec(c, se, "UPDATE mysql.global_priv set priv = 'abc' where `user` = 'r13_broken_user' and `host` = 'localhost'") + mustExec(c, se, "flush privileges") + + defer func() { + c.Assert(se.Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil), IsTrue) + mustExec(c, se, "drop user 'r1'@'localhost'") + mustExec(c, se, "drop user 'r2'@'localhost'") + mustExec(c, se, "drop user 'r3'@'localhost'") + mustExec(c, se, "drop user 'r4'@'localhost'") + mustExec(c, se, "drop user 'r5'@'localhost'") + mustExec(c, se, "drop user 'r6'@'localhost'") + mustExec(c, se, "drop user 'r7_issuer_only'@'localhost'") + mustExec(c, se, "drop user 'r8_subject_only'@'localhost'") + mustExec(c, se, "drop user 'r9_subject_disorder'@'localhost'") + mustExec(c, se, "drop user 'r10_issuer_disorder'@'localhost'") + mustExec(c, se, "drop user 'r11_cipher_only'@'localhost'") + mustExec(c, se, "drop user 'r12_old_tidb_user'@'localhost'") + mustExec(c, se, "drop user 'r13_broken_user'@'localhost'") + }() + + // test without ssl or ca + c.Assert(se.Auth(&auth.UserIdentity{Username: "r1", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r2", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r3", Hostname: "localhost"}, nil, nil), IsFalse) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r4", Hostname: "localhost"}, nil, nil), IsFalse) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r5", Hostname: "localhost"}, nil, nil), IsFalse) + // test use ssl without ca + se.GetSessionVars().TLSConnectionState = &tls.ConnectionState{VerifiedChains: nil} c.Assert(se.Auth(&auth.UserIdentity{Username: "r1", Hostname: "localhost"}, nil, nil), IsTrue) - mustExec(c, se, `create role r_test1@localhost`) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r2", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r3", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r4", Hostname: "localhost"}, nil, nil), IsFalse) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r5", Hostname: "localhost"}, nil, nil), IsFalse) + // test use ssl with signed but info wrong ca. + se.GetSessionVars().TLSConnectionState = &tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{{{}}}} + c.Assert(se.Auth(&auth.UserIdentity{Username: "r1", Hostname: "localhost"}, nil, nil), IsTrue) c.Assert(se.Auth(&auth.UserIdentity{Username: "r2", Hostname: "localhost"}, nil, nil), IsTrue) - _, err := se.Execute(context.Background(), `create role r_test2@localhost`) - c.Assert(terror.ErrorEqual(err, core.ErrSpecificAccessDenied), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r3", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r4", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r5", Hostname: "localhost"}, nil, nil), IsFalse) + + // test a all pass case + se.GetSessionVars().TLSConnectionState = connectionState( + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "US"), + util.MockPkixAttribute(util.Province, "California"), + util.MockPkixAttribute(util.Locality, "San Francisco"), + util.MockPkixAttribute(util.Organization, "PingCAP"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "TiDB admin"), + }, + }, + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "ZH"), + util.MockPkixAttribute(util.Province, "Beijing"), + util.MockPkixAttribute(util.Locality, "Haidian"), + util.MockPkixAttribute(util.Organization, "PingCAP.Inc"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "tester1"), + }, + }, + tls.TLS_AES_128_GCM_SHA256) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r1", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r2", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r3", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r4", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r5", Hostname: "localhost"}, nil, nil), IsTrue) + + // test require but give nothing + se.GetSessionVars().TLSConnectionState = nil + c.Assert(se.Auth(&auth.UserIdentity{Username: "r5", Hostname: "localhost"}, nil, nil), IsFalse) + + // test mismatch cipher + se.GetSessionVars().TLSConnectionState = connectionState( + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "US"), + util.MockPkixAttribute(util.Province, "California"), + util.MockPkixAttribute(util.Locality, "San Francisco"), + util.MockPkixAttribute(util.Organization, "PingCAP"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "TiDB admin"), + }, + }, + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "ZH"), + util.MockPkixAttribute(util.Province, "Beijing"), + util.MockPkixAttribute(util.Locality, "Haidian"), + util.MockPkixAttribute(util.Organization, "PingCAP.Inc"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "tester1"), + }, + }, + tls.TLS_AES_256_GCM_SHA384) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r5", Hostname: "localhost"}, nil, nil), IsFalse) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r6", Hostname: "localhost"}, nil, nil), IsTrue) // not require cipher + c.Assert(se.Auth(&auth.UserIdentity{Username: "r11_cipher_only", Hostname: "localhost"}, nil, nil), IsTrue) + + // test only subject or only issuer + se.GetSessionVars().TLSConnectionState = connectionState( + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "US"), + util.MockPkixAttribute(util.Province, "California"), + util.MockPkixAttribute(util.Locality, "San Francisco"), + util.MockPkixAttribute(util.Organization, "PingCAP"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "TiDB admin"), + }, + }, + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "AZ"), + util.MockPkixAttribute(util.Province, "Beijing"), + util.MockPkixAttribute(util.Locality, "Shijingshang"), + util.MockPkixAttribute(util.Organization, "CAPPing.Inc"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "tester2"), + }, + }, + tls.TLS_AES_128_GCM_SHA256) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r7_issuer_only", Hostname: "localhost"}, nil, nil), IsTrue) + se.GetSessionVars().TLSConnectionState = connectionState( + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "AU"), + util.MockPkixAttribute(util.Province, "California"), + util.MockPkixAttribute(util.Locality, "San Francisco"), + util.MockPkixAttribute(util.Organization, "PingCAP"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "TiDB admin2"), + }, + }, + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "ZH"), + util.MockPkixAttribute(util.Province, "Beijing"), + util.MockPkixAttribute(util.Locality, "Haidian"), + util.MockPkixAttribute(util.Organization, "PingCAP.Inc"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "tester1"), + }, + }, + tls.TLS_AES_128_GCM_SHA256) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r8_subject_only", Hostname: "localhost"}, nil, nil), IsTrue) + + // test disorder issuer or subject + se.GetSessionVars().TLSConnectionState = connectionState( + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{}, + }, + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "ZH"), + util.MockPkixAttribute(util.Province, "Beijing"), + util.MockPkixAttribute(util.Locality, "Haidian"), + util.MockPkixAttribute(util.Organization, "PingCAP.Inc"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "tester1"), + }, + }, + tls.TLS_AES_128_GCM_SHA256) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r9_subject_disorder", Hostname: "localhost"}, nil, nil), IsFalse) + se.GetSessionVars().TLSConnectionState = connectionState( + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "US"), + util.MockPkixAttribute(util.Province, "California"), + util.MockPkixAttribute(util.Locality, "San Francisco"), + util.MockPkixAttribute(util.Organization, "PingCAP"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "TiDB admin"), + }, + }, + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{}, + }, + tls.TLS_AES_128_GCM_SHA256) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r10_issuer_disorder", Hostname: "localhost"}, nil, nil), IsFalse) + + // test old data and broken data + c.Assert(se.Auth(&auth.UserIdentity{Username: "r12_old_tidb_user", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r13_broken_user", Hostname: "localhost"}, nil, nil), IsFalse) + +} + +func connectionState(issuer, subject pkix.Name, cipher uint16) *tls.ConnectionState { + return &tls.ConnectionState{ + VerifiedChains: [][]*x509.Certificate{{{Issuer: issuer, Subject: subject}}}, + CipherSuite: cipher, + } } func (s *testPrivilegeSuite) TestCheckAuthenticate(c *C) { diff --git a/server/tidb_test.go b/server/tidb_test.go index 8797a3be27c39..00d690589fbe6 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -37,6 +37,7 @@ import ( "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/store/mockstore" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/testkit" ) @@ -220,7 +221,7 @@ func generateCert(sn int, commonName string, parentCert *x509.Certificate, paren template := x509.Certificate{ SerialNumber: big.NewInt(int64(sn)), - Subject: pkix.Name{CommonName: commonName}, + Subject: pkix.Name{CommonName: commonName, Names: []pkix.AttributeTypeAndValue{util.MockPkixAttribute(util.CommonName, commonName)}}, NotBefore: notBefore, NotAfter: notAfter, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, diff --git a/session/bootstrap.go b/session/bootstrap.go index cfe367d8024c3..24090cbc4dd85 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -75,6 +75,13 @@ const ( Account_locked ENUM('N','Y') NOT NULL DEFAULT 'N', Shutdown_priv ENUM('N','Y') NOT NULL DEFAULT 'N', PRIMARY KEY (Host, User));` + // CreateGlobalPrivTable is the SQL statement creates Global scope privilege table in system db. + CreateGlobalPrivTable = "CREATE TABLE if not exists mysql.global_priv (" + + "Host char(60) NOT NULL DEFAULT ''," + + "User char(80) NOT NULL DEFAULT ''," + + "Priv longtext NOT NULL DEFAULT ''," + + "PRIMARY KEY (Host, User)" + + ")" // CreateDBPrivTable is the SQL statement creates DB scope privilege table in system db. CreateDBPrivTable = `CREATE TABLE if not exists mysql.db ( Host CHAR(60), @@ -353,6 +360,7 @@ const ( version35 = 35 version36 = 36 version37 = 37 + version38 = 38 ) func checkBootstrapped(s Session) (bool, error) { @@ -561,6 +569,10 @@ func upgrade(s Session) { upgradeToVer37(s) } + if ver < version38 { + upgradeToVer38(s) + } + updateBootstrapVer(s) _, err = s.Execute(context.Background(), "COMMIT") @@ -894,6 +906,14 @@ func upgradeToVer37(s Session) { mustExecute(s, sql) } +func upgradeToVer38(s Session) { + var err error + _, err = s.Execute(context.Background(), CreateGlobalPrivTable) + if err != nil { + logutil.Logger(context.Background()).Fatal("upgradeToVer38 error", zap.Error(err)) + } +} + // updateBootstrapVer updates bootstrap version variable in mysql.TiDB table. func updateBootstrapVer(s Session) { // Update bootstrap version. @@ -923,6 +943,7 @@ func doDDLWorks(s Session) { // Create user table. mustExecute(s, CreateUserTable) // Create privilege tables. + mustExecute(s, CreateGlobalPrivTable) mustExecute(s, CreateDBPrivTable) mustExecute(s, CreateTablePrivTable) mustExecute(s, CreateColumnPrivTable) diff --git a/session/bootstrap_test.go b/session/bootstrap_test.go index d46cfa5b1fb1d..3adf346d7764c 100644 --- a/session/bootstrap_test.go +++ b/session/bootstrap_test.go @@ -61,6 +61,7 @@ func (s *testBootstrapSuite) TestBootstrap(c *C) { c.Assert(se.Auth(&auth.UserIdentity{Username: "root", Hostname: "anyhost"}, []byte(""), []byte("")), IsTrue) mustExecSQL(c, se, "USE test;") // Check privilege tables. + mustExecSQL(c, se, "SELECT * from mysql.global_priv;") mustExecSQL(c, se, "SELECT * from mysql.db;") mustExecSQL(c, se, "SELECT * from mysql.tables_priv;") mustExecSQL(c, se, "SELECT * from mysql.columns_priv;") @@ -165,6 +166,7 @@ func (s *testBootstrapSuite) TestBootstrapWithError(c *C) { mustExecSQL(c, se, "USE test;") // Check privilege tables. + mustExecSQL(c, se, "SELECT * from mysql.global_priv;") mustExecSQL(c, se, "SELECT * from mysql.db;") mustExecSQL(c, se, "SELECT * from mysql.tables_priv;") mustExecSQL(c, se, "SELECT * from mysql.columns_priv;") diff --git a/session/session.go b/session/session.go index e95519a53e6c7..8684d14b287b6 100644 --- a/session/session.go +++ b/session/session.go @@ -1364,7 +1364,7 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by // Check IP or localhost. var success bool - user.AuthUsername, user.AuthHostname, success = pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt) + user.AuthUsername, user.AuthHostname, success = pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt, s.sessionVars.TLSConnectionState) if success { s.sessionVars.User = user s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) @@ -1377,7 +1377,7 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by // Check Hostname. for _, addr := range getHostByIP(user.Hostname) { - u, h, success := pm.ConnectionVerification(user.Username, addr, authentication, salt) + u, h, success := pm.ConnectionVerification(user.Username, addr, authentication, salt, s.sessionVars.TLSConnectionState) if success { s.sessionVars.User = &auth.UserIdentity{ Username: user.Username, @@ -1625,7 +1625,7 @@ func createSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, er const ( notBootstrapped = 0 - currentBootstrapVersion = version37 + currentBootstrapVersion = version38 ) func getStoreBootstrapVersion(store kv.Storage) int64 { diff --git a/sessionctx/variable/statusvar.go b/sessionctx/variable/statusvar.go index 31250195dd746..7d2fc4a11efa6 100644 --- a/sessionctx/variable/statusvar.go +++ b/sessionctx/variable/statusvar.go @@ -17,6 +17,8 @@ import ( "bytes" "crypto/tls" "sync" + + "github.com/pingcap/tidb/util" ) var statisticsList []Statistics @@ -92,6 +94,9 @@ var tlsCiphers = []uint16{ tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + tls.TLS_CHACHA20_POLY1305_SHA256, } var tlsSupportedCiphers string @@ -102,32 +107,7 @@ var tlsVersionString = map[uint16]string{ tls.VersionTLS10: "TLSv1", tls.VersionTLS11: "TLSv1.1", tls.VersionTLS12: "TLSv1.2", -} - -// Taken from https://testssl.sh/openssl-rfc.mapping.html . -var tlsCipherString = map[uint16]string{ - tls.TLS_RSA_WITH_RC4_128_SHA: "RC4-SHA", - tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: "DES-CBC3-SHA", - tls.TLS_RSA_WITH_AES_128_CBC_SHA: "AES128-SHA", - tls.TLS_RSA_WITH_AES_256_CBC_SHA: "AES256-SHA", - tls.TLS_RSA_WITH_AES_128_CBC_SHA256: "AES128-SHA256", - tls.TLS_RSA_WITH_AES_128_GCM_SHA256: "AES128-GCM-SHA256", - tls.TLS_RSA_WITH_AES_256_GCM_SHA384: "AES256-GCM-SHA384", - tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: "ECDHE-ECDSA-RC4-SHA", - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: "ECDHE-ECDSA-AES128-SHA", - tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: "ECDHE-ECDSA-AES256-SHA", - tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA: "ECDHE-RSA-RC4-SHA", - tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: "ECDHE-RSA-DES-CBC3-SHA", - tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: "ECDHE-RSA-AES128-SHA", - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: "ECDHE-RSA-AES256-SHA", - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: "ECDHE-ECDSA-AES128-SHA256", - tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: "ECDHE-RSA-AES128-SHA256", - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: "ECDHE-RSA-AES128-GCM-SHA256", - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "ECDHE-ECDSA-AES128-GCM-SHA256", - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: "ECDHE-RSA-AES256-GCM-SHA384", - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: "ECDHE-ECDSA-AES256-GCM-SHA384", - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: "ECDHE-RSA-CHACHA20-POLY1305", - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: "ECDHE-ECDSA-CHACHA20-POLY1305", + tls.VersionTLS13: "TLSv1.3", } var defaultStatus = map[string]*StatusVal{ @@ -153,7 +133,7 @@ func (s defaultStatusStat) Stats(vars *SessionVars) (map[string]interface{}, err // `vars` may be nil in unit tests. if vars != nil && vars.TLSConnectionState != nil { - statusVars["Ssl_cipher"] = tlsCipherString[vars.TLSConnectionState.CipherSuite] + statusVars["Ssl_cipher"] = util.TLSCipher2String(vars.TLSConnectionState.CipherSuite) statusVars["Ssl_cipher_list"] = tlsSupportedCiphers // tls.VerifyClientCertIfGiven == SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE statusVars["Ssl_verify_mode"] = 0x01 | 0x04 @@ -166,7 +146,7 @@ func (s defaultStatusStat) Stats(vars *SessionVars) (map[string]interface{}, err func init() { var ciphersBuffer bytes.Buffer for _, v := range tlsCiphers { - ciphersBuffer.WriteString(tlsCipherString[v]) + ciphersBuffer.WriteString(util.TLSCipher2String(v)) ciphersBuffer.WriteString(":") } tlsSupportedCiphers = ciphersBuffer.String() diff --git a/util/misc.go b/util/misc.go index 15ecad2b7e296..5161385d0e9c9 100644 --- a/util/misc.go +++ b/util/misc.go @@ -15,12 +15,18 @@ package util import ( "context" + "crypto/tls" + "crypto/x509/pkix" + "fmt" "runtime" + "strconv" "strings" "time" "github.com/pingcap/errors" "github.com/pingcap/parser" + "github.com/pingcap/parser/model" + "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" @@ -131,3 +137,166 @@ func SyntaxWarn(err error) error { } return parser.ErrParse.GenWithStackByArgs(syntaxErrorPrefix, err.Error()) } + +var ( + // InformationSchemaName is the `INFORMATION_SCHEMA` database name. + InformationSchemaName = model.CIStr{O: "INFORMATION_SCHEMA", L: "information_schema"} + // PerformanceSchemaName is the `PERFORMANCE_SCHEMA` database name. + PerformanceSchemaName = model.CIStr{O: "PERFORMANCE_SCHEMA", L: "performance_schema"} + // MetricSchemaName is the `METRIC_SCHEMA` database name. + MetricSchemaName = model.CIStr{O: "METRIC_SCHEMA", L: "metric_schema"} +) + +// IsMemOrSysDB uses to check whether dbLowerName is memory database or system database. +func IsMemOrSysDB(dbLowerName string) bool { + switch dbLowerName { + case InformationSchemaName.L, PerformanceSchemaName.L, mysql.SystemDB, MetricSchemaName.L: + return true + } + return false +} + +// X509NameOnline prints pkix.Name into old X509_NAME_oneline format. +// https://www.openssl.org/docs/manmaster/man3/X509_NAME_oneline.html +func X509NameOnline(n pkix.Name) string { + s := make([]string, 0, len(n.Names)) + for _, name := range n.Names { + oid := name.Type.String() + // unlike MySQL, TiDB only support check pkixAttributeTypeNames fields + if n, exist := pkixAttributeTypeNames[oid]; exist { + s = append(s, n+"="+fmt.Sprint(name.Value)) + } + } + if len(s) == 0 { + return "" + } + return "/" + strings.Join(s, "/") +} + +const ( + // Country is type name for country. + Country = "C" + // Organization is type name for organization. + Organization = "O" + // OrganizationalUnit is type name for organizational unit. + OrganizationalUnit = "OU" + // Locality is type name for locality. + Locality = "L" + // Email is type name for email. + Email = "emailAddress" + // CommonName is type name for common name. + CommonName = "CN" + // Province is type name for province or state. + Province = "ST" +) + +// see go/src/crypto/x509/pkix/pkix.go:attributeTypeNames +var pkixAttributeTypeNames = map[string]string{ + "2.5.4.6": Country, + "2.5.4.10": Organization, + "2.5.4.11": OrganizationalUnit, + "2.5.4.3": CommonName, + "2.5.4.5": "SERIALNUMBER", + "2.5.4.7": Locality, + "2.5.4.8": Province, + "2.5.4.9": "STREET", + "2.5.4.17": "POSTALCODE", + "1.2.840.113549.1.9.1": Email, +} + +var pkixTypeNameAttributes = make(map[string]string) + +// MockPkixAttribute generates mock AttributeTypeAndValue. +// only used for test. +func MockPkixAttribute(name, value string) pkix.AttributeTypeAndValue { + n, exists := pkixTypeNameAttributes[name] + if !exists { + panic(fmt.Sprintf("unsupport mock type: %s", name)) + } + var vs []int + for _, v := range strings.Split(n, ".") { + i, err := strconv.Atoi(v) + if err != nil { + panic(err) + } + vs = append(vs, i) + } + return pkix.AttributeTypeAndValue{ + Type: vs, + Value: value, + } +} + +// CheckSupportX509NameOneline parses and validate input str is X509_NAME_oneline format +// and precheck check-item is supported by TiDB +// https://www.openssl.org/docs/manmaster/man3/X509_NAME_oneline.html +func CheckSupportX509NameOneline(oneline string) (err error) { + entries := strings.Split(oneline, `/`) + for _, entry := range entries { + if len(entry) == 0 { + continue + } + kvs := strings.Split(entry, "=") + if len(kvs) != 2 { + err = errors.Errorf("invalid X509_NAME input: %s", oneline) + return + } + k := kvs[0] + if _, support := pkixTypeNameAttributes[k]; !support { + err = errors.Errorf("Unsupport check '%s' in current version TiDB", k) + return + } + } + return +} + +var tlsCipherString = map[uint16]string{ + tls.TLS_RSA_WITH_RC4_128_SHA: "RC4-SHA", + tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: "DES-CBC3-SHA", + tls.TLS_RSA_WITH_AES_128_CBC_SHA: "AES128-SHA", + tls.TLS_RSA_WITH_AES_256_CBC_SHA: "AES256-SHA", + tls.TLS_RSA_WITH_AES_128_CBC_SHA256: "AES128-SHA256", + tls.TLS_RSA_WITH_AES_128_GCM_SHA256: "AES128-GCM-SHA256", + tls.TLS_RSA_WITH_AES_256_GCM_SHA384: "AES256-GCM-SHA384", + tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: "ECDHE-ECDSA-RC4-SHA", + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: "ECDHE-ECDSA-AES128-SHA", + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: "ECDHE-ECDSA-AES256-SHA", + tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA: "ECDHE-RSA-RC4-SHA", + tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: "ECDHE-RSA-DES-CBC3-SHA", + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: "ECDHE-RSA-AES128-SHA", + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: "ECDHE-RSA-AES256-SHA", + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: "ECDHE-ECDSA-AES128-SHA256", + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: "ECDHE-RSA-AES128-SHA256", + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: "ECDHE-RSA-AES128-GCM-SHA256", + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "ECDHE-ECDSA-AES128-GCM-SHA256", + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: "ECDHE-RSA-AES256-GCM-SHA384", + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: "ECDHE-ECDSA-AES256-GCM-SHA384", + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: "ECDHE-RSA-CHACHA20-POLY1305", + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: "ECDHE-ECDSA-CHACHA20-POLY1305", + // TLS 1.3 cipher suites, compatible with mysql using '_'. + tls.TLS_AES_128_GCM_SHA256: "TLS_AES_128_GCM_SHA256", + tls.TLS_AES_256_GCM_SHA384: "TLS_AES_256_GCM_SHA384", + tls.TLS_CHACHA20_POLY1305_SHA256: "TLS_CHACHA20_POLY1305_SHA256", +} + +// SupportCipher maintains cipher supported by TiDB. +var SupportCipher = make(map[string]struct{}, len(tlsCipherString)) + +// TLSCipher2String convert tls num to string. +// Taken from https://testssl.sh/openssl-rfc.mapping.html . +func TLSCipher2String(n uint16) string { + s, ok := tlsCipherString[n] + if !ok { + return "" + } + return s +} + +func init() { + for _, value := range tlsCipherString { + SupportCipher[value] = struct{}{} + } + for key, value := range pkixAttributeTypeNames { + pkixTypeNameAttributes[value] = key + } +} diff --git a/util/misc_test.go b/util/misc_test.go index 7c365a98fdb5a..8a6112b6454a6 100644 --- a/util/misc_test.go +++ b/util/misc_test.go @@ -15,6 +15,7 @@ package util import ( "bytes" + "crypto/x509/pkix" "time" . "github.com/pingcap/check" @@ -121,6 +122,23 @@ func (s *testMiscSuite) TestCompatibleParseGCTime(c *C) { } } +func (s *testMiscSuite) TestX509NameParseMatch(c *C) { + check := pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + MockPkixAttribute(Country, "SE"), + MockPkixAttribute(Province, "Stockholm2"), + MockPkixAttribute(Locality, "Stockholm"), + MockPkixAttribute(Organization, "MySQL demo client certificate"), + MockPkixAttribute(OrganizationalUnit, "testUnit"), + MockPkixAttribute(CommonName, "client"), + MockPkixAttribute(Email, "client@example.com"), + }, + } + c.Assert(X509NameOnline(check), Equals, "/C=SE/ST=Stockholm2/L=Stockholm/O=MySQL demo client certificate/OU=testUnit/CN=client/emailAddress=client@example.com") + check = pkix.Name{} + c.Assert(X509NameOnline(check), Equals, "") +} + func (s *testMiscSuite) TestBasicFunc(c *C) { // Test for GetStack. b := GetStack()