From fabc7633499a8d388c53ef976742c20387d9c12b Mon Sep 17 00:00:00 2001 From: lysu Date: Mon, 16 Dec 2019 09:35:33 +0800 Subject: [PATCH 1/4] support cert auth --- executor/builder.go | 1 + executor/grant.go | 146 +++++++++++++-- executor/grant_test.go | 60 ++++++ executor/show.go | 23 ++- executor/simple.go | 65 ++++++- go.sum | 4 - privilege/privilege.go | 3 +- privilege/privileges/cache.go | 123 ++++++++++++ privilege/privileges/cache_test.go | 38 ++++ privilege/privileges/privileges.go | 111 ++++++++++- privilege/privileges/privileges_test.go | 236 +++++++++++++++++++++++- server/tidb_test.go | 3 +- session/bootstrap.go | 21 +++ session/bootstrap_test.go | 2 + session/session.go | 6 +- sessionctx/variable/statusvar.go | 36 +--- util/misc.go | 149 +++++++++++++++ util/misc_test.go | 18 ++ 18 files changed, 980 insertions(+), 65 deletions(-) diff --git a/executor/builder.go b/executor/builder.go index 5a32c99dd649a..b88e060689963 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -805,6 +805,7 @@ func (b *executorBuilder) buildGrant(grant *ast.GrantStmt) Executor { Level: grant.Level, Users: grant.Users, WithGrant: grant.WithGrant, + TLSOptions: grant.TLSOptions, is: b.is, } return e diff --git a/executor/grant.go b/executor/grant.go index 66157d591b9b3..f1ce523d63729 100644 --- a/executor/grant.go +++ b/executor/grant.go @@ -15,6 +15,7 @@ package executor import ( "context" + "encoding/json" "fmt" "strings" @@ -24,8 +25,10 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/privilege/privileges" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/sqlexec" ) @@ -46,6 +49,7 @@ type GrantExec struct { ObjectType ast.ObjectTypeType Level *ast.GrantLevel Users []*ast.UserSpec + TLSOptions []*ast.TLSOption is infoschema.InfoSchema WithGrant bool @@ -86,9 +90,14 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { } // If there is no privilege entry in corresponding table, insert a new one. - // DB scope: mysql.DB - // Table scope: mysql.Tables_priv - // Column scope: mysql.Columns_priv + // Global scope: mysql.global_priv + // DB scope: mysql.DB + // Table scope: mysql.Tables_priv + // Column scope: mysql.Columns_priv + err = checkAndInitGlobalPriv(e.ctx, user.User.Username, user.User.Hostname) + if err != nil { + return err + } switch e.Level.Level { case ast.GrantLevelDB: err := checkAndInitDBPriv(e.ctx, dbName, e.is, user.User.Username, user.User.Hostname) @@ -113,7 +122,11 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { } defer func() { e.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusInTrans, false) }() } - + // Grant global priv to user. + err = e.grantGlobalPriv(user) + if err != nil { + return err + } // Grant each priv to the user. for _, priv := range privs { if len(priv.Cols) > 0 { @@ -124,7 +137,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { return err } } - err := e.grantPriv(priv, user) + err := e.grantLevelPriv(priv, user) if err != nil { return err } @@ -134,6 +147,20 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { return nil } +// checkAndInitGlobalPriv checks if global scope privilege entry exists in mysql.global_priv. +// If unexists, insert a new one. +func checkAndInitGlobalPriv(ctx sessionctx.Context, user string, host string) error { + ok, err := globalPrivEntryExists(ctx, user, host) + if err != nil { + return err + } + if ok { + return nil + } + // Entry does not exist for user-host-db. Insert a new entry. + return initGlobalPrivEntry(ctx, user, host) +} + // checkAndInitDBPriv checks if DB scope privilege entry exists in mysql.DB. // If unexists, insert a new one. func checkAndInitDBPriv(ctx sessionctx.Context, dbName string, is infoschema.InfoSchema, user string, host string) error { @@ -190,6 +217,13 @@ func (e *GrantExec) checkAndInitColumnPriv(user string, host string, cols []*ast return nil } +// initGlobalPrivEntry inserts a new row into mysql.DB with empty privilege. +func initGlobalPrivEntry(ctx sessionctx.Context, user string, host string) error { + sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, PRIV) VALUES ('%s', '%s', '%s')`, mysql.SystemDB, mysql.GlobalPrivTable, host, user, "{}") + _, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + return err +} + // initDBPrivEntry inserts a new row into mysql.DB with empty privilege. func initDBPrivEntry(ctx sessionctx.Context, user string, host string, db string) error { sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, DB) VALUES ('%s', '%s', '%s')`, mysql.SystemDB, mysql.DBTable, host, user, db) @@ -211,25 +245,93 @@ func initColumnPrivEntry(ctx sessionctx.Context, user string, host string, db st return err } -// grantPriv grants priv to user in s.Level scope. -func (e *GrantExec) grantPriv(priv *ast.PrivElem, user *ast.UserSpec) error { +// grantGlobalPriv grants priv to user in global scope. +func (e *GrantExec) grantGlobalPriv(user *ast.UserSpec) error { + if len(e.TLSOptions) == 0 { + return nil + } + priv, err := tlsOption2GlobalPriv(e.TLSOptions) + if err != nil { + return errors.Trace(err) + } + sql := fmt.Sprintf(`UPDATE %s.%s SET PRIV = '%s' WHERE User='%s' AND Host='%s'`, mysql.SystemDB, mysql.GlobalPrivTable, priv, user.User.Username, user.User.Hostname) + _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + return err +} + +var emptyGP = privileges.GlobalPrivValue{SSLType: privileges.SslTypeNotSpecified} + +func tlsOption2GlobalPriv(tlsOptions []*ast.TLSOption) (priv []byte, err error) { + if len(tlsOptions) == 0 { + priv = []byte("{}") + return + } + gp := privileges.GlobalPrivValue{SSLType: privileges.SslTypeNotSpecified} + for _, tlsOpt := range tlsOptions { + switch tlsOpt.Type { + case ast.TslNone: + gp.SSLType = privileges.SslTypeNone + case ast.Ssl: + gp.SSLType = privileges.SslTypeAny + case ast.X509: + gp.SSLType = privileges.SslTypeX509 + case ast.Cipher: + gp.SSLType = privileges.SslTypeSpecified + if len(tlsOpt.Value) > 0 { + if _, ok := util.SupportCipher[tlsOpt.Value]; !ok { + err = errors.Errorf("Unsupported cipher suit: %s", tlsOpt.Value) + return + } + gp.SSLCipher = tlsOpt.Value + } + case ast.Issuer: + err = util.CheckSupportX509NameOneline(tlsOpt.Value) + if err != nil { + return + } + gp.SSLType = privileges.SslTypeSpecified + gp.X509Issuer = tlsOpt.Value + case ast.Subject: + err = util.CheckSupportX509NameOneline(tlsOpt.Value) + if err != nil { + return + } + gp.SSLType = privileges.SslTypeSpecified + gp.X509Subject = tlsOpt.Value + default: + err = errors.Errorf("Unknown ssl type: %#v", tlsOpt.Type) + return + } + } + if gp == emptyGP { + return + } + priv, err = json.Marshal(&gp) + if err != nil { + return + } + return +} + +// grantLevelPriv grants priv to user in s.Level scope. +func (e *GrantExec) grantLevelPriv(priv *ast.PrivElem, user *ast.UserSpec) error { switch e.Level.Level { case ast.GrantLevelGlobal: - return e.grantGlobalPriv(priv, user) + return e.grantGlobalLevel(priv, user) case ast.GrantLevelDB: - return e.grantDBPriv(priv, user) + return e.grantDBLevel(priv, user) case ast.GrantLevelTable: if len(priv.Cols) == 0 { - return e.grantTablePriv(priv, user) + return e.grantTableLevel(priv, user) } - return e.grantColumnPriv(priv, user) + return e.grantColumnLevel(priv, user) default: return errors.Errorf("Unknown grant level: %#v", e.Level) } } -// grantGlobalPriv manipulates mysql.user table. -func (e *GrantExec) grantGlobalPriv(priv *ast.PrivElem, user *ast.UserSpec) error { +// grantGlobalLevel manipulates mysql.user table. +func (e *GrantExec) grantGlobalLevel(priv *ast.PrivElem, user *ast.UserSpec) error { if priv.Priv == 0 { return nil } @@ -242,8 +344,8 @@ func (e *GrantExec) grantGlobalPriv(priv *ast.PrivElem, user *ast.UserSpec) erro return err } -// grantDBPriv manipulates mysql.db table. -func (e *GrantExec) grantDBPriv(priv *ast.PrivElem, user *ast.UserSpec) error { +// grantDBLevel manipulates mysql.db table. +func (e *GrantExec) grantDBLevel(priv *ast.PrivElem, user *ast.UserSpec) error { dbName := e.Level.DBName if len(dbName) == 0 { dbName = e.ctx.GetSessionVars().CurrentDB @@ -257,8 +359,8 @@ func (e *GrantExec) grantDBPriv(priv *ast.PrivElem, user *ast.UserSpec) error { return err } -// grantTablePriv manipulates mysql.tables_priv table. -func (e *GrantExec) grantTablePriv(priv *ast.PrivElem, user *ast.UserSpec) error { +// grantTableLevel manipulates mysql.tables_priv table. +func (e *GrantExec) grantTableLevel(priv *ast.PrivElem, user *ast.UserSpec) error { dbName := e.Level.DBName if len(dbName) == 0 { dbName = e.ctx.GetSessionVars().CurrentDB @@ -273,8 +375,8 @@ func (e *GrantExec) grantTablePriv(priv *ast.PrivElem, user *ast.UserSpec) error return err } -// grantColumnPriv manipulates mysql.tables_priv table. -func (e *GrantExec) grantColumnPriv(priv *ast.PrivElem, user *ast.UserSpec) error { +// grantColumnLevel manipulates mysql.tables_priv table. +func (e *GrantExec) grantColumnLevel(priv *ast.PrivElem, user *ast.UserSpec) error { dbName, tbl, err := getTargetSchemaAndTable(e.ctx, e.Level.DBName, e.Level.TableName, e.is) if err != nil { return err @@ -473,6 +575,12 @@ func recordExists(ctx sessionctx.Context, sql string) (bool, error) { return len(rows) > 0, nil } +// globalPrivEntryExists checks if there is an entry with key user-host in mysql.global_priv. +func globalPrivEntryExists(ctx sessionctx.Context, name string, host string) (bool, error) { + sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s';`, mysql.SystemDB, mysql.GlobalPrivTable, name, host) + return recordExists(ctx, sql) +} + // dbUserExists checks if there is an entry with key user-host-db in mysql.DB. func dbUserExists(ctx sessionctx.Context, name string, host string, db string) (bool, error) { sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s' AND DB='%s';`, mysql.SystemDB, mysql.DBTable, name, host, db) diff --git a/executor/grant_test.go b/executor/grant_test.go index f652ec37a1448..d61ab8f12f4dc 100644 --- a/executor/grant_test.go +++ b/executor/grant_test.go @@ -237,3 +237,63 @@ func (s *testSuite3) TestGrantUnderANSIQuotes(c *C) { tk.MustExec(`REVOKE ALL PRIVILEGES ON video_ulimit.* FROM web@'%';`) tk.MustExec(`DROP USER IF EXISTS 'web'@'%'`) } + +func (s *testSuite3) TestMaintainRequire(c *C) { + tk := testkit.NewTestKit(c, s.store) + + // test create with require + tk.MustExec(`CREATE USER 'ssl_auser'@'%' require issuer '/CN=TiDB admin/OU=TiDB/O=PingCAP/L=San Francisco/ST=California/C=US' subject '/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH' cipher 'AES128-GCM-SHA256'`) + tk.MustExec(`CREATE USER 'ssl_buser'@'%' require subject '/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH' cipher 'AES128-GCM-SHA256'`) + tk.MustExec(`CREATE USER 'ssl_cuser'@'%' require cipher 'AES128-GCM-SHA256'`) + tk.MustExec(`CREATE USER 'ssl_duser'@'%'`) + tk.MustExec(`CREATE USER 'ssl_euser'@'%' require none`) + tk.MustExec(`CREATE USER 'ssl_fuser'@'%' require ssl`) + tk.MustExec(`CREATE USER 'ssl_guser'@'%' require x509`) + tk.MustQuery("select * from mysql.global_priv where `user` like 'ssl_%'").Check(testkit.Rows( + "% ssl_auser {\"ssl_type\":3,\"ssl_cipher\":\"AES128-GCM-SHA256\",\"x509_issuer\":\"/CN=TiDB admin/OU=TiDB/O=PingCAP/L=San Francisco/ST=California/C=US\",\"x509_subject\":\"/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH\"}", + "% ssl_buser {\"ssl_type\":3,\"ssl_cipher\":\"AES128-GCM-SHA256\",\"x509_subject\":\"/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH\"}", + "% ssl_cuser {\"ssl_type\":3,\"ssl_cipher\":\"AES128-GCM-SHA256\"}", + "% ssl_duser {}", + "% ssl_euser {}", + "% ssl_fuser {\"ssl_type\":1}", + "% ssl_guser {\"ssl_type\":2}", + )) + + // test grant with require + tk.MustExec("CREATE USER 'u1'@'%'") + tk.MustExec("GRANT ALL ON *.* TO 'u1'@'%' require issuer '/CN=TiDB admin/OU=TiDB/O=PingCAP/L=San Francisco/ST=California/C=US' and subject '/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH'") // add new require. + tk.MustQuery("select priv from mysql.global_priv where `Host` = '%' and `User` = 'u1'").Check(testkit.Rows("{\"ssl_type\":3,\"x509_issuer\":\"/CN=TiDB admin/OU=TiDB/O=PingCAP/L=San Francisco/ST=California/C=US\",\"x509_subject\":\"/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH\"}")) + tk.MustExec("GRANT ALL ON *.* TO 'u1'@'%' require cipher 'AES128-GCM-SHA256'") // modify always overwrite. + tk.MustQuery("select priv from mysql.global_priv where `Host` = '%' and `User` = 'u1'").Check(testkit.Rows("{\"ssl_type\":3,\"ssl_cipher\":\"AES128-GCM-SHA256\"}")) + tk.MustExec("GRANT select ON *.* TO 'u1'@'%'") // modify without require should not modify old require. + tk.MustQuery("select priv from mysql.global_priv where `Host` = '%' and `User` = 'u1'").Check(testkit.Rows("{\"ssl_type\":3,\"ssl_cipher\":\"AES128-GCM-SHA256\"}")) + tk.MustExec("GRANT ALL ON *.* TO 'u1'@'%' require none") // use require none to clean up require. + tk.MustQuery("select priv from mysql.global_priv where `Host` = '%' and `User` = 'u1'").Check(testkit.Rows("{}")) + + // test alter with require + tk.MustExec("CREATE USER 'u2'@'%'") + tk.MustExec("alter user 'u2'@'%' require ssl") + tk.MustQuery("select priv from mysql.global_priv where `Host` = '%' and `User` = 'u2'").Check(testkit.Rows("{\"ssl_type\":1}")) + tk.MustExec("alter user 'u2'@'%' require x509") + tk.MustQuery("select priv from mysql.global_priv where `Host` = '%' and `User` = 'u2'").Check(testkit.Rows("{\"ssl_type\":2}")) + tk.MustExec("alter user 'u2'@'%' require issuer '/CN=TiDB admin/OU=TiDB/O=PingCAP/L=San Francisco/ST=California/C=US' subject '/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH' cipher 'AES128-GCM-SHA256'") + tk.MustQuery("select priv from mysql.global_priv where `Host` = '%' and `User` = 'u2'").Check(testkit.Rows("{\"ssl_type\":3,\"ssl_cipher\":\"AES128-GCM-SHA256\",\"x509_issuer\":\"/CN=TiDB admin/OU=TiDB/O=PingCAP/L=San Francisco/ST=California/C=US\",\"x509_subject\":\"/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH\"}")) + tk.MustExec("alter user 'u2'@'%' require none") + tk.MustQuery("select priv from mysql.global_priv where `Host` = '%' and `User` = 'u2'").Check(testkit.Rows("{}")) + + // test show create user + tk.MustExec(`CREATE USER 'u3'@'%' require issuer '/CN=TiDB admin/OU=TiDB/O=PingCAP/L=San Francisco/ST=California/C=US' subject '/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH' cipher 'AES128-GCM-SHA256'`) + tk.MustQuery("show create user 'u3'").Check(testkit.Rows("CREATE USER 'u3'@'%' IDENTIFIED WITH 'mysql_native_password' AS '' REQUIRE CIPHER 'AES128-GCM-SHA256' ISSUER '/CN=TiDB admin/OU=TiDB/O=PingCAP/L=San Francisco/ST=California/C=US' SUBJECT '/CN=tester1/OU=TiDB/O=PingCAP.Inc/L=Haidian/ST=Beijing/C=ZH' PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK")) + + // check issuer/subject/cipher value + _, err := tk.Exec(`CREATE USER 'u4'@'%' require issuer 'CN=TiDB,OU=PingCAP'`) + c.Assert(err, NotNil) + _, err = tk.Exec(`CREATE USER 'u5'@'%' require subject '/CN=TiDB\OU=PingCAP'`) + c.Assert(err, NotNil) + _, err = tk.Exec(`CREATE USER 'u6'@'%' require subject '/CN=TiDB\NC=PingCAP'`) + c.Assert(err, NotNil) + _, err = tk.Exec(`CREATE USER 'u7'@'%' require cipher 'AES128-GCM-SHA1'`) + c.Assert(err, NotNil) + _, err = tk.Exec(`CREATE USER 'u8'@'%' require subject '/CN'`) + c.Assert(err, NotNil) +} diff --git a/executor/show.go b/executor/show.go index 79189c416ab06..5cbffc6a543c6 100644 --- a/executor/show.go +++ b/executor/show.go @@ -16,6 +16,7 @@ package executor import ( "bytes" "context" + gjson "encoding/json" "fmt" "sort" "strconv" @@ -43,6 +44,7 @@ import ( plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/plugin" "github.com/pingcap/tidb/privilege" + "github.com/pingcap/tidb/privilege/privileges" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" @@ -52,6 +54,7 @@ import ( "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/format" + "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/sqlexec" ) @@ -1009,8 +1012,24 @@ func (e *ShowExec) fetchShowCreateUser() error { return ErrCannotUser.GenWithStackByArgs("SHOW CREATE USER", fmt.Sprintf("'%s'@'%s'", e.User.Username, e.User.Hostname)) } - showStr := fmt.Sprintf("CREATE USER '%s'@'%s' IDENTIFIED WITH 'mysql_native_password' AS '%s' REQUIRE NONE PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK", - e.User.Username, e.User.Hostname, checker.GetEncodedPassword(e.User.Username, e.User.Hostname)) + sql = fmt.Sprintf(`SELECT PRIV FROM %s.%s WHERE User='%s' AND Host='%s'`, + mysql.SystemDB, mysql.GlobalPrivTable, userName, hostName) + rows, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + if err != nil { + return errors.Trace(err) + } + require := "NONE" + if len(rows) == 1 { + privData := rows[0].GetString(0) + var privValue privileges.GlobalPrivValue + err = gjson.Unmarshal(hack.Slice(privData), &privValue) + if err != nil { + return errors.Trace(err) + } + require = privValue.RequireStr() + } + showStr := fmt.Sprintf("CREATE USER '%s'@'%s' IDENTIFIED WITH 'mysql_native_password' AS '%s' REQUIRE %s PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK", + e.User.Username, e.User.Hostname, checker.GetEncodedPassword(e.User.Username, e.User.Hostname), require) e.appendRow([]interface{}{showStr}) return nil } diff --git a/executor/simple.go b/executor/simple.go index 0137cd750ec9f..e4f4f6510887a 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -38,6 +38,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/sqlexec" "go.uber.org/zap" @@ -659,7 +660,13 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm } } + privData, err := tlsOption2GlobalPriv(s.TLSOptions) + if err != nil { + return err + } + users := make([]string, 0, len(s.Specs)) + privs := make([]string, 0, len(s.Specs)) for _, spec := range s.Specs { exists, err1 := userExists(e.ctx, spec.User.Username, spec.User.Hostname) if err1 != nil { @@ -686,6 +693,11 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm user = fmt.Sprintf(`('%s', '%s', '%s', 'Y')`, spec.User.Hostname, spec.User.Username, pwd) } users = append(users, user) + + if len(privData) != 0 { + priv := fmt.Sprintf(`('%s', '%s', '%s')`, spec.User.Hostname, spec.User.Username, hack.String(privData)) + privs = append(privs, priv) + } } if len(users) == 0 { return nil @@ -695,10 +707,37 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm if s.IsCreateRole { sql = fmt.Sprintf(`INSERT INTO %s.%s (Host, User, Password, Account_locked) VALUES %s;`, mysql.SystemDB, mysql.UserTable, strings.Join(users, ", ")) } - _, _, err := e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + + restrictedCtx, err := e.getSysSession() if err != nil { return err } + defer e.releaseSysSession(restrictedCtx) + sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) + + if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + return errors.Trace(err) + } + _, err = sqlExecutor.Execute(context.Background(), sql) + if err != nil { + if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + return rollbackErr + } + return err + } + if len(privs) != 0 { + sql = fmt.Sprintf("INSERT IGNORE INTO %s.%s (Host, User, Priv) VALUES %s", mysql.SystemDB, mysql.GlobalPrivTable, strings.Join(privs, ", ")) + _, err = sqlExecutor.Execute(context.Background(), sql) + if err != nil { + if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + return rollbackErr + } + return err + } + } + if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + return errors.Trace(err) + } domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) return err } @@ -716,6 +755,11 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error { s.Specs = []*ast.UserSpec{spec} } + privData, err := tlsOption2GlobalPriv(s.TLSOptions) + if err != nil { + return err + } + failedUsers := make([]string, 0, len(s.Specs)) for _, spec := range s.Specs { exists, err := userExists(e.ctx, spec.User.Username, spec.User.Hostname) @@ -741,6 +785,15 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error { if err != nil { failedUsers = append(failedUsers, spec.User.String()) } + + if len(privData) > 0 { + sql = fmt.Sprintf("INSERT INTO %s.%s (Host, User, Priv) VALUES ('%s','%s','%s') ON DUPLICATE KEY UPDATE Priv = values(Priv)", + mysql.SystemDB, mysql.GlobalPrivTable, spec.User.Hostname, spec.User.Username, hack.String(privData)) + _, _, err = e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(sql) + if err != nil { + failedUsers = append(failedUsers, spec.User.String()) + } + } } if len(failedUsers) > 0 { // Commit the transaction even if we returns error @@ -878,6 +931,16 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { break } + // delete privileges from mysql.global_priv + sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.GlobalPrivTable, user.Hostname, user.Username) + if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + failedUsers = append(failedUsers, user.String()) + if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { + return err + } + continue + } + // delete privileges from mysql.db sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.DBTable, user.Hostname, user.Username) if _, err = sqlExecutor.Execute(context.Background(), sql); err != nil { diff --git a/go.sum b/go.sum index a7dc19ee4e218..a0beafea5a8a7 100644 --- a/go.sum +++ b/go.sum @@ -15,7 +15,6 @@ github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20171208011716-f6d7a1f6fbf3/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa h1:OaNxuTZr7kxeODyLWsRMC+OD03aFUH+mW6r2d+MWa5Y= github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= @@ -219,7 +218,6 @@ github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5Fsn github.com/prometheus/client_model v0.0.0-20170216185247-6f3806018612/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20171117100541-99fa1f4be8e5/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90 h1:S/YWwWx/RA8rT8tKFRuGUZhuA90OyIBpPCXkcbwU8DE= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 h1:gQz4mCbXsO+nc9n1hCxHcGA3Zx3Eo+UHZoInFGUIXNM= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -234,7 +232,6 @@ github.com/remyoudompheng/bigfft v0.0.0-20190512091148-babf20351dd7 h1:FUL3b97ZY github.com/remyoudompheng/bigfft v0.0.0-20190512091148-babf20351dd7/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/sergi/go-diff v1.0.1-0.20180205163309-da645544ed44 h1:tB9NOR21++IjLyVx3/PCPhWMwqGNCMQEH96A6dMZ/gc= github.com/sergi/go-diff v1.0.1-0.20180205163309-da645544ed44/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shirou/gopsutil v2.19.10+incompatible h1:lA4Pi29JEVIQIgATSeftHSY0rMGI9CLrl2ZvDLiahto= github.com/shirou/gopsutil v2.19.10+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= @@ -383,7 +380,6 @@ google.golang.org/grpc v0.0.0-20180607172857-7a6a684ca69e/go.mod h1:yo6s7OP7yaDg google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.23.1 h1:q4XQuHFC6I28BKZpo6IYyb3mNO+l7lSOxRuYTCiDfXk= google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1 h1:wdKvqQk7IttEw92GoRyKG2IDrUIpgpj6H6m81yfeMW0= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= diff --git a/privilege/privilege.go b/privilege/privilege.go index 9362239618441..7e1655e151a23 100644 --- a/privilege/privilege.go +++ b/privilege/privilege.go @@ -14,6 +14,7 @@ package privilege import ( + "crypto/tls" "github.com/pingcap/parser/auth" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx" @@ -45,7 +46,7 @@ type Manager interface { RequestVerificationWithUser(db, table, column string, priv mysql.PrivilegeType, user *auth.UserIdentity) bool // ConnectionVerification verifies user privilege for connection. - ConnectionVerification(user, host string, auth, salt []byte) (string, string, bool) + ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) (string, string, bool) // DBIsVisible returns true is the database is visible to current user. DBIsVisible(activeRole []*auth.RoleIdentity, db string) bool diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index f107e2255f661..a8f419ec5200e 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -15,6 +15,7 @@ package privileges import ( "context" + "encoding/json" "fmt" "sort" "strings" @@ -29,6 +30,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tidb/util/stringutil" @@ -62,6 +64,71 @@ type UserRecord struct { patTypes []byte } +type globalPrivRecord struct { + Host string + User string + Priv GlobalPrivValue + Broken bool + + // patChars is compiled from Host, cached for pattern match performance. + patChars []byte + patTypes []byte +} + +// SSLType is enum value for GlobalPrivValue.SSLType. +// the value is compatible with MySQL storage json value. +type SSLType int + +const ( + // SslTypeNotSpecified indicates . + SslTypeNotSpecified SSLType = iota - 1 + // SslTypeNone indicates not require use ssl. + SslTypeNone + // SslTypeAny indicates require use ssl but not validate cert. + SslTypeAny + // SslTypeX509 indicates require use ssl and validate cert. + SslTypeX509 + // SslTypeSpecified indicates require use ssl and validate cert's subject or issuer. + SslTypeSpecified +) + +// GlobalPrivValue is store json format for priv column in mysql.global_priv. +type GlobalPrivValue struct { + SSLType SSLType `json:"ssl_type,omitempty"` + SSLCipher string `json:"ssl_cipher,omitempty"` + X509Issuer string `json:"x509_issuer,omitempty"` + X509Subject string `json:"x509_subject,omitempty"` +} + +// RequireStr returns describe string after `REQUIRE` clause. +func (g *GlobalPrivValue) RequireStr() string { + require := "NONE" + switch g.SSLType { + case SslTypeAny: + require = "SSL" + case SslTypeX509: + require = "X509" + case SslTypeSpecified: + var s []string + if len(g.SSLCipher) > 0 { + s = append(s, "CIPHER") + s = append(s, "'"+g.SSLCipher+"'") + } + if len(g.X509Issuer) > 0 { + s = append(s, "ISSUER") + s = append(s, "'"+g.X509Issuer+"'") + } + if len(g.X509Subject) > 0 { + s = append(s, "SUBJECT") + s = append(s, "'"+g.X509Subject+"'") + } + if len(s) > 0 { + require = strings.Join(s, " ") + } + } + return require +} + type dbRecord struct { Host string DB string @@ -151,6 +218,7 @@ type MySQLPrivilege struct { // non-full privileges (i.e. user.db entries). User []UserRecord UserMap map[string][]UserRecord // Accelerate User searching + Global []globalPrivRecord DB []dbRecord DBMap map[string][]dbRecord // Accelerate DB searching TablesPriv []tablesPrivRecord @@ -204,6 +272,11 @@ func (p *MySQLPrivilege) LoadAll(ctx sessionctx.Context) error { return errLoadPrivilege.FastGen("mysql.user") } + err = p.LoadGlobalPrivTable(ctx) + if err != nil { + return errors.Trace(err) + } + err = p.LoadDBTable(ctx) if err != nil { if !noSuchTable(err) { @@ -382,6 +455,11 @@ func (p MySQLPrivilege) SortUserTable() { sort.Sort(sortedUserRecord(p.User)) } +// LoadGlobalPrivTable loads the mysql.global_priv table from database. +func (p *MySQLPrivilege) LoadGlobalPrivTable(ctx sessionctx.Context) error { + return p.loadTable(ctx, "select HIGH_PRIORITY Host,User,Priv from mysql.global_priv", p.decodeGlobalPrivTableRow) +} + // LoadDBTable loads the mysql.db table from database. func (p *MySQLPrivilege) LoadDBTable(ctx sessionctx.Context) error { err := p.loadTable(ctx, "select HIGH_PRIORITY Host,DB,User,Select_priv,Insert_priv,Update_priv,Delete_priv,Create_priv,Drop_priv,Grant_priv,Index_priv,Alter_priv,Execute_priv,Create_view_priv,Show_view_priv from mysql.db order by host, db, user;", p.decodeDBTableRow) @@ -492,6 +570,37 @@ func (p *MySQLPrivilege) decodeUserTableRow(row chunk.Row, fs []*ast.ResultField return nil } +func (p *MySQLPrivilege) decodeGlobalPrivTableRow(row chunk.Row, fs []*ast.ResultField) error { + var value globalPrivRecord + for i, f := range fs { + switch { + case f.ColumnAsName.L == "host": + value.Host = row.GetString(i) + value.patChars, value.patTypes = stringutil.CompilePattern(value.Host, '\\') + case f.ColumnAsName.L == "user": + value.User = row.GetString(i) + case f.ColumnAsName.L == "priv": + privData := row.GetString(i) + if len(privData) > 0 { + var privValue GlobalPrivValue + err := json.Unmarshal(hack.Slice(privData), &privValue) + if err != nil { + logutil.BgLogger().Error("one userglobal priv data is broken, forbidden login until data be fixed", + zap.String("user", value.User), zap.String("host", value.Host)) + value.Broken = true + } else { + value.Priv.SSLType = privValue.SSLType + value.Priv.SSLCipher = privValue.SSLCipher + value.Priv.X509Issuer = privValue.X509Issuer + value.Priv.X509Subject = privValue.X509Subject + } + } + } + } + p.Global = append(p.Global, value) + return nil +} + func (p *MySQLPrivilege) decodeDBTableRow(row chunk.Row, fs []*ast.ResultField) error { var value dbRecord for i, f := range fs { @@ -631,6 +740,10 @@ func decodeSetToPrivilege(s types.Set) mysql.PrivilegeType { return ret } +func (record *globalPrivRecord) match(user, host string) bool { + return record.User == user && patternMatch(host, record.patChars, record.patTypes) +} + func (record *UserRecord) match(user, host string) bool { return record.User == user && patternMatch(host, record.patChars, record.patTypes) } @@ -674,6 +787,16 @@ func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord { return nil } +func (p *MySQLPrivilege) matchGlobalPriv(user, host string) *globalPrivRecord { + for i := 0; i < len(p.Global); i++ { + record := &p.Global[i] + if record.match(user, host) { + return record + } + } + return nil +} + func (p *MySQLPrivilege) matchUser(user, host string) *UserRecord { records, exists := p.UserMap[user] if exists { diff --git a/privilege/privileges/cache_test.go b/privilege/privileges/cache_test.go index 8e837ad624e41..c8b7d86348c1e 100644 --- a/privilege/privileges/cache_test.go +++ b/privilege/privileges/cache_test.go @@ -77,6 +77,25 @@ func (s *testCacheSuite) TestLoadUserTable(c *C) { c.Assert(user[3].Privileges, Equals, mysql.CreateUserPriv|mysql.IndexPriv|mysql.ExecutePriv|mysql.CreateViewPriv|mysql.ShowViewPriv|mysql.ShowDBPriv|mysql.SuperPriv|mysql.TriggerPriv) } +func (s *testCacheSuite) TestLoadGlobalPrivTable(c *C) { + se, err := session.CreateSession4Test(s.store) + c.Assert(err, IsNil) + defer se.Close() + mustExec(c, se, "use mysql;") + mustExec(c, se, "truncate table global_priv") + + mustExec(c, se, `INSERT INTO mysql.global_priv VALUES ("%", "tu", "{\"access\":0,\"plugin\":\"mysql_native_password\",\"ssl_type\":3, \"ssl_cipher\":\"cipher\",\"x509_subject\":\"\C=ZH1\", \"x509_issuer\":\"\C=ZH2\", \"password_last_changed\":1}")`) + + var p privileges.MySQLPrivilege + err = p.LoadGlobalPrivTable(se) + c.Assert(err, IsNil) + c.Assert(p.Global[0].Host, Equals, `%`) + c.Assert(p.Global[0].User, Equals, `tu`) + c.Assert(p.Global[0].Priv.SSLType, Equals, privileges.SslTypeSpecified) + c.Assert(p.Global[0].Priv.X509Issuer, Equals, "C=ZH2") + c.Assert(p.Global[0].Priv.X509Subject, Equals, "C=ZH1") +} + func (s *testCacheSuite) TestLoadDBTable(c *C) { se, err := session.CreateSession4Test(s.store) c.Assert(err, IsNil) @@ -402,6 +421,25 @@ func (s *testCacheSuite) TestSortUserTable(c *C) { checkUserRecord(p.User, result, c) } +func (s *testCacheSuite) TestGlobalPrivValueRequireStr(c *C) { + var ( + none = privileges.GlobalPrivValue{SSLType: privileges.SslTypeNone} + tls = privileges.GlobalPrivValue{SSLType: privileges.SslTypeAny} + x509 = privileges.GlobalPrivValue{SSLType: privileges.SslTypeX509} + spec = privileges.GlobalPrivValue{SSLType: privileges.SslTypeSpecified, SSLCipher: "c1", X509Subject: "s1", X509Issuer: "i1"} + spec2 = privileges.GlobalPrivValue{SSLType: privileges.SslTypeSpecified, X509Subject: "s1", X509Issuer: "i1"} + spec3 = privileges.GlobalPrivValue{SSLType: privileges.SslTypeSpecified, X509Issuer: "i1"} + spec4 = privileges.GlobalPrivValue{SSLType: privileges.SslTypeSpecified} + ) + c.Assert(none.RequireStr(), Equals, "NONE") + c.Assert(tls.RequireStr(), Equals, "SSL") + c.Assert(x509.RequireStr(), Equals, "X509") + c.Assert(spec.RequireStr(), Equals, "CIPHER 'c1' ISSUER 'i1' SUBJECT 's1'") + c.Assert(spec2.RequireStr(), Equals, "ISSUER 'i1' SUBJECT 's1'") + c.Assert(spec3.RequireStr(), Equals, "ISSUER 'i1'") + c.Assert(spec4.RequireStr(), Equals, "NONE") +} + func checkUserRecord(x, y []privileges.UserRecord, c *C) { c.Assert(len(x), Equals, len(y)) for i := 0; i < len(x); i++ { diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index aaaaff623afc4..144e686548ebb 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -14,6 +14,8 @@ package privileges import ( + "crypto/tls" + "fmt" "strings" "github.com/pingcap/parser/auth" @@ -21,6 +23,7 @@ import ( "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" ) @@ -101,7 +104,7 @@ func (p *UserPrivileges) GetEncodedPassword(user, host string) string { } // ConnectionVerification implements the Manager interface. -func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte) (u string, h string, success bool) { +func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (u string, h string, success bool) { if SkipWithGrant { p.user = user p.host = host @@ -120,6 +123,16 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio u = record.User h = record.Host + globalPriv := mysqlPriv.matchGlobalPriv(user, host) + if globalPriv != nil { + if !p.checkSSL(globalPriv, tlsState) { + logutil.BgLogger().Error("global priv check ssl fail", + zap.String("user", user), zap.String("host", host)) + success = false + return + } + } + // Login a locked account is not allowed. locked := record.AccountLocked if locked { @@ -163,6 +176,102 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio return } +type checkResult int + +const ( + notCheck checkResult = iota + pass + fail +) + +func (p *UserPrivileges) checkSSL(priv *globalPrivRecord, tlsState *tls.ConnectionState) bool { + if priv.Broken { + logutil.BgLogger().Debug("ssl check failure, due to broken global_priv record", + zap.String("user", priv.User), zap.String("host", priv.Host)) + return false + } + switch priv.Priv.SSLType { + case SslTypeNotSpecified, SslTypeNone: + return true + case SslTypeAny: + r := tlsState != nil + if !r { + logutil.BgLogger().Debug("ssl check failure, require ssl but not use ssl", + zap.String("user", priv.User), zap.String("host", priv.Host)) + } + return r + case SslTypeX509: + if tlsState == nil { + logutil.BgLogger().Debug("ssl check failure, require x509 but not use ssl", + zap.String("user", priv.User), zap.String("host", priv.Host)) + return false + } + hasCert := false + for _, chain := range tlsState.VerifiedChains { + if len(chain) > 0 { + hasCert = true + break + } + } + if !hasCert { + logutil.BgLogger().Debug("ssl check failure, require x509 but no verified cert", + zap.String("user", priv.User), zap.String("host", priv.Host)) + } + return hasCert + case SslTypeSpecified: + if tlsState == nil { + logutil.BgLogger().Debug("ssl check failure, require subject/issuer/cipher but not use ssl", + zap.String("user", priv.User), zap.String("host", priv.Host)) + return false + } + if len(priv.Priv.SSLCipher) > 0 && priv.Priv.SSLCipher != util.TLSCipher2String(tlsState.CipherSuite) { + logutil.BgLogger().Debug("ssl check failure for cipher", zap.String("user", priv.User), zap.String("host", priv.Host), + zap.String("require", priv.Priv.SSLCipher), zap.String("given", util.TLSCipher2String(tlsState.CipherSuite))) + return false + } + var ( + hasCert = false + matchIssuer checkResult + matchSubject checkResult + ) + for _, chain := range tlsState.VerifiedChains { + if len(chain) == 0 { + continue + } + cert := chain[0] + if len(priv.Priv.X509Issuer) > 0 { + given := util.X509NameOnline(cert.Issuer) + if priv.Priv.X509Issuer == given { + matchIssuer = pass + } else if matchIssuer == notCheck { + matchIssuer = fail + logutil.BgLogger().Debug("ssl check failure for issuer", zap.String("user", priv.User), zap.String("host", priv.Host), + zap.String("require", priv.Priv.X509Issuer), zap.String("given", given)) + } + } + if len(priv.Priv.X509Subject) > 0 { + given := util.X509NameOnline(cert.Subject) + if priv.Priv.X509Subject == given { + matchSubject = pass + } else if matchSubject == notCheck { + matchSubject = fail + logutil.BgLogger().Debug("ssl check failure for subject", zap.String("user", priv.User), zap.String("host", priv.Host), + zap.String("require", priv.Priv.X509Subject), zap.String("given", given)) + } + } + hasCert = true + } + checkResult := hasCert && matchIssuer != fail && matchSubject != fail + if !checkResult && !hasCert { + logutil.BgLogger().Debug("ssl check failure, require issuer/subject but no verified cert", + zap.String("user", priv.User), zap.String("host", priv.Host)) + } + return checkResult + default: + panic(fmt.Sprintf("support ssl_type: %d", priv.Priv.SSLType)) + } +} + // DBIsVisible implements the Manager interface. func (p *UserPrivileges) DBIsVisible(activeRoles []*auth.RoleIdentity, db string) bool { if SkipWithGrant { diff --git a/privilege/privileges/privileges_test.go b/privilege/privileges/privileges_test.go index 50c22a231d7e2..fe119f3b965f8 100644 --- a/privilege/privileges/privileges_test.go +++ b/privilege/privileges/privileges_test.go @@ -16,6 +16,9 @@ package privileges_test import ( "bytes" "context" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "fmt" "strings" "testing" @@ -31,6 +34,7 @@ import ( "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/store/mockstore" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testleak" "github.com/pingcap/tidb/util/testutil" @@ -446,17 +450,239 @@ func (s *testPrivilegeSuite) TestSelectViewSecurity(c *C) { } func (s *testPrivilegeSuite) TestRoleAdminSecurity(c *C) { + se := newSession(c, s.store, s.dbName) + mustExec(c, se, `CREATE USER 'ar1'@'localhost';`) + mustExec(c, se, `CREATE USER 'ar2'@'localhost';`) + mustExec(c, se, `GRANT ALL ON *.* to ar1@localhost`) + defer func() { + c.Assert(se.Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil), IsTrue) + mustExec(c, se, "drop user 'ar1'@'localhost'") + mustExec(c, se, "drop user 'ar2'@'localhost'") + }() + + c.Assert(se.Auth(&auth.UserIdentity{Username: "ar1", Hostname: "localhost"}, nil, nil), IsTrue) + mustExec(c, se, `create role r_test1@localhost`) + + c.Assert(se.Auth(&auth.UserIdentity{Username: "ar2", Hostname: "localhost"}, nil, nil), IsTrue) + _, err := se.Execute(context.Background(), `create role r_test2@localhost`) + c.Assert(terror.ErrorEqual(err, core.ErrSpecificAccessDenied), IsTrue) +} + +func (s *testPrivilegeSuite) TestCheckCertBasedAuth(c *C) { se := newSession(c, s.store, s.dbName) mustExec(c, se, `CREATE USER 'r1'@'localhost';`) - mustExec(c, se, `CREATE USER 'r2'@'localhost';`) - mustExec(c, se, `GRANT ALL ON *.* to r1@localhost`) + mustExec(c, se, `CREATE USER 'r2'@'localhost' require none;`) + mustExec(c, se, `CREATE USER 'r3'@'localhost' require ssl;`) + mustExec(c, se, `CREATE USER 'r4'@'localhost' require x509;`) + mustExec(c, se, `CREATE USER 'r5'@'localhost' require issuer '/C=US/ST=California/L=San Francisco/O=PingCAP/OU=TiDB/CN=TiDB admin' + subject '/C=ZH/ST=Beijing/L=Haidian/O=PingCAP.Inc/OU=TiDB/CN=tester1' cipher 'TLS_AES_128_GCM_SHA256'`) + mustExec(c, se, `CREATE USER 'r6'@'localhost' require issuer '/C=US/ST=California/L=San Francisco/O=PingCAP/OU=TiDB/CN=TiDB admin' + subject '/C=ZH/ST=Beijing/L=Haidian/O=PingCAP.Inc/OU=TiDB/CN=tester1'`) + mustExec(c, se, `CREATE USER 'r7_issuer_only'@'localhost' require issuer '/C=US/ST=California/L=San Francisco/O=PingCAP/OU=TiDB/CN=TiDB admin'`) + mustExec(c, se, `CREATE USER 'r8_subject_only'@'localhost' require subject '/C=ZH/ST=Beijing/L=Haidian/O=PingCAP.Inc/OU=TiDB/CN=tester1'`) + mustExec(c, se, `CREATE USER 'r9_subject_disorder'@'localhost' require subject '/ST=Beijing/C=ZH/L=Haidian/O=PingCAP.Inc/OU=TiDB/CN=tester1'`) + mustExec(c, se, `CREATE USER 'r10_issuer_disorder'@'localhost' require issuer '/ST=California/C=US/L=San Francisco/O=PingCAP/OU=TiDB/CN=TiDB admin'`) + mustExec(c, se, `CREATE USER 'r11_cipher_only'@'localhost' require cipher 'TLS_AES_256_GCM_SHA384'`) + mustExec(c, se, `CREATE USER 'r12_old_tidb_user'@'localhost'`) + mustExec(c, se, "DELETE FROM mysql.global_priv WHERE `user` = 'r12_old_tidb_user' and `host` = 'localhost'") + mustExec(c, se, `CREATE USER 'r13_broken_user'@'localhost'require issuer '/C=US/ST=California/L=San Francisco/O=PingCAP/OU=TiDB/CN=TiDB admin' + subject '/C=ZH/ST=Beijing/L=Haidian/O=PingCAP.Inc/OU=TiDB/CN=tester1'`) + mustExec(c, se, "UPDATE mysql.global_priv set priv = 'abc' where `user` = 'r13_broken_user' and `host` = 'localhost'") + mustExec(c, se, "flush privileges") + + defer func() { + c.Assert(se.Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil), IsTrue) + mustExec(c, se, "drop user 'r1'@'localhost'") + mustExec(c, se, "drop user 'r2'@'localhost'") + mustExec(c, se, "drop user 'r3'@'localhost'") + mustExec(c, se, "drop user 'r4'@'localhost'") + mustExec(c, se, "drop user 'r5'@'localhost'") + mustExec(c, se, "drop user 'r6'@'localhost'") + mustExec(c, se, "drop user 'r7_issuer_only'@'localhost'") + mustExec(c, se, "drop user 'r8_subject_only'@'localhost'") + mustExec(c, se, "drop user 'r9_subject_disorder'@'localhost'") + mustExec(c, se, "drop user 'r10_issuer_disorder'@'localhost'") + mustExec(c, se, "drop user 'r11_cipher_only'@'localhost'") + mustExec(c, se, "drop user 'r12_old_tidb_user'@'localhost'") + mustExec(c, se, "drop user 'r13_broken_user'@'localhost'") + }() + + // test without ssl or ca + c.Assert(se.Auth(&auth.UserIdentity{Username: "r1", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r2", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r3", Hostname: "localhost"}, nil, nil), IsFalse) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r4", Hostname: "localhost"}, nil, nil), IsFalse) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r5", Hostname: "localhost"}, nil, nil), IsFalse) + // test use ssl without ca + se.GetSessionVars().TLSConnectionState = &tls.ConnectionState{VerifiedChains: nil} c.Assert(se.Auth(&auth.UserIdentity{Username: "r1", Hostname: "localhost"}, nil, nil), IsTrue) - mustExec(c, se, `create role r_test1@localhost`) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r2", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r3", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r4", Hostname: "localhost"}, nil, nil), IsFalse) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r5", Hostname: "localhost"}, nil, nil), IsFalse) + // test use ssl with signed but info wrong ca. + se.GetSessionVars().TLSConnectionState = &tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{{{}}}} + c.Assert(se.Auth(&auth.UserIdentity{Username: "r1", Hostname: "localhost"}, nil, nil), IsTrue) c.Assert(se.Auth(&auth.UserIdentity{Username: "r2", Hostname: "localhost"}, nil, nil), IsTrue) - _, err := se.Execute(context.Background(), `create role r_test2@localhost`) - c.Assert(terror.ErrorEqual(err, core.ErrSpecificAccessDenied), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r3", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r4", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r5", Hostname: "localhost"}, nil, nil), IsFalse) + + // test a all pass case + se.GetSessionVars().TLSConnectionState = connectionState( + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "US"), + util.MockPkixAttribute(util.Province, "California"), + util.MockPkixAttribute(util.Locality, "San Francisco"), + util.MockPkixAttribute(util.Organization, "PingCAP"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "TiDB admin"), + }, + }, + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "ZH"), + util.MockPkixAttribute(util.Province, "Beijing"), + util.MockPkixAttribute(util.Locality, "Haidian"), + util.MockPkixAttribute(util.Organization, "PingCAP.Inc"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "tester1"), + }, + }, + tls.TLS_AES_128_GCM_SHA256) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r1", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r2", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r3", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r4", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r5", Hostname: "localhost"}, nil, nil), IsTrue) + + // test require but give nothing + se.GetSessionVars().TLSConnectionState = nil + c.Assert(se.Auth(&auth.UserIdentity{Username: "r5", Hostname: "localhost"}, nil, nil), IsFalse) + + // test mismatch cipher + se.GetSessionVars().TLSConnectionState = connectionState( + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "US"), + util.MockPkixAttribute(util.Province, "California"), + util.MockPkixAttribute(util.Locality, "San Francisco"), + util.MockPkixAttribute(util.Organization, "PingCAP"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "TiDB admin"), + }, + }, + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "ZH"), + util.MockPkixAttribute(util.Province, "Beijing"), + util.MockPkixAttribute(util.Locality, "Haidian"), + util.MockPkixAttribute(util.Organization, "PingCAP.Inc"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "tester1"), + }, + }, + tls.TLS_AES_256_GCM_SHA384) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r5", Hostname: "localhost"}, nil, nil), IsFalse) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r6", Hostname: "localhost"}, nil, nil), IsTrue) // not require cipher + c.Assert(se.Auth(&auth.UserIdentity{Username: "r11_cipher_only", Hostname: "localhost"}, nil, nil), IsTrue) + + // test only subject or only issuer + se.GetSessionVars().TLSConnectionState = connectionState( + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "US"), + util.MockPkixAttribute(util.Province, "California"), + util.MockPkixAttribute(util.Locality, "San Francisco"), + util.MockPkixAttribute(util.Organization, "PingCAP"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "TiDB admin"), + }, + }, + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "AZ"), + util.MockPkixAttribute(util.Province, "Beijing"), + util.MockPkixAttribute(util.Locality, "Shijingshang"), + util.MockPkixAttribute(util.Organization, "CAPPing.Inc"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "tester2"), + }, + }, + tls.TLS_AES_128_GCM_SHA256) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r7_issuer_only", Hostname: "localhost"}, nil, nil), IsTrue) + se.GetSessionVars().TLSConnectionState = connectionState( + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "AU"), + util.MockPkixAttribute(util.Province, "California"), + util.MockPkixAttribute(util.Locality, "San Francisco"), + util.MockPkixAttribute(util.Organization, "PingCAP"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "TiDB admin2"), + }, + }, + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "ZH"), + util.MockPkixAttribute(util.Province, "Beijing"), + util.MockPkixAttribute(util.Locality, "Haidian"), + util.MockPkixAttribute(util.Organization, "PingCAP.Inc"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "tester1"), + }, + }, + tls.TLS_AES_128_GCM_SHA256) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r8_subject_only", Hostname: "localhost"}, nil, nil), IsTrue) + + // test disorder issuer or subject + se.GetSessionVars().TLSConnectionState = connectionState( + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{}, + }, + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "ZH"), + util.MockPkixAttribute(util.Province, "Beijing"), + util.MockPkixAttribute(util.Locality, "Haidian"), + util.MockPkixAttribute(util.Organization, "PingCAP.Inc"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "tester1"), + }, + }, + tls.TLS_AES_128_GCM_SHA256) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r9_subject_disorder", Hostname: "localhost"}, nil, nil), IsFalse) + se.GetSessionVars().TLSConnectionState = connectionState( + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + util.MockPkixAttribute(util.Country, "US"), + util.MockPkixAttribute(util.Province, "California"), + util.MockPkixAttribute(util.Locality, "San Francisco"), + util.MockPkixAttribute(util.Organization, "PingCAP"), + util.MockPkixAttribute(util.OrganizationalUnit, "TiDB"), + util.MockPkixAttribute(util.CommonName, "TiDB admin"), + }, + }, + pkix.Name{ + Names: []pkix.AttributeTypeAndValue{}, + }, + tls.TLS_AES_128_GCM_SHA256) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r10_issuer_disorder", Hostname: "localhost"}, nil, nil), IsFalse) + + // test old data and broken data + c.Assert(se.Auth(&auth.UserIdentity{Username: "r12_old_tidb_user", Hostname: "localhost"}, nil, nil), IsTrue) + c.Assert(se.Auth(&auth.UserIdentity{Username: "r13_broken_user", Hostname: "localhost"}, nil, nil), IsFalse) + +} + +func connectionState(issuer, subject pkix.Name, cipher uint16) *tls.ConnectionState { + return &tls.ConnectionState{ + VerifiedChains: [][]*x509.Certificate{{{Issuer: issuer, Subject: subject}}}, + CipherSuite: cipher, + } } func (s *testPrivilegeSuite) TestCheckAuthenticate(c *C) { diff --git a/server/tidb_test.go b/server/tidb_test.go index 4f787990754e8..9b1a71111c67e 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -22,6 +22,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "github.com/pingcap/tidb/util" "io/ioutil" "math/big" "os" @@ -219,7 +220,7 @@ func generateCert(sn int, commonName string, parentCert *x509.Certificate, paren template := x509.Certificate{ SerialNumber: big.NewInt(int64(sn)), - Subject: pkix.Name{CommonName: commonName}, + Subject: pkix.Name{CommonName: commonName, Names: []pkix.AttributeTypeAndValue{util.MockPkixAttribute(util.CommonName, commonName)}}, NotBefore: notBefore, NotAfter: notAfter, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, diff --git a/session/bootstrap.go b/session/bootstrap.go index 59d83a088e1e1..f1364e15ab069 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -75,6 +75,13 @@ const ( Account_locked ENUM('N','Y') NOT NULL DEFAULT 'N', Shutdown_priv ENUM('N','Y') NOT NULL DEFAULT 'N', PRIMARY KEY (Host, User));` + // CreateGlobalPrivTable is the SQL statement creates Global scope privilege table in system db. + CreateGlobalPrivTable = "CREATE TABLE if not exists mysql.global_priv (" + + "Host char(60) NOT NULL DEFAULT ''," + + "User char(80) NOT NULL DEFAULT ''," + + "Priv longtext NOT NULL DEFAULT ''," + + "PRIMARY KEY (Host, User)" + + ")" // CreateDBPrivTable is the SQL statement creates DB scope privilege table in system db. CreateDBPrivTable = `CREATE TABLE if not exists mysql.db ( Host CHAR(60), @@ -353,6 +360,7 @@ const ( version35 = 35 version36 = 36 version37 = 37 + version38 = 38 ) func checkBootstrapped(s Session) (bool, error) { @@ -561,6 +569,10 @@ func upgrade(s Session) { upgradeToVer37(s) } + if ver < version38 { + upgradeToVer38(s) + } + updateBootstrapVer(s) _, err = s.Execute(context.Background(), "COMMIT") @@ -890,6 +902,14 @@ func upgradeToVer37(s Session) { mustExecute(s, sql) } +func upgradeToVer38(s Session) { + var err error + _, err = s.Execute(context.Background(), CreateGlobalPrivTable) + if err != nil { + logutil.BgLogger().Fatal("upgradeToVer38 error", zap.Error(err)) + } +} + // updateBootstrapVer updates bootstrap version variable in mysql.TiDB table. func updateBootstrapVer(s Session) { // Update bootstrap version. @@ -919,6 +939,7 @@ func doDDLWorks(s Session) { // Create user table. mustExecute(s, CreateUserTable) // Create privilege tables. + mustExecute(s, CreateGlobalPrivTable) mustExecute(s, CreateDBPrivTable) mustExecute(s, CreateTablePrivTable) mustExecute(s, CreateColumnPrivTable) diff --git a/session/bootstrap_test.go b/session/bootstrap_test.go index be91eae2e8d6c..5b027deb2e7ea 100644 --- a/session/bootstrap_test.go +++ b/session/bootstrap_test.go @@ -61,6 +61,7 @@ func (s *testBootstrapSuite) TestBootstrap(c *C) { c.Assert(se.Auth(&auth.UserIdentity{Username: "root", Hostname: "anyhost"}, []byte(""), []byte("")), IsTrue) mustExecSQL(c, se, "USE test;") // Check privilege tables. + mustExecSQL(c, se, "SELECT * from mysql.global_priv;") mustExecSQL(c, se, "SELECT * from mysql.db;") mustExecSQL(c, se, "SELECT * from mysql.tables_priv;") mustExecSQL(c, se, "SELECT * from mysql.columns_priv;") @@ -165,6 +166,7 @@ func (s *testBootstrapSuite) TestBootstrapWithError(c *C) { mustExecSQL(c, se, "USE test;") // Check privilege tables. + mustExecSQL(c, se, "SELECT * from mysql.global_priv;") mustExecSQL(c, se, "SELECT * from mysql.db;") mustExecSQL(c, se, "SELECT * from mysql.tables_priv;") mustExecSQL(c, se, "SELECT * from mysql.columns_priv;") diff --git a/session/session.go b/session/session.go index 44e12de2a0736..f19408ea0c223 100644 --- a/session/session.go +++ b/session/session.go @@ -1472,7 +1472,7 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by // Check IP or localhost. var success bool - user.AuthUsername, user.AuthHostname, success = pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt) + user.AuthUsername, user.AuthHostname, success = pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt, s.sessionVars.TLSConnectionState) if success { s.sessionVars.User = user s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) @@ -1485,7 +1485,7 @@ func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []by // Check Hostname. for _, addr := range getHostByIP(user.Hostname) { - u, h, success := pm.ConnectionVerification(user.Username, addr, authentication, salt) + u, h, success := pm.ConnectionVerification(user.Username, addr, authentication, salt, s.sessionVars.TLSConnectionState) if success { s.sessionVars.User = &auth.UserIdentity{ Username: user.Username, @@ -1744,7 +1744,7 @@ func CreateSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, er const ( notBootstrapped = 0 - currentBootstrapVersion = version37 + currentBootstrapVersion = version38 ) func getStoreBootstrapVersion(store kv.Storage) int64 { diff --git a/sessionctx/variable/statusvar.go b/sessionctx/variable/statusvar.go index 31250195dd746..7d2fc4a11efa6 100644 --- a/sessionctx/variable/statusvar.go +++ b/sessionctx/variable/statusvar.go @@ -17,6 +17,8 @@ import ( "bytes" "crypto/tls" "sync" + + "github.com/pingcap/tidb/util" ) var statisticsList []Statistics @@ -92,6 +94,9 @@ var tlsCiphers = []uint16{ tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + tls.TLS_CHACHA20_POLY1305_SHA256, } var tlsSupportedCiphers string @@ -102,32 +107,7 @@ var tlsVersionString = map[uint16]string{ tls.VersionTLS10: "TLSv1", tls.VersionTLS11: "TLSv1.1", tls.VersionTLS12: "TLSv1.2", -} - -// Taken from https://testssl.sh/openssl-rfc.mapping.html . -var tlsCipherString = map[uint16]string{ - tls.TLS_RSA_WITH_RC4_128_SHA: "RC4-SHA", - tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: "DES-CBC3-SHA", - tls.TLS_RSA_WITH_AES_128_CBC_SHA: "AES128-SHA", - tls.TLS_RSA_WITH_AES_256_CBC_SHA: "AES256-SHA", - tls.TLS_RSA_WITH_AES_128_CBC_SHA256: "AES128-SHA256", - tls.TLS_RSA_WITH_AES_128_GCM_SHA256: "AES128-GCM-SHA256", - tls.TLS_RSA_WITH_AES_256_GCM_SHA384: "AES256-GCM-SHA384", - tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: "ECDHE-ECDSA-RC4-SHA", - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: "ECDHE-ECDSA-AES128-SHA", - tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: "ECDHE-ECDSA-AES256-SHA", - tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA: "ECDHE-RSA-RC4-SHA", - tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: "ECDHE-RSA-DES-CBC3-SHA", - tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: "ECDHE-RSA-AES128-SHA", - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: "ECDHE-RSA-AES256-SHA", - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: "ECDHE-ECDSA-AES128-SHA256", - tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: "ECDHE-RSA-AES128-SHA256", - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: "ECDHE-RSA-AES128-GCM-SHA256", - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "ECDHE-ECDSA-AES128-GCM-SHA256", - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: "ECDHE-RSA-AES256-GCM-SHA384", - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: "ECDHE-ECDSA-AES256-GCM-SHA384", - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: "ECDHE-RSA-CHACHA20-POLY1305", - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: "ECDHE-ECDSA-CHACHA20-POLY1305", + tls.VersionTLS13: "TLSv1.3", } var defaultStatus = map[string]*StatusVal{ @@ -153,7 +133,7 @@ func (s defaultStatusStat) Stats(vars *SessionVars) (map[string]interface{}, err // `vars` may be nil in unit tests. if vars != nil && vars.TLSConnectionState != nil { - statusVars["Ssl_cipher"] = tlsCipherString[vars.TLSConnectionState.CipherSuite] + statusVars["Ssl_cipher"] = util.TLSCipher2String(vars.TLSConnectionState.CipherSuite) statusVars["Ssl_cipher_list"] = tlsSupportedCiphers // tls.VerifyClientCertIfGiven == SSL_VERIFY_PEER | SSL_VERIFY_CLIENT_ONCE statusVars["Ssl_verify_mode"] = 0x01 | 0x04 @@ -166,7 +146,7 @@ func (s defaultStatusStat) Stats(vars *SessionVars) (map[string]interface{}, err func init() { var ciphersBuffer bytes.Buffer for _, v := range tlsCiphers { - ciphersBuffer.WriteString(tlsCipherString[v]) + ciphersBuffer.WriteString(util.TLSCipher2String(v)) ciphersBuffer.WriteString(":") } tlsSupportedCiphers = ciphersBuffer.String() diff --git a/util/misc.go b/util/misc.go index c0d0b86ddbef8..5ba10c07295b4 100644 --- a/util/misc.go +++ b/util/misc.go @@ -14,7 +14,11 @@ package util import ( + "crypto/tls" + "crypto/x509/pkix" + "fmt" "runtime" + "strconv" "strings" "time" @@ -150,3 +154,148 @@ func IsMemOrSysDB(dbLowerName string) bool { } return false } + +// X509NameOnline prints pkix.Name into old X509_NAME_oneline format. +// https://www.openssl.org/docs/manmaster/man3/X509_NAME_oneline.html +func X509NameOnline(n pkix.Name) string { + s := make([]string, 0, len(n.Names)) + for _, name := range n.Names { + oid := name.Type.String() + // unlike MySQL, TiDB only support check pkixAttributeTypeNames fields + if n, exist := pkixAttributeTypeNames[oid]; exist { + s = append(s, n+"="+fmt.Sprint(name.Value)) + } + } + if len(s) == 0 { + return "" + } + return "/" + strings.Join(s, "/") +} + +const ( + // Country is type name for country. + Country = "C" + // Organization is type name for organization. + Organization = "O" + // OrganizationalUnit is type name for organizational unit. + OrganizationalUnit = "OU" + // Locality is type name for locality. + Locality = "L" + // Email is type name for email. + Email = "emailAddress" + // CommonName is type name for common name. + CommonName = "CN" + // Province is type name for province or state. + Province = "ST" +) + +// see go/src/crypto/x509/pkix/pkix.go:attributeTypeNames +var pkixAttributeTypeNames = map[string]string{ + "2.5.4.6": Country, + "2.5.4.10": Organization, + "2.5.4.11": OrganizationalUnit, + "2.5.4.3": CommonName, + "2.5.4.5": "SERIALNUMBER", + "2.5.4.7": Locality, + "2.5.4.8": Province, + "2.5.4.9": "STREET", + "2.5.4.17": "POSTALCODE", + "1.2.840.113549.1.9.1": Email, +} + +var pkixTypeNameAttributes = make(map[string]string) + +// MockPkixAttribute generates mock AttributeTypeAndValue. +// only used for test. +func MockPkixAttribute(name, value string) pkix.AttributeTypeAndValue { + n, exists := pkixTypeNameAttributes[name] + if !exists { + panic(fmt.Sprintf("unsupport mock type: %s", name)) + } + var vs []int + for _, v := range strings.Split(n, ".") { + i, err := strconv.Atoi(v) + if err != nil { + panic(err) + } + vs = append(vs, i) + } + return pkix.AttributeTypeAndValue{ + Type: vs, + Value: value, + } +} + +// CheckSupportX509NameOneline parses and validate input str is X509_NAME_oneline format +// and precheck check-item is supported by TiDB +// https://www.openssl.org/docs/manmaster/man3/X509_NAME_oneline.html +func CheckSupportX509NameOneline(oneline string) (err error) { + entries := strings.Split(oneline, `/`) + for _, entry := range entries { + if len(entry) == 0 { + continue + } + kvs := strings.Split(entry, "=") + if len(kvs) != 2 { + err = errors.Errorf("invalid X509_NAME input: %s", oneline) + return + } + k := kvs[0] + if _, support := pkixTypeNameAttributes[k]; !support { + err = errors.Errorf("Unsupport check '%s' in current version TiDB", k) + return + } + } + return +} + +var tlsCipherString = map[uint16]string{ + tls.TLS_RSA_WITH_RC4_128_SHA: "RC4-SHA", + tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: "DES-CBC3-SHA", + tls.TLS_RSA_WITH_AES_128_CBC_SHA: "AES128-SHA", + tls.TLS_RSA_WITH_AES_256_CBC_SHA: "AES256-SHA", + tls.TLS_RSA_WITH_AES_128_CBC_SHA256: "AES128-SHA256", + tls.TLS_RSA_WITH_AES_128_GCM_SHA256: "AES128-GCM-SHA256", + tls.TLS_RSA_WITH_AES_256_GCM_SHA384: "AES256-GCM-SHA384", + tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA: "ECDHE-ECDSA-RC4-SHA", + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: "ECDHE-ECDSA-AES128-SHA", + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: "ECDHE-ECDSA-AES256-SHA", + tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA: "ECDHE-RSA-RC4-SHA", + tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA: "ECDHE-RSA-DES-CBC3-SHA", + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: "ECDHE-RSA-AES128-SHA", + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: "ECDHE-RSA-AES256-SHA", + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: "ECDHE-ECDSA-AES128-SHA256", + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: "ECDHE-RSA-AES128-SHA256", + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: "ECDHE-RSA-AES128-GCM-SHA256", + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "ECDHE-ECDSA-AES128-GCM-SHA256", + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: "ECDHE-RSA-AES256-GCM-SHA384", + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: "ECDHE-ECDSA-AES256-GCM-SHA384", + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: "ECDHE-RSA-CHACHA20-POLY1305", + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: "ECDHE-ECDSA-CHACHA20-POLY1305", + // TLS 1.3 cipher suites, compatible with mysql using '_'. + tls.TLS_AES_128_GCM_SHA256: "TLS_AES_128_GCM_SHA256", + tls.TLS_AES_256_GCM_SHA384: "TLS_AES_256_GCM_SHA384", + tls.TLS_CHACHA20_POLY1305_SHA256: "TLS_CHACHA20_POLY1305_SHA256", +} + +// SupportCipher maintains cipher supported by TiDB. +var SupportCipher = make(map[string]struct{}, len(tlsCipherString)) + +// TLSCipher2String convert tls num to string. +// Taken from https://testssl.sh/openssl-rfc.mapping.html . +func TLSCipher2String(n uint16) string { + s, ok := tlsCipherString[n] + if !ok { + return "" + } + return s +} + +func init() { + for _, value := range tlsCipherString { + SupportCipher[value] = struct{}{} + } + for key, value := range pkixAttributeTypeNames { + pkixTypeNameAttributes[value] = key + } +} diff --git a/util/misc_test.go b/util/misc_test.go index dba9ccd4eabf1..f7ebcd1c82478 100644 --- a/util/misc_test.go +++ b/util/misc_test.go @@ -15,6 +15,7 @@ package util import ( "bytes" + "crypto/x509/pkix" "time" . "github.com/pingcap/check" @@ -121,6 +122,23 @@ func (s *testMiscSuite) TestCompatibleParseGCTime(c *C) { } } +func (s *testMiscSuite) TestX509NameParseMatch(c *C) { + check := pkix.Name{ + Names: []pkix.AttributeTypeAndValue{ + MockPkixAttribute(Country, "SE"), + MockPkixAttribute(Province, "Stockholm2"), + MockPkixAttribute(Locality, "Stockholm"), + MockPkixAttribute(Organization, "MySQL demo client certificate"), + MockPkixAttribute(OrganizationalUnit, "testUnit"), + MockPkixAttribute(CommonName, "client"), + MockPkixAttribute(Email, "client@example.com"), + }, + } + c.Assert(X509NameOnline(check), Equals, "/C=SE/ST=Stockholm2/L=Stockholm/O=MySQL demo client certificate/OU=testUnit/CN=client/emailAddress=client@example.com") + check = pkix.Name{} + c.Assert(X509NameOnline(check), Equals, "") +} + func (s *testMiscSuite) TestBasicFunc(c *C) { // Test for GetStack. b := GetStack() From 780cca8ded0fd95d5525f6b54afc405063d17327 Mon Sep 17 00:00:00 2001 From: lysu Date: Tue, 17 Dec 2019 16:13:01 +0800 Subject: [PATCH 2/4] address comments --- executor/grant.go | 2 +- go.sum | 1 - privilege/privileges/cache.go | 2 +- server/tidb_test.go | 2 +- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/executor/grant.go b/executor/grant.go index f1ce523d63729..c3b8b981b58d6 100644 --- a/executor/grant.go +++ b/executor/grant.go @@ -148,7 +148,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { } // checkAndInitGlobalPriv checks if global scope privilege entry exists in mysql.global_priv. -// If unexists, insert a new one. +// If not exists, insert a new one. func checkAndInitGlobalPriv(ctx sessionctx.Context, user string, host string) error { ok, err := globalPrivEntryExists(ctx, user, host) if err != nil { diff --git a/go.sum b/go.sum index a0beafea5a8a7..fd47a1f51bf94 100644 --- a/go.sum +++ b/go.sum @@ -180,7 +180,6 @@ github.com/pingcap/errcode v0.0.0-20180921232412-a1a7271709d9/go.mod h1:4b2X8xSq github.com/pingcap/errors v0.11.0/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/errors v0.11.5-0.20190809092503-95897b64e011 h1:58naV4XMEqm0hl9LcYo6cZoGBGiLtefMQMF/vo3XLgQ= github.com/pingcap/errors v0.11.5-0.20190809092503-95897b64e011/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= -github.com/pingcap/failpoint v0.0.0-20190512135322-30cc7431d99c h1:hvQd3aOLKLF7xvRV6DzvPkKY4QXzfVbjU1BhW0d9yL8= github.com/pingcap/failpoint v0.0.0-20190512135322-30cc7431d99c/go.mod h1:DNS3Qg7bEDhU6EXNHF+XSv/PGznQaMJ5FWvctpm6pQI= github.com/pingcap/failpoint v0.0.0-20191029060244-12f4ac2fd11d h1:F8vp38kTAckN+v8Jlc98uMBvKIzr1a+UhnLyVYn8Q5Q= github.com/pingcap/failpoint v0.0.0-20191029060244-12f4ac2fd11d/go.mod h1:DNS3Qg7bEDhU6EXNHF+XSv/PGznQaMJ5FWvctpm6pQI= diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index a8f419ec5200e..7d6fc7ae39706 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -585,7 +585,7 @@ func (p *MySQLPrivilege) decodeGlobalPrivTableRow(row chunk.Row, fs []*ast.Resul var privValue GlobalPrivValue err := json.Unmarshal(hack.Slice(privData), &privValue) if err != nil { - logutil.BgLogger().Error("one userglobal priv data is broken, forbidden login until data be fixed", + logutil.BgLogger().Error("one user global priv data is broken, forbidden login until data be fixed", zap.String("user", value.User), zap.String("host", value.Host)) value.Broken = true } else { diff --git a/server/tidb_test.go b/server/tidb_test.go index 9b1a71111c67e..e80f5de788bdb 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -22,7 +22,6 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" - "github.com/pingcap/tidb/util" "io/ioutil" "math/big" "os" @@ -38,6 +37,7 @@ import ( "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/store/mockstore" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/testkit" ) From 2ef3921ea837fc7526f2ab2f77df37250b7c7fd4 Mon Sep 17 00:00:00 2001 From: lysu Date: Wed, 18 Dec 2019 10:35:43 +0800 Subject: [PATCH 3/4] accelerate global priv search --- privilege/privileges/cache.go | 15 +++++++++++---- privilege/privileges/cache_test.go | 10 +++++----- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index 7d6fc7ae39706..210215996a086 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -218,7 +218,7 @@ type MySQLPrivilege struct { // non-full privileges (i.e. user.db entries). User []UserRecord UserMap map[string][]UserRecord // Accelerate User searching - Global []globalPrivRecord + Global map[string][]globalPrivRecord DB []dbRecord DBMap map[string][]dbRecord // Accelerate DB searching TablesPriv []tablesPrivRecord @@ -597,7 +597,10 @@ func (p *MySQLPrivilege) decodeGlobalPrivTableRow(row chunk.Row, fs []*ast.Resul } } } - p.Global = append(p.Global, value) + if p.Global == nil { + p.Global = make(map[string][]globalPrivRecord) + } + p.Global[value.User] = append(p.Global[value.User], value) return nil } @@ -788,8 +791,12 @@ func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord { } func (p *MySQLPrivilege) matchGlobalPriv(user, host string) *globalPrivRecord { - for i := 0; i < len(p.Global); i++ { - record := &p.Global[i] + uGlobal, exists := p.Global[user] + if !exists { + return nil + } + for i := 0; i < len(uGlobal); i++ { + record := &uGlobal[i] if record.match(user, host) { return record } diff --git a/privilege/privileges/cache_test.go b/privilege/privileges/cache_test.go index c8b7d86348c1e..50b381e82eb7c 100644 --- a/privilege/privileges/cache_test.go +++ b/privilege/privileges/cache_test.go @@ -89,11 +89,11 @@ func (s *testCacheSuite) TestLoadGlobalPrivTable(c *C) { var p privileges.MySQLPrivilege err = p.LoadGlobalPrivTable(se) c.Assert(err, IsNil) - c.Assert(p.Global[0].Host, Equals, `%`) - c.Assert(p.Global[0].User, Equals, `tu`) - c.Assert(p.Global[0].Priv.SSLType, Equals, privileges.SslTypeSpecified) - c.Assert(p.Global[0].Priv.X509Issuer, Equals, "C=ZH2") - c.Assert(p.Global[0].Priv.X509Subject, Equals, "C=ZH1") + c.Assert(p.Global["tu"][0].Host, Equals, `%`) + c.Assert(p.Global["tu"][0].User, Equals, `tu`) + c.Assert(p.Global["tu"][0].Priv.SSLType, Equals, privileges.SslTypeSpecified) + c.Assert(p.Global["tu"][0].Priv.X509Issuer, Equals, "C=ZH2") + c.Assert(p.Global["tu"][0].Priv.X509Subject, Equals, "C=ZH1") } func (s *testCacheSuite) TestLoadDBTable(c *C) { From 6742729c89203b5b5b90ad58d4aea485f0e9cdc1 Mon Sep 17 00:00:00 2001 From: lysu Date: Fri, 20 Dec 2019 14:50:13 +0800 Subject: [PATCH 4/4] address comments --- executor/grant.go | 25 ++++++++++++++++++++++--- executor/grant_test.go | 10 ++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/executor/grant.go b/executor/grant.go index c3b8b981b58d6..c1fe9e60a7183 100644 --- a/executor/grant.go +++ b/executor/grant.go @@ -94,9 +94,11 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { // DB scope: mysql.DB // Table scope: mysql.Tables_priv // Column scope: mysql.Columns_priv - err = checkAndInitGlobalPriv(e.ctx, user.User.Username, user.User.Hostname) - if err != nil { - return err + if e.TLSOptions != nil { + err = checkAndInitGlobalPriv(e.ctx, user.User.Username, user.User.Hostname) + if err != nil { + return err + } } switch e.Level.Level { case ast.GrantLevelDB: @@ -266,6 +268,23 @@ func tlsOption2GlobalPriv(tlsOptions []*ast.TLSOption) (priv []byte, err error) priv = []byte("{}") return } + dupSet := make(map[int]struct{}) + for _, opt := range tlsOptions { + if _, dup := dupSet[opt.Type]; dup { + var typeName string + switch opt.Type { + case ast.Cipher: + typeName = "CIPHER" + case ast.Issuer: + typeName = "ISSUER" + case ast.Subject: + typeName = "SUBJECT" + } + err = errors.Errorf("Duplicate require %s clause", typeName) + return + } + dupSet[opt.Type] = struct{}{} + } gp := privileges.GlobalPrivValue{SSLType: privileges.SslTypeNotSpecified} for _, tlsOpt := range tlsOptions { switch tlsOpt.Type { diff --git a/executor/grant_test.go b/executor/grant_test.go index d61ab8f12f4dc..d5724a4cda089 100644 --- a/executor/grant_test.go +++ b/executor/grant_test.go @@ -296,4 +296,14 @@ func (s *testSuite3) TestMaintainRequire(c *C) { c.Assert(err, NotNil) _, err = tk.Exec(`CREATE USER 'u8'@'%' require subject '/CN'`) c.Assert(err, NotNil) + _, err = tk.Exec(`CREATE USER 'u9'@'%' require cipher 'TLS_AES_256_GCM_SHA384' cipher 'RC4-SHA'`) + c.Assert(err.Error(), Equals, "Duplicate require CIPHER clause") + _, err = tk.Exec(`CREATE USER 'u9'@'%' require issuer 'CN=TiDB,OU=PingCAP' issuer 'CN=TiDB,OU=PingCAP2'`) + c.Assert(err.Error(), Equals, "Duplicate require ISSUER clause") + _, err = tk.Exec(`CREATE USER 'u9'@'%' require subject '/CN=TiDB\OU=PingCAP' subject '/CN=TiDB\OU=PingCAP2'`) + c.Assert(err.Error(), Equals, "Duplicate require SUBJECT clause") + _, err = tk.Exec(`CREATE USER 'u9'@'%' require ssl ssl`) + c.Assert(err, NotNil) + _, err = tk.Exec(`CREATE USER 'u9'@'%' require x509 x509`) + c.Assert(err, NotNil) }