diff --git a/controllers/mysql_combined_test.go b/controllers/mysql_combined_test.go index 89efdd11a07..492e15c1cd7 100644 --- a/controllers/mysql_combined_test.go +++ b/controllers/mysql_combined_test.go @@ -9,10 +9,14 @@ import ( "context" "testing" + "github.com/stretchr/testify/assert" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" azurev1alpha1 "github.com/Azure/azure-service-operator/api/v1alpha1" "github.com/Azure/azure-service-operator/api/v1alpha2" + "github.com/Azure/azure-service-operator/pkg/resourcemanager/mysql" + "github.com/Azure/azure-service-operator/pkg/resourcemanager/mysql/mysqluser" ) func TestMySQLHappyPath(t *testing.T) { @@ -32,6 +36,7 @@ func TestMySQLHappyPath(t *testing.T) { RequireInstance(ctx, t, tc, mySQLServerInstance) // Create a mySQL replica + mySQLReplicaInstance := v1alpha2.NewReplicaMySQLServer(mySQLReplicaName, rgName, rgLocation, mySQLServerInstance.Status.ResourceId) mySQLReplicaInstance.Spec.StorageProfile = nil @@ -76,6 +81,8 @@ func TestMySQLHappyPath(t *testing.T) { ruleName := GenerateTestResourceNameWithRandom("mysql-fw", 10) + // This rule opens access to the public internet, but in this case + // there's literally no data in the database ruleInstance := &azurev1alpha1.MySQLFirewallRule{ ObjectMeta: metav1.ObjectMeta{ Name: ruleName, @@ -85,12 +92,15 @@ func TestMySQLHappyPath(t *testing.T) { Server: mySQLServerName, ResourceGroup: rgName, StartIPAddress: "0.0.0.0", - EndIPAddress: "0.0.0.0", + EndIPAddress: "255.255.255.255", }, } EnsureInstance(ctx, t, tc, ruleInstance) + // Create user and ensure it can be updated + RunMySQLUserHappyPath(ctx, t, mySQLServerName,mySQLDBName, rgName) + // Create VNet and VNetRules ----- RunMySqlVNetRuleHappyPath(t, mySQLServerName, rgLocation) @@ -99,3 +109,69 @@ func TestMySQLHappyPath(t *testing.T) { EnsureDelete(ctx, t, tc, mySQLServerInstance) EnsureDelete(ctx, t, tc, mySQLReplicaInstance) } + +func RunMySQLUserHappyPath(ctx context.Context, t *testing.T, mySQLServerName string, mySQLDBName string, rgName string) { + assert := assert.New(t) + + // Create a user in the DB + username := GenerateTestResourceNameWithRandom("user", 10) + user := &azurev1alpha1.MySQLUser{ + ObjectMeta: metav1.ObjectMeta{ + Name: username, + Namespace: "default", + }, + Spec: azurev1alpha1.MySQLUserSpec{ + ResourceGroup: rgName, + Server: mySQLServerName, + DbName: mySQLDBName, + Username: username, + Roles: []string{"SELECT"}, + }, + } + EnsureInstance(ctx, t, tc, user) + + // Update user + namespacedName := types.NamespacedName{Name: username, Namespace: "default"} + err := tc.k8sClient.Get(ctx, namespacedName, user) + assert.NoError(err) + + updatedRoles := []string{"UPDATE", "DELETE", "CREATE", "DROP"} + user.Spec.Roles = updatedRoles + err = tc.k8sClient.Update(ctx, user) + assert.NoError(err) + + // TODO: Ugh this is fragile, the path to the secret should probably be set on the status? + // See issue here: https://github.com/Azure/azure-service-operator/issues/1318 + secretNamespacedName := types.NamespacedName{Name: mySQLServerName, Namespace: "default"} + adminSecret, err := tc.secretClient.Get(ctx, secretNamespacedName) + adminUser := string(adminSecret["fullyQualifiedUsername"]) + adminPassword := string(adminSecret[mysqluser.MSecretPasswordKey]) + fullServerName := string(adminSecret["fullyQualifiedServerName"]) + + db, err := mysql.ConnectToSqlDB( + ctx, + mysql.MySQLDriverName, + fullServerName, + mySQLDBName, + mysql.MySQLServerPort, + adminUser, + adminPassword) + assert.NoError(err) + + assert.Eventually(func() bool { + roles, err := mysql.ExtractUserRoles(ctx, db, username, mySQLDBName) + assert.NoError(err) + + if len(roles) != len(updatedRoles) { + return false + } + + for _, role := range updatedRoles { + if _, ok := roles[role]; !ok { + return false + } + } + + return true + }, tc.timeout, tc.retry, "waiting for DB user to be updated") +} diff --git a/pkg/helpers/sqlrole.go b/pkg/helpers/sqlrole.go new file mode 100644 index 00000000000..3eb5e948aa8 --- /dev/null +++ b/pkg/helpers/sqlrole.go @@ -0,0 +1,29 @@ +package helpers + +type SQLRoleDelta struct { + AddedRoles map[string]struct{} + DeletedRoles map[string]struct{} +} + +func DiffCurrentAndExpectedSQLRoles(currentRoles map[string]struct{}, expectedRoles map[string]struct{}) SQLRoleDelta { + result := SQLRoleDelta{ + AddedRoles: make(map[string]struct{}), + DeletedRoles: make(map[string]struct{}), + } + + for role := range expectedRoles { + // If an expected role isn't in the current role set, we need to add it + if _, ok := currentRoles[role]; !ok { + result.AddedRoles[role] = struct{}{} + } + } + + for role := range currentRoles { + // If a current role isn't in the expected set, we need to remove it + if _, ok := expectedRoles[role]; !ok { + result.DeletedRoles[role] = struct{}{} + } + } + + return result +} diff --git a/pkg/helpers/sqlrole_test.go b/pkg/helpers/sqlrole_test.go new file mode 100644 index 00000000000..37f8fd0278b --- /dev/null +++ b/pkg/helpers/sqlrole_test.go @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +package helpers + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDiffCurrentAndExpectedSQLRoles(t *testing.T) { + assert := assert.New(t) + + cases := []struct { + name string + currentRoles map[string]struct{} + expectedRoles map[string]struct{} + expectedRoleCreates map[string]struct{} + expectedRoleDeletes map[string]struct{} + }{ + { + name: "Current and expected equal", + currentRoles: map[string]struct{}{"USAGE": {}}, + expectedRoles: map[string]struct{}{"USAGE": {}}, + expectedRoleCreates: make(map[string]struct{}), + expectedRoleDeletes: make(map[string]struct{}), + }, + { + name: "Expected has single role more than current", + currentRoles: map[string]struct{}{"USAGE": {}}, + expectedRoles: map[string]struct{}{"USAGE": {}, "SELECT": {}}, + expectedRoleCreates: map[string]struct{}{"SELECT": {}}, + expectedRoleDeletes: make(map[string]struct{}), + }, + { + name: "Expected has single role less than current", + currentRoles: map[string]struct{}{"USAGE": {}, "SELECT": {}}, + expectedRoles: map[string]struct{}{"USAGE": {}}, + expectedRoleCreates: make(map[string]struct{}), + expectedRoleDeletes: map[string]struct{}{"SELECT": {}}, + }, + { + name: "Expected has many roles less than current", + currentRoles: map[string]struct{}{"SELECT": {}, "INSERT": {}, "UPDATE": {}, "DELETE": {}, "CREATE": {}, "DROP": {}, "RELOAD": {}}, + expectedRoles: map[string]struct{}{"SELECT": {}, "INSERT": {}}, + expectedRoleCreates: make(map[string]struct{}), + expectedRoleDeletes: map[string]struct{}{"UPDATE": {}, "DELETE": {}, "CREATE": {}, "DROP": {}, "RELOAD": {}}, + }, + { + name: "Expected has many roles more than current", + currentRoles: map[string]struct{}{"SELECT": {}, "INSERT": {}}, + expectedRoles: map[string]struct{}{"SELECT": {}, "INSERT": {}, "UPDATE": {}, "DELETE": {}, "CREATE": {}, "DROP": {}, "RELOAD": {}}, + expectedRoleCreates: map[string]struct{}{"UPDATE": {}, "DELETE": {}, "CREATE": {}, "DROP": {}, "RELOAD": {}}, + expectedRoleDeletes: make(map[string]struct{}), + }, + } + + for _, c := range cases { + c := c + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + result := DiffCurrentAndExpectedSQLRoles(c.currentRoles, c.expectedRoles) + assert.Equal(c.expectedRoleCreates, result.AddedRoles) + assert.Equal(c.expectedRoleDeletes, result.DeletedRoles) + }) + } + +} diff --git a/pkg/resourcemanager/mysql/mysqlhelper.go b/pkg/resourcemanager/mysql/mysqlhelper.go index 87aafda352c..372d35b9db0 100644 --- a/pkg/resourcemanager/mysql/mysqlhelper.go +++ b/pkg/resourcemanager/mysql/mysqlhelper.go @@ -6,6 +6,8 @@ import ( "fmt" "strings" + "github.com/pkg/errors" + "github.com/Azure/azure-service-operator/pkg/errhelp" "github.com/Azure/azure-service-operator/pkg/helpers" "github.com/Azure/azure-service-operator/pkg/resourcemanager/iam" @@ -88,31 +90,76 @@ func ConnectToSQLDBAsCurrentUser( return db, err } +// ExtractUserRoles extracts the roles the user has. The result is the set of roles the user has for the requested database +func ExtractUserRoles(ctx context.Context, db *sql.DB, user string, database string) (map[string]struct{}, error) { + if err := helpers.FindBadChars(user); err != nil { + return nil, errors.Wrapf(err, "problem found with username") + } + + // Note: This works because we only assign permissions at the DB level, not at the table, column, etc levels -- if we assigned + // permissions at more levels we would need to do something else here such as join multiple tables or + // parse SHOW GRANTS with a regex. + formattedUser := fmt.Sprintf("'%s'@'%%'", user) + rows, err := db.QueryContext( + ctx, + "SELECT PRIVILEGE_TYPE FROM INFORMATION_SCHEMA.SCHEMA_PRIVILEGES WHERE GRANTEE = ? and TABLE_SCHEMA = ?", + formattedUser, + database) + if err != nil { + return nil, errors.Wrapf(err, "listing grants for user %s", user) + } + defer rows.Close() + + result := make(map[string]struct{}) + for rows.Next() { + var row string + err := rows.Scan(&row) + if err != nil { + return nil, errors.Wrapf(err, "iterating returned rows") + } + + result[row] = struct{}{} + } + + return result, nil +} + func GrantUserRoles(ctx context.Context, user string, database string, roles []string, db *sql.DB) error { var errorStrings []string if err := helpers.FindBadChars(user); err != nil { return fmt.Errorf("problem found with username: %v", err) } + rolesMap := make(map[string]struct{}) for _, role := range roles { + rolesMap[role] = struct{}{} + } - if err := helpers.FindBadChars(role); err != nil { - return fmt.Errorf("problem found with role: %v", err) - } - - // Due to how go-mysql-driver performs parameter replacement, it always wraps - // string parameters in ''. That doesn't work for this query because some of - // our parameters are actually SQL keywords or identifiers (backticks). Admittedly - // protecting against SQL injection here is probably pointless as we're giving the caller - // permission to create users, which means there's nothing stopping them from creating - // an administrator user and then doing whatever they want without SQL injection. - // See https://github.com/go-sql-driver/mysql/blob/3b935426341bc5d229eafd936e4f4240da027ccd/connection.go#L198 - // for specifics of what go-mysql-driver supports. - tsql := fmt.Sprintf("GRANT %s ON `%s`.* TO ?", role, database) - _, err := db.ExecContext(ctx, tsql, user) - if err != nil { - errorStrings = append(errorStrings, err.Error()) - } + // Get the current roles + currentRoles, err := ExtractUserRoles(ctx, db, user, database) + if err != nil { + return errors.Wrapf(err, "couldn't get existing roles for user %s", user) + } + // Remove "USAGE" as it's special and we never grant or remove it + delete(currentRoles, "USAGE") + + rolesDiff := helpers.DiffCurrentAndExpectedSQLRoles(currentRoles, rolesMap) + + // Due to how go-mysql-driver performs parameter replacement, it always wraps + // string parameters in ''. That doesn't work for these queries because some of + // our parameters are actually SQL keywords or identifiers (requiring backticks). Admittedly + // protecting against SQL injection here is probably pointless as we're giving the caller + // permission to create users, which means there's nothing stopping them from creating + // an administrator user and then doing whatever they want without SQL injection. + // See https://github.com/go-sql-driver/mysql/blob/3b935426341bc5d229eafd936e4f4240da027ccd/connection.go#L198 + // for specifics of what go-mysql-driver supports. + err = addRoles(ctx, db, database, user, rolesDiff.AddedRoles) + if err != nil { + errorStrings = append(errorStrings, err.Error()) + } + err = deleteRoles(ctx, db, database, user, rolesDiff.DeletedRoles) + if err != nil { + errorStrings = append(errorStrings, err.Error()) } if len(errorStrings) != 0 { @@ -121,6 +168,42 @@ func GrantUserRoles(ctx context.Context, user string, database string, roles []s return nil } +func addRoles(ctx context.Context, db *sql.DB, database string, user string, roles map[string]struct{}) error { + if len(roles) == 0 { + // Nothing to do + return nil + } + + var rolesSlice []string + for elem := range roles { + rolesSlice = append(rolesSlice, elem) + } + + toAdd := strings.Join(rolesSlice, ",") + tsql := fmt.Sprintf("GRANT %s ON `%s`.* TO ?", toAdd, database) + _, err := db.ExecContext(ctx, tsql, user) + + return err +} + +func deleteRoles(ctx context.Context, db *sql.DB, database string, user string, roles map[string]struct{}) error { + if len(roles) == 0 { + // Nothing to do + return nil + } + + var rolesSlice []string + for elem := range roles { + rolesSlice = append(rolesSlice, elem) + } + + toDelete := strings.Join(rolesSlice, ",") + tsql := fmt.Sprintf("REVOKE %s ON `%s`.* FROM ?", toDelete, database) + _, err := db.ExecContext(ctx, tsql, user) + + return err +} + // UserExists checks if db contains user func UserExists(ctx context.Context, db *sql.DB, username string) (bool, error) {