-
Notifications
You must be signed in to change notification settings - Fork 204
/
mysqlhelper.go
369 lines (304 loc) · 10.5 KB
/
mysqlhelper.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
/*
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
*/
package mysql
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/hashicorp/go-multierror"
"github.com/pkg/errors"
"github.com/Azure/azure-service-operator/pkg/errhelp"
"github.com/Azure/azure-service-operator/pkg/helpers"
"github.com/Azure/azure-service-operator/pkg/resourcemanager/config"
"github.com/Azure/azure-service-operator/pkg/resourcemanager/iam"
)
// ServerPort is the default server port for sql server
const ServerPort = 3306
// DriverName is driver name for psqldb connection
const DriverName = "mysql"
// SystemDatabase is the name of the system database in a MySQL server
// where users and privileges are stored (and which we can always
// assume will exist).
const SystemDatabase = "mysql"
func GetFullSQLServerName(serverName string) string {
return serverName + "." + config.Environment().MySQLDatabaseDNSSuffix
}
func GetFullyQualifiedUserName(userName string, serverName string) string {
return fmt.Sprintf("%s@%s", userName, serverName)
}
// ConnectToSqlDb connects to the SQL db using the given credentials
func ConnectToSqlDB(ctx context.Context, driverName string, fullServer string, database string, port int, user string, password string) (*sql.DB, error) {
connString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?tls=true&interpolateParams=true", user, password, fullServer, port, database)
db, err := sql.Open(driverName, connString)
if err != nil {
return db, err
}
err = db.PingContext(ctx)
if err != nil {
return db, errors.Wrapf(err, "error pinging the mysql db (%s:%d/%s)", fullServer, port, database)
}
return db, err
}
// ConnectToSQLDBAsCurrentUser connects to the SQL DB using the specified MSI ClientID
func ConnectToSQLDBAsCurrentUser(
ctx context.Context,
driverName string,
fullServer string,
database string,
port int,
user string,
clientID string) (*sql.DB, error) {
tokenProvider, err := iam.GetMSITokenProviderForResourceByClientID(
config.Environment().ResourceIdentifiers.OSSRDBMS,
clientID)
if err != nil {
return nil, err
}
// In our case we can't pass the provider directly so we just invoke it and get the token and use that
token, err := tokenProvider()
if err != nil {
return nil, err
}
// See https://docs.microsoft.com/en-us/azure/mysql/howto-connect-with-managed-identity
// As noted here https://docs.microsoft.com/en-us/azure/mysql/howto-configure-sign-in-azure-ad-authentication#compatibility-with-application-drivers
// we must specify allowCleartextPasswords to pass a token
connString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?tls=true&allowCleartextPasswords=true&interpolateParams=true", user, token, fullServer, port, database)
db, err := sql.Open(driverName, connString)
if err != nil {
return db, err
}
err = db.PingContext(ctx)
if err != nil {
return db, errors.Wrapf(err, "error pinging the mysql db (%s:%d/%s) as %s", fullServer, port, database, user)
}
return db, err
}
func formatUser(user string) string {
// Wrap the user name in the weird formatting MySQL uses.
return fmt.Sprintf("'%s'@'%%'", user)
}
// ExtractUserDatabaseRoles extracts the per-database roles that the
// user has. The user can have different permissions to each
// database. The details of access are returned in the map, keyed by
// database name.
func ExtractUserDatabaseRoles(ctx context.Context, db *sql.DB, user string) (map[string]StringSet, error) {
// Note: This works because we only assign permissions at the DB level, not at the table, column, etc levels -- if we assigned
// permissions at more levels we would need to do something else here such as join multiple tables or
// parse SHOW GRANTS with a regex.
rows, err := db.QueryContext(
ctx,
"SELECT TABLE_SCHEMA, PRIVILEGE_TYPE FROM INFORMATION_SCHEMA.SCHEMA_PRIVILEGES WHERE GRANTEE = ?",
formatUser(user),
)
if err != nil {
return nil, errors.Wrapf(err, "listing database grants for user %s", user)
}
defer rows.Close()
results := make(map[string]StringSet)
for rows.Next() {
var database, privilege string
err := rows.Scan(&database, &privilege)
if err != nil {
return nil, errors.Wrapf(err, "extracting privilege row")
}
var privileges StringSet
if existingPrivileges, found := results[database]; found {
privileges = existingPrivileges
} else {
privileges = make(StringSet)
results[database] = privileges
}
privileges.Add(privilege)
}
if rows.Err() != nil {
return nil, errors.Wrapf(rows.Err(), "iterating database privileges")
}
return results, nil
}
// ExtractUserServerRoles extracts the server-level privileges the user has as a set.
func ExtractUserServerRoles(ctx context.Context, db *sql.DB, user string) (StringSet, error) {
// Note: This works because we only assign permissions at the DB level, not at the table, column, etc levels -- if we assigned
// permissions at more levels we would need to do something else here such as join multiple tables or
// parse SHOW GRANTS with a regex.
// Remove "USAGE" as it's special and we never grant or remove it.
rows, err := db.QueryContext(
ctx,
"SELECT PRIVILEGE_TYPE FROM INFORMATION_SCHEMA.USER_PRIVILEGES WHERE GRANTEE = ? AND PRIVILEGE_TYPE != 'USAGE'",
formatUser(user),
)
if err != nil {
return nil, errors.Wrapf(err, "listing grants for user %s", user)
}
defer rows.Close()
result := make(StringSet)
for rows.Next() {
var row string
err := rows.Scan(&row)
if err != nil {
return nil, errors.Wrapf(err, "extracting privilege field")
}
result.Add(row)
}
if rows.Err() != nil {
return nil, errors.Wrapf(rows.Err(), "iterating privileges")
}
return result, nil
}
type StringSet map[string]struct{}
func SliceToSet(values []string) StringSet {
result := make(StringSet)
for _, value := range values {
result[value] = struct{}{}
}
return result
}
func (s StringSet) Add(value string) {
s[value] = struct{}{}
}
// EnsureUserServerRoles revokes and grants server-level roles as
// needed so the roles for the user match those passed in.
func EnsureUserServerRoles(ctx context.Context, db *sql.DB, user string, roles []string) error {
var errorStrings []string
if err := helpers.FindBadChars(user); err != nil {
return errors.Wrap(err, "problem found with username")
}
desiredRoles := SliceToSet(roles)
currentRoles, err := ExtractUserServerRoles(ctx, db, user)
if err != nil {
return errors.Wrapf(err, "couldn't get existing roles for user %s", user)
}
rolesDiff := helpers.DiffCurrentAndExpectedSQLRoles(currentRoles, desiredRoles)
err = addRoles(ctx, db, "", user, rolesDiff.AddedRoles)
if err != nil {
errorStrings = append(errorStrings, err.Error())
}
err = deleteRoles(ctx, db, "", user, rolesDiff.DeletedRoles)
if err != nil {
errorStrings = append(errorStrings, err.Error())
}
if len(errorStrings) != 0 {
return fmt.Errorf(strings.Join(errorStrings, "\n"))
}
return nil
}
// EnsureUserDatabaseRoles revokes and grants database roles as needed
// so they match the ones passed in. If there's an error applying
// privileges for one database it will still continue to apply
// privileges for subsequent databases (before reporting all errors).
func EnsureUserDatabaseRoles(ctx context.Context, conn *sql.DB, user string, dbRoles map[string][]string) error {
if err := helpers.FindBadChars(user); err != nil {
return errors.Errorf("problem found with username: %s", err)
}
desiredRoles := make(map[string]StringSet)
for database, roles := range dbRoles {
desiredRoles[database] = SliceToSet(roles)
}
currentRoles, err := ExtractUserDatabaseRoles(ctx, conn, user)
if err != nil {
return errors.Wrapf(err, "couldn't get existing database roles for user %s", user)
}
allDatabases := make(StringSet)
for db := range desiredRoles {
allDatabases.Add(db)
}
for db := range currentRoles {
allDatabases.Add(db)
}
var dbErrors error
for db := range allDatabases {
rolesDiff := helpers.DiffCurrentAndExpectedSQLRoles(
currentRoles[db],
desiredRoles[db],
)
err = addRoles(ctx, conn, db, user, rolesDiff.AddedRoles)
if err != nil {
dbErrors = multierror.Append(dbErrors, errors.Wrap(err, db))
}
err = deleteRoles(ctx, conn, db, user, rolesDiff.DeletedRoles)
if err != nil {
dbErrors = multierror.Append(dbErrors, errors.Wrap(err, db))
}
}
return dbErrors
}
func addRoles(ctx context.Context, db *sql.DB, database string, user string, roles StringSet) error {
if len(roles) == 0 {
// Nothing to do
return nil
}
var rolesSlice []string
for elem := range roles {
rolesSlice = append(rolesSlice, elem)
}
toAdd := strings.Join(rolesSlice, ",")
tsql := fmt.Sprintf("GRANT %s ON %s TO ?", toAdd, asGrantTarget(database))
_, err := db.ExecContext(ctx, tsql, user)
return err
}
func deleteRoles(ctx context.Context, db *sql.DB, database string, user string, roles StringSet) error {
if len(roles) == 0 {
// Nothing to do
return nil
}
var rolesSlice []string
for elem := range roles {
rolesSlice = append(rolesSlice, elem)
}
toDelete := strings.Join(rolesSlice, ",")
tsql := fmt.Sprintf("REVOKE %s ON %s FROM ?", toDelete, asGrantTarget(database))
_, err := db.ExecContext(ctx, tsql, user)
return err
}
// asGrantTarget formats the database name as a target suitable for a
// grant or revoke statement. If database is empty it returns "*.*"
// for server-level privileges.
func asGrantTarget(database string) string {
if database == "" {
return "*.*"
}
return fmt.Sprintf("`%s`.*", database)
}
// UserExists checks if db contains user
func UserExists(ctx context.Context, db *sql.DB, username string) (bool, error) {
err := db.QueryRowContext(ctx, "SELECT * FROM mysql.user WHERE User = $1", username)
if err != nil {
return false, nil
}
return true, nil
}
// DropUser drops a user from db
func DropUser(ctx context.Context, db *sql.DB, user string) error {
if err := helpers.FindBadChars(user); err != nil {
return errors.Wrap(err, "problem found with username")
}
_, err := db.ExecContext(ctx, "DROP USER IF EXISTS ?", user)
return err
}
// TODO: This is probably more generic than MySQL
func IsErrorResourceNotFound(err error) bool {
requeueErrors := []string{
errhelp.ResourceNotFound,
errhelp.ParentNotFoundErrorCode,
errhelp.ResourceGroupNotFoundErrorCode,
}
azerr := errhelp.NewAzureError(err)
return helpers.ContainsString(requeueErrors, azerr.Type)
}
func IgnoreResourceNotFound(err error) error {
if IsErrorResourceNotFound(err) {
return nil
}
return err
}
func IsErrorDatabaseBusy(err error) bool {
return strings.Contains(err.Error(), "Please retry the connection later")
}
func IgnoreDatabaseBusy(err error) error {
if IsErrorDatabaseBusy(err) {
return nil
}
return err
}