From 3c5eda3ba99fd82eca160bdd2ffc764f8c9b9847 Mon Sep 17 00:00:00 2001 From: Erin Corson Date: Tue, 7 Apr 2020 16:11:24 -0600 Subject: [PATCH 1/9] replace Eventually() in test helper to improve output --- controllers/helpers.go | 94 ++++++++++++++++++++++++++++++++------- pkg/helpers/retry.go | 30 +++++++++++++ pkg/helpers/retry_test.go | 31 +++++++++++++ 3 files changed, 139 insertions(+), 16 deletions(-) create mode 100644 pkg/helpers/retry.go create mode 100644 pkg/helpers/retry_test.go diff --git a/controllers/helpers.go b/controllers/helpers.go index e6b12bdcdb3..230b82a2c5b 100644 --- a/controllers/helpers.go +++ b/controllers/helpers.go @@ -224,16 +224,47 @@ func EnsureInstanceWithResult(ctx context.Context, t *testing.T, tc TestContext, names := types.NamespacedName{Name: res.GetName(), Namespace: res.GetNamespace()} // Wait for first sql server to resolve - assert.Eventually(func() bool { - _ = tc.k8sClient.Get(ctx, names, instance) - return HasFinalizer(res, finalizerName) - }, tc.timeoutFast, tc.retry, "error waiting for %s to have finalizer", typeOf) + err = helpers.Retry(tc.timeoutFast, tc.retry, func() error { + err := tc.k8sClient.Get(ctx, names, instance) + if err != nil { + return err + } + + if !HasFinalizer(res, finalizerName) { + return fmt.Errorf("resource '%s' (%s) does not have finalizer '%s'", names.Name, typeOf, finalizerName) + } + return nil + }) + assert.Nil(err, "error waiting for %s to have finalizer", typeOf) + + err = helpers.Retry(tc.timeout, tc.retry, func() error { + err := tc.k8sClient.Get(ctx, names, instance) + if err != nil { + return err + } - assert.Eventually(func() bool { - _ = tc.k8sClient.Get(ctx, names, instance) statused := ConvertToStatus(instance) - return strings.Contains(statused.Status.Message, message) && statused.Status.Provisioned == provisioned - }, tc.timeout, tc.retry, "wait for %s to provision", typeOf) + if statused.Status.FailedProvisioning { + return helpers.NewStop(fmt.Errorf("Failed provisioning: %s", statused.Status.Message)) + } + if !strings.Contains(statused.Status.Message, message) || statused.Status.Provisioned != provisioned { + return fmt.Errorf( + `Expected: + Status.Message to contain %s + Status.Provisioned to be %t + Actual: + Message: '%s' + Provisioned: %t + `, + message, + provisioned, + statused.Status.Message, + statused.Status.Provisioned, + ) + } + return nil + }) + assert.Nil(err, "wait for %s to provision", typeOf) } @@ -274,16 +305,47 @@ func RequireInstanceWithResult(ctx context.Context, t *testing.T, tc TestContext names := types.NamespacedName{Name: res.GetName(), Namespace: res.GetNamespace()} // Wait for first sql server to resolve - require.Eventually(func() bool { - _ = tc.k8sClient.Get(ctx, names, instance) - return HasFinalizer(res, finalizerName) - }, tc.timeoutFast, tc.retry, "error waiting for %s to have finalizer", typeOf) + err = helpers.Retry(tc.timeoutFast, tc.retry, func() error { + err := tc.k8sClient.Get(ctx, names, instance) + if err != nil { + return err + } + + if !HasFinalizer(res, finalizerName) { + return fmt.Errorf("resource '%s' (%s) does not have finalizer '%s'", names.Name, typeOf, finalizerName) + } + return nil + }) + require.Nil(err, "error waiting for %s to have finalizer", typeOf) + + err = helpers.Retry(tc.timeout, tc.retry, func() error { + err := tc.k8sClient.Get(ctx, names, instance) + if err != nil { + return err + } - require.Eventually(func() bool { - _ = tc.k8sClient.Get(ctx, names, instance) statused := ConvertToStatus(instance) - return strings.Contains(statused.Status.Message, message) && statused.Status.Provisioned == provisioned - }, tc.timeout, tc.retry, "wait for %s to provision", typeOf) + if statused.Status.FailedProvisioning { + return helpers.NewStop(fmt.Errorf("Failed provisioning: %s", statused.Status.Message)) + } + if !strings.Contains(statused.Status.Message, message) || statused.Status.Provisioned != provisioned { + return fmt.Errorf( + `Expected: + Status.Message to contain %s + Status.Provisioned to be %t + Actual: + Message: '%s' + Provisioned: %t + `, + message, + provisioned, + statused.Status.Message, + statused.Status.Provisioned, + ) + } + return nil + }) + require.Nil(err, "wait for %s to provision", typeOf) } diff --git a/pkg/helpers/retry.go b/pkg/helpers/retry.go new file mode 100644 index 00000000000..c1bb6b88a39 --- /dev/null +++ b/pkg/helpers/retry.go @@ -0,0 +1,30 @@ +package helpers + +import "time" + +func NewStop(e error) *StopErr { + return &StopErr{e} +} + +type StopErr struct { + Err error +} + +func (s *StopErr) Error() string { + return s.Err.Error() +} + +func Retry(timeout time.Duration, sleep time.Duration, fn func() error) error { + if err := fn(); err != nil { + // allow early exit + if v, ok := err.(*StopErr); ok { + return v.Err + } + if timeout > 0 { + time.Sleep(sleep) + return Retry(timeout-sleep, sleep, fn) + } + return err + } + return nil +} diff --git a/pkg/helpers/retry_test.go b/pkg/helpers/retry_test.go new file mode 100644 index 00000000000..93d99007057 --- /dev/null +++ b/pkg/helpers/retry_test.go @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +package helpers + +import ( + "fmt" + "testing" + "time" +) + +func TestRetryTimeout(t *testing.T) { + start := time.Now() + _ = Retry(5*time.Second, 1*time.Second, func() error { + return fmt.Errorf("test") + }) + stop := time.Now().Sub(start) + if stop < 5*time.Second { + t.Errorf("retry ended too soon: %v", stop) + } +} +func TestRetryStopErr(t *testing.T) { + start := time.Now() + _ = Retry(5*time.Second, 1*time.Second, func() error { + return NewStop(fmt.Errorf("test")) + }) + stop := time.Now().Sub(start) + if stop > 1*time.Second { + t.Errorf("retry with stop should not take so long: %v", stop) + } +} From e653d8173a81eadb13d3c9f7d80554a66ca3b4b0 Mon Sep 17 00:00:00 2001 From: Erin Corson Date: Tue, 7 Apr 2020 16:23:58 -0600 Subject: [PATCH 2/9] adding license header --- pkg/helpers/retry.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pkg/helpers/retry.go b/pkg/helpers/retry.go index c1bb6b88a39..767525d6a96 100644 --- a/pkg/helpers/retry.go +++ b/pkg/helpers/retry.go @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + package helpers import "time" From 2610c577b06989a5603fa4426069017ed7e98120 Mon Sep 17 00:00:00 2001 From: Erin Corson Date: Tue, 7 Apr 2020 17:16:50 -0600 Subject: [PATCH 3/9] change the way keyvault names are generated in tests --- controllers/keyvault_controller_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/controllers/keyvault_controller_test.go b/controllers/keyvault_controller_test.go index f3a78b2a03b..565b7486156 100644 --- a/controllers/keyvault_controller_test.go +++ b/controllers/keyvault_controller_test.go @@ -30,7 +30,7 @@ func TestKeyvaultControllerHappyPath(t *testing.T) { ctx := context.Background() assert := assert.New(t) - keyVaultName := helpers.FillWithRandom(GenerateTestResourceName("kv"), 24) + keyVaultName := GenerateTestResourceNameWithRandom("kv", 10) const poll = time.Second * 10 keyVaultLocation := tc.resourceGroupLocation @@ -72,7 +72,7 @@ func TestKeyvaultControllerWithAccessPolicies(t *testing.T) { ctx := context.Background() assert := assert.New(t) - keyVaultName := helpers.FillWithRandom(GenerateTestResourceName("kv"), 24) + keyVaultName := GenerateTestResourceNameWithRandom("kv", 10) const poll = time.Second * 10 keyVaultLocation := tc.resourceGroupLocation accessPolicies := []azurev1alpha1.AccessPolicyEntry{ @@ -150,7 +150,7 @@ func TestKeyvaultControllerWithLimitedAccessPoliciesAndUpdate(t *testing.T) { defer PanicRecover(t) ctx := context.Background() assert := assert.New(t) - keyVaultName := helpers.FillWithRandom(GenerateTestResourceName("kv"), 24) + keyVaultName := GenerateTestResourceNameWithRandom("kv", 10) const poll = time.Second * 10 keyVaultLocation := tc.resourceGroupLocation limitedPermissions := []string{"backup"} @@ -346,7 +346,7 @@ func TestKeyvaultControllerWithVirtualNetworkRulesAndUpdate(t *testing.T) { ctx := context.Background() assert := assert.New(t) - keyVaultName := helpers.FillWithRandom(GenerateTestResourceName("kv"), 24) + keyVaultName := GenerateTestResourceNameWithRandom("kv", 10) const poll = time.Second * 10 keyVaultLocation := tc.resourceGroupLocation accessPolicies := []azurev1alpha1.AccessPolicyEntry{ From 9e006554c35640e85e0117417078939c172581d0 Mon Sep 17 00:00:00 2001 From: Erin Corson Date: Tue, 7 Apr 2020 17:23:13 -0600 Subject: [PATCH 4/9] don't exit on failed provision if expected result was failure --- controllers/helpers.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/controllers/helpers.go b/controllers/helpers.go index 230b82a2c5b..653696782aa 100644 --- a/controllers/helpers.go +++ b/controllers/helpers.go @@ -244,7 +244,8 @@ func EnsureInstanceWithResult(ctx context.Context, t *testing.T, tc TestContext, } statused := ConvertToStatus(instance) - if statused.Status.FailedProvisioning { + // if we expect this resource to end up with provisioned == true then failedProvisioning == true is unrecoverable + if provisioned == true && statused.Status.FailedProvisioning { return helpers.NewStop(fmt.Errorf("Failed provisioning: %s", statused.Status.Message)) } if !strings.Contains(statused.Status.Message, message) || statused.Status.Provisioned != provisioned { From d5ef4cdb3201261cdba403ae2fd5b1bc33c94b8f Mon Sep 17 00:00:00 2001 From: Claudia Nadolny Date: Wed, 8 Apr 2020 09:15:47 -0700 Subject: [PATCH 5/9] filter out bad labels (#900) --- pkg/resourcemanager/rediscaches/rediscaches.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pkg/resourcemanager/rediscaches/rediscaches.go b/pkg/resourcemanager/rediscaches/rediscaches.go index f9bf0e4001d..e7fd22bb484 100644 --- a/pkg/resourcemanager/rediscaches/rediscaches.go +++ b/pkg/resourcemanager/rediscaches/rediscaches.go @@ -12,6 +12,7 @@ import ( "github.com/Azure/azure-sdk-for-go/services/redis/mgmt/2018-03-01/redis" model "github.com/Azure/azure-sdk-for-go/services/redis/mgmt/2018-03-01/redis" azurev1alpha1 "github.com/Azure/azure-service-operator/api/v1alpha1" + "github.com/Azure/azure-service-operator/pkg/helpers" "github.com/Azure/azure-service-operator/pkg/resourcemanager/config" "github.com/Azure/azure-service-operator/pkg/resourcemanager/iam" "github.com/Azure/azure-service-operator/pkg/secrets" @@ -54,11 +55,7 @@ func (r *AzureRedisCacheManager) CreateRedisCache( props := instance.Spec.Properties // convert kube labels to expected tag format - tags := map[string]*string{} - for k, v := range instance.GetLabels() { - value := v - tags[k] = &value - } + tags := helpers.LabelsToTags(instance.GetLabels()) redisClient, err := getRedisCacheClient() if err != nil { From 40a5e1e569001bce051d9f42e7ba34f2640057ba Mon Sep 17 00:00:00 2001 From: Erin Corson Date: Wed, 8 Apr 2020 13:44:16 -0600 Subject: [PATCH 6/9] shorten keyvault names --- controllers/keyvault_controller_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/controllers/keyvault_controller_test.go b/controllers/keyvault_controller_test.go index 565b7486156..1620b901c88 100644 --- a/controllers/keyvault_controller_test.go +++ b/controllers/keyvault_controller_test.go @@ -30,7 +30,7 @@ func TestKeyvaultControllerHappyPath(t *testing.T) { ctx := context.Background() assert := assert.New(t) - keyVaultName := GenerateTestResourceNameWithRandom("kv", 10) + keyVaultName := GenerateTestResourceNameWithRandom("kv", 6) const poll = time.Second * 10 keyVaultLocation := tc.resourceGroupLocation @@ -72,7 +72,7 @@ func TestKeyvaultControllerWithAccessPolicies(t *testing.T) { ctx := context.Background() assert := assert.New(t) - keyVaultName := GenerateTestResourceNameWithRandom("kv", 10) + keyVaultName := GenerateTestResourceNameWithRandom("kv", 6) const poll = time.Second * 10 keyVaultLocation := tc.resourceGroupLocation accessPolicies := []azurev1alpha1.AccessPolicyEntry{ @@ -150,7 +150,7 @@ func TestKeyvaultControllerWithLimitedAccessPoliciesAndUpdate(t *testing.T) { defer PanicRecover(t) ctx := context.Background() assert := assert.New(t) - keyVaultName := GenerateTestResourceNameWithRandom("kv", 10) + keyVaultName := GenerateTestResourceNameWithRandom("kv", 6) const poll = time.Second * 10 keyVaultLocation := tc.resourceGroupLocation limitedPermissions := []string{"backup"} @@ -346,7 +346,7 @@ func TestKeyvaultControllerWithVirtualNetworkRulesAndUpdate(t *testing.T) { ctx := context.Background() assert := assert.New(t) - keyVaultName := GenerateTestResourceNameWithRandom("kv", 10) + keyVaultName := GenerateTestResourceNameWithRandom("kv", 6) const poll = time.Second * 10 keyVaultLocation := tc.resourceGroupLocation accessPolicies := []azurev1alpha1.AccessPolicyEntry{ From 20b5dbb81e4e6eaa93adc54b4d0717ca8a0db3c9 Mon Sep 17 00:00:00 2001 From: Erin Corson Date: Wed, 8 Apr 2020 17:01:28 -0600 Subject: [PATCH 7/9] require first server to move on --- controllers/azuresql_combined_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controllers/azuresql_combined_test.go b/controllers/azuresql_combined_test.go index 4e78d80cc0b..434be02bb8a 100644 --- a/controllers/azuresql_combined_test.go +++ b/controllers/azuresql_combined_test.go @@ -44,7 +44,7 @@ func TestAzureSqlServerCombinedHappyPath(t *testing.T) { sqlServerInstance2 := azurev1alpha1.NewAzureSQLServer(sqlServerNamespacedName2, rgName, rgLocation2) // create and wait - EnsureInstance(ctx, t, tc, sqlServerInstance) + RequireInstance(ctx, t, tc, sqlServerInstance) //verify secret exists in k8s for server 1 --------------------------------- secret := &v1.Secret{} From f9cadcc8a2a9b2b974388fc62da36815e60c0a71 Mon Sep 17 00:00:00 2001 From: Erin Corson Date: Wed, 8 Apr 2020 17:01:41 -0600 Subject: [PATCH 8/9] addressing pr comments --- controllers/helpers.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/controllers/helpers.go b/controllers/helpers.go index 653696782aa..29970bfe123 100644 --- a/controllers/helpers.go +++ b/controllers/helpers.go @@ -223,7 +223,7 @@ func EnsureInstanceWithResult(ctx context.Context, t *testing.T, tc TestContext, names := types.NamespacedName{Name: res.GetName(), Namespace: res.GetNamespace()} - // Wait for first sql server to resolve + // Wait for finalizer err = helpers.Retry(tc.timeoutFast, tc.retry, func() error { err := tc.k8sClient.Get(ctx, names, instance) if err != nil { @@ -237,6 +237,7 @@ func EnsureInstanceWithResult(ctx context.Context, t *testing.T, tc TestContext, }) assert.Nil(err, "error waiting for %s to have finalizer", typeOf) + // wait for provisioned and message to be as expected err = helpers.Retry(tc.timeout, tc.retry, func() error { err := tc.k8sClient.Get(ctx, names, instance) if err != nil { @@ -245,7 +246,7 @@ func EnsureInstanceWithResult(ctx context.Context, t *testing.T, tc TestContext, statused := ConvertToStatus(instance) // if we expect this resource to end up with provisioned == true then failedProvisioning == true is unrecoverable - if provisioned == true && statused.Status.FailedProvisioning { + if provisioned && statused.Status.FailedProvisioning { return helpers.NewStop(fmt.Errorf("Failed provisioning: %s", statused.Status.Message)) } if !strings.Contains(statused.Status.Message, message) || statused.Status.Provisioned != provisioned { @@ -305,7 +306,7 @@ func RequireInstanceWithResult(ctx context.Context, t *testing.T, tc TestContext names := types.NamespacedName{Name: res.GetName(), Namespace: res.GetNamespace()} - // Wait for first sql server to resolve + // Wait for finalizer err = helpers.Retry(tc.timeoutFast, tc.retry, func() error { err := tc.k8sClient.Get(ctx, names, instance) if err != nil { @@ -319,6 +320,7 @@ func RequireInstanceWithResult(ctx context.Context, t *testing.T, tc TestContext }) require.Nil(err, "error waiting for %s to have finalizer", typeOf) + // wait for provisioned state and message to be as expected err = helpers.Retry(tc.timeout, tc.retry, func() error { err := tc.k8sClient.Get(ctx, names, instance) if err != nil { @@ -326,7 +328,7 @@ func RequireInstanceWithResult(ctx context.Context, t *testing.T, tc TestContext } statused := ConvertToStatus(instance) - if statused.Status.FailedProvisioning { + if provisioned && statused.Status.FailedProvisioning { return helpers.NewStop(fmt.Errorf("Failed provisioning: %s", statused.Status.Message)) } if !strings.Contains(statused.Status.Message, message) || statused.Status.Provisioned != provisioned { From baeb529645d849b6709c238003dd7ba748784f35 Mon Sep 17 00:00:00 2001 From: William Mortl <32373900+WilliamMortlMicrosoft@users.noreply.github.com> Date: Wed, 8 Apr 2020 19:32:10 -0700 Subject: [PATCH 9/9] Remove instantiations of GO SDK clients in the Ensure function for azuresqluser, gracefully fails on lack of firewall rule (#898) * first * refactored based on Erins changes * fixed firewall issue * deletion fix * syncing up with Justin and Erin * erin feedback 1 Co-authored-by: William Mortl Co-authored-by: Erin Corson --- .../azuresql/azuresqldb/azuresqldb.go | 19 +- .../azuresqlfailovergroup.go | 5 +- .../azuresql/azuresqlshared/getgoclients.go | 9 +- .../azuresql/azuresqlshared/resourceclient.go | 27 - .../azuresql/azuresqluser/azuresqluser.go | 460 ++---------------- .../azuresqluser/azuresqluser_manager.go | 8 + .../azuresqluser/azuresqluser_reconcile.go | 419 ++++++++++++++++ 7 files changed, 498 insertions(+), 449 deletions(-) delete mode 100644 pkg/resourcemanager/azuresql/azuresqlshared/resourceclient.go create mode 100644 pkg/resourcemanager/azuresql/azuresqluser/azuresqluser_reconcile.go diff --git a/pkg/resourcemanager/azuresql/azuresqldb/azuresqldb.go b/pkg/resourcemanager/azuresql/azuresqldb/azuresqldb.go index dde5a2213ea..2f59d0c6ee4 100644 --- a/pkg/resourcemanager/azuresql/azuresqldb/azuresqldb.go +++ b/pkg/resourcemanager/azuresql/azuresqldb/azuresqldb.go @@ -34,7 +34,10 @@ func (_ *AzureSqlDbManager) GetServer(ctx context.Context, resourceGroupName str // GetDB retrieves a database func (_ *AzureSqlDbManager) GetDB(ctx context.Context, resourceGroupName string, serverName string, databaseName string) (sql.Database, error) { - dbClient := azuresqlshared.GetGoDbClient() + dbClient, err := azuresqlshared.GetGoDbClient() + if err != nil { + return sql.Database{}, err + } return dbClient.Get( ctx, @@ -65,7 +68,11 @@ func (sdk *AzureSqlDbManager) DeleteDB(ctx context.Context, resourceGroupName st return result, nil } - dbClient := azuresqlshared.GetGoDbClient() + dbClient, err := azuresqlshared.GetGoDbClient() + if err != nil { + return result, err + } + result, err = dbClient.Delete( ctx, resourceGroupName, @@ -78,7 +85,13 @@ func (sdk *AzureSqlDbManager) DeleteDB(ctx context.Context, resourceGroupName st // CreateOrUpdateDB creates or updates a DB in Azure func (_ *AzureSqlDbManager) CreateOrUpdateDB(ctx context.Context, resourceGroupName string, location string, serverName string, tags map[string]*string, properties azuresqlshared.SQLDatabaseProperties) (*http.Response, error) { - dbClient := azuresqlshared.GetGoDbClient() + dbClient, err := azuresqlshared.GetGoDbClient() + if err != nil { + return &http.Response{ + StatusCode: 0, + }, err + } + dbProp := azuresqlshared.SQLDatabasePropertiesToDatabase(properties) future, err := dbClient.CreateOrUpdate( diff --git a/pkg/resourcemanager/azuresql/azuresqlfailovergroup/azuresqlfailovergroup.go b/pkg/resourcemanager/azuresql/azuresqlfailovergroup/azuresqlfailovergroup.go index a6971d79fc1..ba36e73e454 100644 --- a/pkg/resourcemanager/azuresql/azuresqlfailovergroup/azuresqlfailovergroup.go +++ b/pkg/resourcemanager/azuresql/azuresqlfailovergroup/azuresqlfailovergroup.go @@ -42,7 +42,10 @@ func (f *AzureSqlFailoverGroupManager) GetServer(ctx context.Context, resourceGr // GetDB retrieves a database func (f *AzureSqlFailoverGroupManager) GetDB(ctx context.Context, resourceGroupName string, serverName string, databaseName string) (sql.Database, error) { - dbClient := azuresqlshared.GetGoDbClient() + dbClient, err := azuresqlshared.GetGoDbClient() + if err != nil { + return sql.Database{}, err + } return dbClient.Get( ctx, diff --git a/pkg/resourcemanager/azuresql/azuresqlshared/getgoclients.go b/pkg/resourcemanager/azuresql/azuresqlshared/getgoclients.go index aee36aa0907..2c960669c7c 100644 --- a/pkg/resourcemanager/azuresql/azuresqlshared/getgoclients.go +++ b/pkg/resourcemanager/azuresql/azuresqlshared/getgoclients.go @@ -11,12 +11,15 @@ import ( ) // GetGoDbClient retrieves a DatabasesClient -func GetGoDbClient() sql.DatabasesClient { +func GetGoDbClient() (sql.DatabasesClient, error) { dbClient := sql.NewDatabasesClientWithBaseURI(config.BaseURI(), config.SubscriptionID()) - a, _ := iam.GetResourceManagementAuthorizer() + a, err := iam.GetResourceManagementAuthorizer() + if err != nil { + return sql.DatabasesClient{}, err + } dbClient.Authorizer = a dbClient.AddToUserAgent(config.UserAgent()) - return dbClient + return dbClient, nil } // GetGoServersClient retrieves a ServersClient diff --git a/pkg/resourcemanager/azuresql/azuresqlshared/resourceclient.go b/pkg/resourcemanager/azuresql/azuresqlshared/resourceclient.go deleted file mode 100644 index b44438f6ebb..00000000000 --- a/pkg/resourcemanager/azuresql/azuresqlshared/resourceclient.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -package azuresqlshared - -import ( - "context" - - "github.com/Azure/azure-sdk-for-go/services/preview/sql/mgmt/2015-05-01-preview/sql" - "github.com/Azure/go-autorest/autorest" -) - -// ResourceClient contains the helper functions for interacting with SQL servers / databases -type ResourceClient interface { - CreateOrUpdateSQLServer(ctx context.Context, resourceGroupName string, location string, serverName string, tags map[string]*string, properties SQLServerProperties, forceUpdate bool) (result sql.Server, err error) - DeleteSQLServer(ctx context.Context, resourceGroupName string, serverName string) (result autorest.Response, err error) - GetServer(ctx context.Context, resourceGroupName string, serverName string) (result sql.Server, err error) - DeleteSQLFirewallRule(ctx context.Context, resourceGroupName string, serverName string, ruleName string) (err error) - CreateOrUpdateFailoverGroup(ctx context.Context, resourceGroupName string, serverName string, failovergroupname string, properties SQLFailoverGroupProperties) (result sql.FailoverGroupsCreateOrUpdateFuture, err error) - DeleteFailoverGroup(ctx context.Context, resourceGroupName string, serverName string, failoverGroupName string) (result autorest.Response, err error) - DeleteDB(ctx context.Context, resourceGroupName string, serverName string, databaseName string) (result autorest.Response, err error) - GetSQLFirewallRule(ctx context.Context, resourceGroupName string, serverName string, ruleName string) (result sql.FirewallRule, err error) - GetDB(ctx context.Context, resourceGroupName string, serverName string, databaseName string) (sql.Database, error) - GetFailoverGroup(ctx context.Context, resourceGroupName string, serverName string, failovergroupname string) (sql.FailoverGroup, error) - CreateOrUpdateDB(ctx context.Context, resourceGroupName string, location string, serverName string, tags map[string]*string, properties SQLDatabaseProperties) (sql.DatabasesCreateOrUpdateFuture, error) - CreateOrUpdateSQLFirewallRule(ctx context.Context, resourceGroupName string, serverName string, ruleName string, startIP string, endIP string) (result bool, err error) -} diff --git a/pkg/resourcemanager/azuresql/azuresqluser/azuresqluser.go b/pkg/resourcemanager/azuresql/azuresqluser/azuresqluser.go index 4170f7d4092..007d93d14bc 100644 --- a/pkg/resourcemanager/azuresql/azuresqluser/azuresqluser.go +++ b/pkg/resourcemanager/azuresql/azuresqluser/azuresqluser.go @@ -10,15 +10,12 @@ import ( "reflect" "strings" + azuresql "github.com/Azure/azure-sdk-for-go/services/preview/sql/mgmt/2015-05-01-preview/sql" "github.com/Azure/azure-service-operator/pkg/helpers" - "github.com/Azure/azure-service-operator/pkg/resourcemanager/azuresql/azuresqldb" + azuresqlshared "github.com/Azure/azure-service-operator/pkg/resourcemanager/azuresql/azuresqlshared" "github.com/Azure/azure-service-operator/pkg/secrets" "github.com/Azure/azure-service-operator/api/v1alpha1" - "github.com/Azure/azure-service-operator/pkg/errhelp" - "github.com/Azure/azure-service-operator/pkg/resourcemanager" - keyvaultSecrets "github.com/Azure/azure-service-operator/pkg/secrets/keyvault" - "github.com/google/uuid" "k8s.io/apimachinery/pkg/runtime" _ "github.com/denisenkom/go-mssqldb" @@ -49,8 +46,24 @@ func NewAzureSqlUserManager(secretClient secrets.SecretClient, scheme *runtime.S } } +// GetDB retrieves a database +func (s *AzureSqlUserManager) GetDB(ctx context.Context, resourceGroupName string, serverName string, databaseName string) (azuresql.Database, error) { + dbClient, err := azuresqlshared.GetGoDbClient() + if err != nil { + return azuresql.Database{}, err + } + + return dbClient.Get( + ctx, + resourceGroupName, + serverName, + databaseName, + "serviceTierAdvisors, transparentDataEncryption", + ) +} + // ConnectToSqlDb connects to the SQL db using the given credentials -func (m *AzureSqlUserManager) ConnectToSqlDb(ctx context.Context, drivername string, server string, database string, port int, user string, password string) (*sql.DB, error) { +func (s *AzureSqlUserManager) ConnectToSqlDb(ctx context.Context, drivername string, server string, database string, port int, user string, password string) (*sql.DB, error) { fullServerAddress := fmt.Sprintf("%s.database.windows.net", server) connString := fmt.Sprintf("server=%s;user id=%s;password=%s;port=%d;database=%s;Persist Security Info=False;Pooling=False;MultipleActiveResultSets=False;Encrypt=True;TrustServerCertificate=False;Connection Timeout=30", fullServerAddress, user, password, port, database) @@ -69,7 +82,7 @@ func (m *AzureSqlUserManager) ConnectToSqlDb(ctx context.Context, drivername str } // GrantUserRoles grants roles to a user for a given database -func (m *AzureSqlUserManager) GrantUserRoles(ctx context.Context, user string, roles []string, db *sql.DB) error { +func (s *AzureSqlUserManager) GrantUserRoles(ctx context.Context, user string, roles []string, db *sql.DB) error { var errorStrings []string for _, role := range roles { tsql := "sp_addrolemember @role, @user" @@ -90,8 +103,8 @@ func (m *AzureSqlUserManager) GrantUserRoles(ctx context.Context, user string, r return nil } -// Creates user with secret credentials -func (m *AzureSqlUserManager) CreateUser(ctx context.Context, secret map[string][]byte, db *sql.DB) (string, error) { +// CreateUser creates user with secret credentials +func (s *AzureSqlUserManager) CreateUser(ctx context.Context, secret map[string][]byte, db *sql.DB) (string, error) { newUser := string(secret[SecretUsernameKey]) newPassword := string(secret[SecretPasswordKey]) @@ -117,7 +130,7 @@ func (m *AzureSqlUserManager) CreateUser(ctx context.Context, secret map[string] } // UpdateUser - Updates user password -func (m *AzureSqlUserManager) UpdateUser(ctx context.Context, secret map[string][]byte, db *sql.DB) error { +func (s *AzureSqlUserManager) UpdateUser(ctx context.Context, secret map[string][]byte, db *sql.DB) error { user := string(secret[SecretUsernameKey]) newPassword := helpers.NewPassword() @@ -136,7 +149,7 @@ func (m *AzureSqlUserManager) UpdateUser(ctx context.Context, secret map[string] } // UserExists checks if db contains user -func (m *AzureSqlUserManager) UserExists(ctx context.Context, db *sql.DB, username string) (bool, error) { +func (s *AzureSqlUserManager) UserExists(ctx context.Context, db *sql.DB, username string) (bool, error) { res, err := db.ExecContext( ctx, "SELECT * FROM sysusers WHERE NAME=@user", @@ -149,398 +162,14 @@ func (m *AzureSqlUserManager) UserExists(ctx context.Context, db *sql.DB, userna return rows > 0, err } -// Drops user from db -func (m *AzureSqlUserManager) DropUser(ctx context.Context, db *sql.DB, user string) error { +// DropUser drops a user from db +func (s *AzureSqlUserManager) DropUser(ctx context.Context, db *sql.DB, user string) error { tsql := "DROP USER @user" _, err := db.ExecContext(ctx, tsql, sql.Named("user", user)) return err } -func (s *AzureSqlUserManager) Ensure(ctx context.Context, obj runtime.Object, opts ...resourcemanager.ConfigOption) (bool, error) { - instance, err := s.convert(obj) - if err != nil { - return false, err - } - - requestedUsername := instance.Spec.Username - if len(requestedUsername) == 0 { - requestedUsername = instance.Name - } - - options := &resourcemanager.Options{} - for _, opt := range opts { - opt(options) - } - - adminSecretClient := s.SecretClient - - adminsecretName := instance.Spec.AdminSecret - if len(instance.Spec.AdminSecret) == 0 { - adminsecretName = instance.Spec.Server - } - - key := types.NamespacedName{Name: adminsecretName, Namespace: instance.Namespace} - - var sqlUserSecretClient secrets.SecretClient - if options.SecretClient != nil { - sqlUserSecretClient = options.SecretClient - } else { - sqlUserSecretClient = s.SecretClient - } - - // if the admin secret keyvault is not specified, fall back to global secretclient - if len(instance.Spec.AdminSecretKeyVault) != 0 { - adminSecretClient = keyvaultSecrets.New(instance.Spec.AdminSecretKeyVault) - if len(instance.Spec.AdminSecret) != 0 { - key = types.NamespacedName{Name: instance.Spec.AdminSecret} - } - } - - // need this to detect missing databases - dbClient := azuresqldb.NewAzureSqlDbManager() - - // get admin creds for server - adminSecret, err := adminSecretClient.Get(ctx, key) - if err != nil { - instance.Status.Provisioning = false - instance.Status.Message = fmt.Sprintf("admin secret : %s, not found in %s", key.String(), reflect.TypeOf(adminSecretClient).Elem().Name()) - return false, nil - } - - adminUser := string(adminSecret[SecretUsernameKey]) - adminPassword := string(adminSecret[SecretPasswordKey]) - - _, err = dbClient.GetDB(ctx, instance.Spec.ResourceGroup, instance.Spec.Server, instance.Spec.DbName) - if err != nil { - instance.Status.Message = err.Error() - - catch := []string{ - errhelp.ResourceNotFound, - errhelp.ParentNotFoundErrorCode, - errhelp.ResourceGroupNotFoundErrorCode, - } - azerr := errhelp.NewAzureErrorAzureError(err) - if helpers.ContainsString(catch, azerr.Type) { - return false, nil - } - return false, err - } - - db, err := s.ConnectToSqlDb(ctx, DriverName, instance.Spec.Server, instance.Spec.DbName, SqlServerPort, adminUser, adminPassword) - if err != nil { - instance.Status.Message = err.Error() - if strings.Contains(err.Error(), "create a firewall rule for this IP address") { - return false, nil - } - return false, err - } - - // determine our key namespace - if we're persisting to kube, we should use the actual instance namespace. - // In keyvault we have to avoid collisions with other secrets so we create a custom namespace with the user's parameters - key = GetNamespacedName(instance, sqlUserSecretClient) - - // create or get new user secret - DBSecret := s.GetOrPrepareSecret(ctx, instance, sqlUserSecretClient) - // reset user from secret in case it was loaded - user := string(DBSecret[SecretUsernameKey]) - if user == "" { - user = fmt.Sprintf("%s-%s", requestedUsername, uuid.New()) - DBSecret[SecretUsernameKey] = []byte(user) - } - - // Publishing the user secret: - // We do this first so if the keyvault does not have right permissions we will not proceed to creating the user - err = sqlUserSecretClient.Upsert( - ctx, - key, - DBSecret, - secrets.WithOwner(instance), - secrets.WithScheme(s.Scheme), - ) - if err != nil { - instance.Status.Message = "failed to update secret, err: " + err.Error() - return false, err - } - - // Preformatted special formats are only available through keyvault as they require separated secrets - keyVaultEnabled := reflect.TypeOf(sqlUserSecretClient).Elem().Name() == "KeyvaultSecretClient" - if keyVaultEnabled { - // Instantiate a map of all formats and flip the bool to true for any that have been requested in the spec. - // Formats that were not requested will be explicitly deleted. - requestedFormats := map[string]bool{ - "adonet": false, - "adonet-urlonly": false, - "jdbc": false, - "jdbc-urlonly": false, - "odbc": false, - "odbc-urlonly": false, - "server": false, - "database": false, - "username": false, - "password": false, - } - for _, format := range instance.Spec.KeyVaultSecretFormats { - requestedFormats[format] = true - } - - // Deleted items will be processed immediately but secrets that need to be added will be created in this array and persisted in one pass at the end - formattedSecrets := make(map[string][]byte) - - for formatName, requested := range requestedFormats { - // Add the format to the output map if it has been requested otherwise call for its deletion from the secret store - if requested { - switch formatName { - case "adonet": - formattedSecrets["adonet"] = []byte(fmt.Sprintf( - "Server=tcp:%v,1433;Initial Catalog=%v;Persist Security Info=False;User ID=%v;Password=%v;MultipleActiveResultSets=False;Encrypt=True;TrustServerCertificate=False;Connection Timeout=30;", - string(DBSecret["fullyQualifiedServerName"]), - instance.Spec.DbName, - user, - string(DBSecret["password"]), - )) - - case "adonet-urlonly": - formattedSecrets["adonet-urlonly"] = []byte(fmt.Sprintf( - "Server=tcp:%v,1433;Initial Catalog=%v;Persist Security Info=False; MultipleActiveResultSets=False;Encrypt=True;TrustServerCertificate=False;Connection Timeout", - string(DBSecret["fullyQualifiedServerName"]), - instance.Spec.DbName, - )) - - case "jdbc": - formattedSecrets["jdbc"] = []byte(fmt.Sprintf( - "jdbc:sqlserver://%v:1433;database=%v;user=%v@%v;password=%v;encrypt=true;trustServerCertificate=false;hostNameInCertificate=*.database.windows.net;loginTimeout=30;", - string(DBSecret["fullyQualifiedServerName"]), - instance.Spec.DbName, - user, - instance.Spec.Server, - string(DBSecret["password"]), - )) - case "jdbc-urlonly": - formattedSecrets["jdbc-urlonly"] = []byte(fmt.Sprintf( - "jdbc:sqlserver://%v:1433;database=%v;encrypt=true;trustServerCertificate=false;hostNameInCertificate=*.database.windows.net;loginTimeout=30;", - string(DBSecret["fullyQualifiedServerName"]), - instance.Spec.DbName, - )) - - case "odbc": - formattedSecrets["odbc"] = []byte(fmt.Sprintf( - "Server=tcp:%v,1433;Initial Catalog=%v;Persist Security Info=False;User ID=%v;Password=%v;MultipleActiveResultSets=False;Encrypt=True;TrustServerCertificate=False;Connection Timeout=30;", - string(DBSecret["fullyQualifiedServerName"]), - instance.Spec.DbName, - user, - string(DBSecret["password"]), - )) - case "odbc-urlonly": - formattedSecrets["odbc-urlonly"] = []byte(fmt.Sprintf( - "Driver={ODBC Driver 13 for SQL Server};Server=tcp:%v,1433;Database=%v; Encrypt=yes;TrustServerCertificate=no;Connection Timeout=30;", - string(DBSecret["fullyQualifiedServerName"]), - instance.Spec.DbName, - )) - case "server": - formattedSecrets["server"] = DBSecret["fullyQualifiedServerName"] - - case "database": - formattedSecrets["database"] = []byte(instance.Spec.DbName) - - case "username": - formattedSecrets["username"] = []byte(user) - - case "password": - formattedSecrets["password"] = DBSecret["password"] - } - } else { - err = sqlUserSecretClient.Delete( - ctx, - types.NamespacedName{Namespace: key.Namespace, Name: instance.Name + "-" + formatName}, - ) - } - } - - err = sqlUserSecretClient.Upsert( - ctx, - types.NamespacedName{Namespace: key.Namespace, Name: instance.Name}, - formattedSecrets, - secrets.WithOwner(instance), - secrets.WithScheme(s.Scheme), - secrets.Flatten(true), - ) - if err != nil { - return false, err - } - } - - userExists, err := s.UserExists(ctx, db, string(DBSecret[SecretUsernameKey])) - if err != nil { - instance.Status.Message = fmt.Sprintf("failed checking for user, err: %v", err) - return false, nil - } - - if !userExists { - user, err = s.CreateUser(ctx, DBSecret, db) - if err != nil { - instance.Status.Message = "failed creating user, err: " + err.Error() - return false, err - } - } - - // apply roles to user - if len(instance.Spec.Roles) == 0 { - instance.Status.Message = "No roles specified for user" - return false, fmt.Errorf("No roles specified for database user") - } - - err = s.GrantUserRoles(ctx, user, instance.Spec.Roles, db) - if err != nil { - fmt.Println(err) - instance.Status.Message = "GrantUserRoles failed" - return false, fmt.Errorf("GrantUserRoles failed") - } - - instance.Status.Provisioned = true - instance.Status.State = "Succeeded" - instance.Status.Message = resourcemanager.SuccessMsg - - return true, nil -} - -func (s *AzureSqlUserManager) Delete(ctx context.Context, obj runtime.Object, opts ...resourcemanager.ConfigOption) (bool, error) { - - options := &resourcemanager.Options{} - for _, opt := range opts { - opt(options) - } - - instance, err := s.convert(obj) - if err != nil { - return false, err - } - - adminSecretClient := s.SecretClient - - adminsecretName := instance.Spec.AdminSecret - if len(instance.Spec.AdminSecret) == 0 { - adminsecretName = instance.Spec.Server - } - key := types.NamespacedName{Name: adminsecretName, Namespace: instance.Namespace} - - var sqlUserSecretClient secrets.SecretClient - if options.SecretClient != nil { - sqlUserSecretClient = options.SecretClient - } else { - sqlUserSecretClient = s.SecretClient - } - - // if the admin secret keyvault is not specified, fall back to global secretclient - if len(instance.Spec.AdminSecretKeyVault) != 0 { - adminSecretClient = keyvaultSecrets.New(instance.Spec.AdminSecretKeyVault) - if len(instance.Spec.AdminSecret) != 0 { - key = types.NamespacedName{Name: instance.Spec.AdminSecret} - } - } - - adminSecret, err := adminSecretClient.Get(ctx, key) - if err != nil { - // assuming if the admin secret is gone the sql server is too - return false, nil - } - - // short circuit connection if database doesn't exist - dbClient := azuresqldb.NewAzureSqlDbManager() - _, err = dbClient.GetDB(ctx, instance.Spec.ResourceGroup, instance.Spec.Server, instance.Spec.DbName) - if err != nil { - instance.Status.Message = err.Error() - - catch := []string{ - errhelp.ResourceNotFound, - errhelp.ParentNotFoundErrorCode, - errhelp.ResourceGroupNotFoundErrorCode, - } - azerr := errhelp.NewAzureErrorAzureError(err) - if helpers.ContainsString(catch, azerr.Type) { - return false, nil - } - return false, err - } - - var user = string(adminSecret[SecretUsernameKey]) - var password = string(adminSecret[SecretPasswordKey]) - - db, err := s.ConnectToSqlDb(ctx, DriverName, instance.Spec.Server, instance.Spec.DbName, SqlServerPort, user, password) - if err != nil { - return false, err - } - - exists, err := s.UserExists(ctx, db, user) - if err != nil { - return true, err - } - if !exists { - s.DeleteSecrets(ctx, instance, sqlUserSecretClient) - return false, nil - } - - err = s.DropUser(ctx, db, user) - if err != nil { - instance.Status.Message = fmt.Sprintf("Delete AzureSqlUser failed with %s", err.Error()) - return false, err - } - - // Once the user has been dropped, also delete their secrets. - s.DeleteSecrets(ctx, instance, sqlUserSecretClient) - - instance.Status.Message = fmt.Sprintf("Delete AzureSqlUser succeeded") - - return true, nil -} - -func (s *AzureSqlUserManager) GetParents(obj runtime.Object) ([]resourcemanager.KubeParent, error) { - instance, err := s.convert(obj) - if err != nil { - return nil, err - } - - return []resourcemanager.KubeParent{ - { - Key: types.NamespacedName{ - Namespace: instance.Namespace, - Name: instance.Spec.DbName, - }, - Target: &v1alpha1.AzureSqlDatabase{}, - }, - { - Key: types.NamespacedName{ - Namespace: instance.Namespace, - Name: instance.Spec.Server, - }, - Target: &v1alpha1.AzureSqlServer{}, - }, - { - Key: types.NamespacedName{ - Namespace: instance.Namespace, - Name: instance.Spec.ResourceGroup, - }, - Target: &v1alpha1.ResourceGroup{}, - }, - }, nil -} - -func (g *AzureSqlUserManager) GetStatus(obj runtime.Object) (*v1alpha1.ASOStatus, error) { - instance, err := g.convert(obj) - if err != nil { - return nil, err - } - return &instance.Status, nil -} - -func (s *AzureSqlUserManager) convert(obj runtime.Object) (*v1alpha1.AzureSQLUser, error) { - local, ok := obj.(*v1alpha1.AzureSQLUser) - if !ok { - return nil, fmt.Errorf("failed type assertion on kind: %s", obj.GetObjectKind().GroupVersionKind().String()) - } - return local, nil -} - -// Deletes the secrets associated with a SQLUser +// DeleteSecrets deletes the secrets associated with a SQLUser func (s *AzureSqlUserManager) DeleteSecrets(ctx context.Context, instance *v1alpha1.AzureSQLUser, secretClient secrets.SecretClient) (bool, error) { // determine our key namespace - if we're persisting to kube, we should use the actual instance namespace. // In keyvault we have some creative freedom to allow more flexibility @@ -609,23 +238,7 @@ func (s *AzureSqlUserManager) GetOrPrepareSecret(ctx context.Context, instance * return secret } -func findBadChars(stack string) error { - badChars := []string{ - "'", - "\"", - ";", - "--", - "/*", - } - - for _, s := range badChars { - if idx := strings.Index(stack, s); idx > -1 { - return fmt.Errorf("potentially dangerous character seqience found: '%s' at pos: %d", s, idx) - } - } - return nil -} - +// GetNamespacedName gets the namespaced-name func GetNamespacedName(instance *v1alpha1.AzureSQLUser, secretClient secrets.SecretClient) types.NamespacedName { var namespacedName types.NamespacedName keyVaultEnabled := reflect.TypeOf(secretClient).Elem().Name() == "KeyvaultSecretClient" @@ -646,3 +259,20 @@ func GetNamespacedName(instance *v1alpha1.AzureSQLUser, secretClient secrets.Sec return namespacedName } + +func findBadChars(stack string) error { + badChars := []string{ + "'", + "\"", + ";", + "--", + "/*", + } + + for _, s := range badChars { + if idx := strings.Index(stack, s); idx > -1 { + return fmt.Errorf("potentially dangerous character seqience found: '%s' at pos: %d", s, idx) + } + } + return nil +} diff --git a/pkg/resourcemanager/azuresql/azuresqluser/azuresqluser_manager.go b/pkg/resourcemanager/azuresql/azuresqluser/azuresqluser_manager.go index 53855077bcb..28a38482dfd 100644 --- a/pkg/resourcemanager/azuresql/azuresqluser/azuresqluser_manager.go +++ b/pkg/resourcemanager/azuresql/azuresqluser/azuresqluser_manager.go @@ -7,15 +7,23 @@ import ( "context" "database/sql" + azuresql "github.com/Azure/azure-sdk-for-go/services/preview/sql/mgmt/2015-05-01-preview/sql" + + "github.com/Azure/azure-service-operator/api/v1alpha1" "github.com/Azure/azure-service-operator/pkg/resourcemanager" + "github.com/Azure/azure-service-operator/pkg/secrets" ) type SqlUserManager interface { + GetDB(ctx context.Context, resourceGroupName string, serverName string, databaseName string) (azuresql.Database, error) ConnectToSqlDb(ctx context.Context, drivername string, server string, dbname string, port int, username string, password string) (*sql.DB, error) GrantUserRoles(ctx context.Context, user string, roles []string, db *sql.DB) error CreateUser(ctx context.Context, secret map[string][]byte, db *sql.DB) (string, error) UserExists(ctx context.Context, db *sql.DB, username string) (bool, error) DropUser(ctx context.Context, db *sql.DB, user string) error + DeleteSecrets(ctx context.Context, instance *v1alpha1.AzureSQLUser, secretClient secrets.SecretClient) (bool, error) + GetOrPrepareSecret(ctx context.Context, instance *v1alpha1.AzureSQLUser, secretClient secrets.SecretClient) map[string][]byte + // also embed methods from AsyncClient resourcemanager.ARMClient } diff --git a/pkg/resourcemanager/azuresql/azuresqluser/azuresqluser_reconcile.go b/pkg/resourcemanager/azuresql/azuresqluser/azuresqluser_reconcile.go new file mode 100644 index 00000000000..235838a30a0 --- /dev/null +++ b/pkg/resourcemanager/azuresql/azuresqluser/azuresqluser_reconcile.go @@ -0,0 +1,419 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +package azuresqluser + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/Azure/azure-service-operator/pkg/helpers" + "github.com/Azure/azure-service-operator/pkg/secrets" + + "github.com/Azure/azure-service-operator/api/v1alpha1" + "github.com/Azure/azure-service-operator/pkg/errhelp" + "github.com/Azure/azure-service-operator/pkg/resourcemanager" + keyvaultSecrets "github.com/Azure/azure-service-operator/pkg/secrets/keyvault" + "github.com/google/uuid" + "k8s.io/apimachinery/pkg/runtime" + + _ "github.com/denisenkom/go-mssqldb" + "k8s.io/apimachinery/pkg/types" +) + +// Ensure that user exists +func (s *AzureSqlUserManager) Ensure(ctx context.Context, obj runtime.Object, opts ...resourcemanager.ConfigOption) (bool, error) { + instance, err := s.convert(obj) + if err != nil { + return false, err + } + + requestedUsername := instance.Spec.Username + if len(requestedUsername) == 0 { + requestedUsername = instance.Name + } + + options := &resourcemanager.Options{} + for _, opt := range opts { + opt(options) + } + + adminSecretClient := s.SecretClient + + adminsecretName := instance.Spec.AdminSecret + if len(instance.Spec.AdminSecret) == 0 { + adminsecretName = instance.Spec.Server + } + + key := types.NamespacedName{Name: adminsecretName, Namespace: instance.Namespace} + + var sqlUserSecretClient secrets.SecretClient + if options.SecretClient != nil { + sqlUserSecretClient = options.SecretClient + } else { + sqlUserSecretClient = s.SecretClient + } + + // if the admin secret keyvault is not specified, fall back to global secretclient + if len(instance.Spec.AdminSecretKeyVault) != 0 { + adminSecretClient = keyvaultSecrets.New(instance.Spec.AdminSecretKeyVault) + if len(instance.Spec.AdminSecret) != 0 { + key = types.NamespacedName{Name: instance.Spec.AdminSecret} + } + } + + // get admin creds for server + adminSecret, err := adminSecretClient.Get(ctx, key) + if err != nil { + instance.Status.Provisioning = false + instance.Status.Message = fmt.Sprintf("admin secret : %s, not found in %s", key.String(), reflect.TypeOf(adminSecretClient).Elem().Name()) + return false, nil + } + + adminUser := string(adminSecret[SecretUsernameKey]) + adminPassword := string(adminSecret[SecretPasswordKey]) + + _, err = s.GetDB(ctx, instance.Spec.ResourceGroup, instance.Spec.Server, instance.Spec.DbName) + if err != nil { + instance.Status.Message = err.Error() + + catch := []string{ + errhelp.ResourceNotFound, + errhelp.ParentNotFoundErrorCode, + errhelp.ResourceGroupNotFoundErrorCode, + } + azerr := errhelp.NewAzureErrorAzureError(err) + if helpers.ContainsString(catch, azerr.Type) { + return false, nil + } + return false, err + } + + db, err := s.ConnectToSqlDb(ctx, DriverName, instance.Spec.Server, instance.Spec.DbName, SqlServerPort, adminUser, adminPassword) + if err != nil { + instance.Status.Message = errhelp.StripErrorIDs(err) + + // catch firewall issue - keep cycling until it clears up + if strings.Contains(err.Error(), "create a firewall rule for this IP address") { + instance.Status.Provisioned = false + instance.Status.Provisioning = false + return false, nil + } + + return false, err + } + + // determine our key namespace - if we're persisting to kube, we should use the actual instance namespace. + // In keyvault we have to avoid collisions with other secrets so we create a custom namespace with the user's parameters + key = GetNamespacedName(instance, sqlUserSecretClient) + + // create or get new user secret + DBSecret := s.GetOrPrepareSecret(ctx, instance, sqlUserSecretClient) + // reset user from secret in case it was loaded + user := string(DBSecret[SecretUsernameKey]) + if user == "" { + user = fmt.Sprintf("%s-%s", requestedUsername, uuid.New()) + DBSecret[SecretUsernameKey] = []byte(user) + } + + // Publishing the user secret: + // We do this first so if the keyvault does not have right permissions we will not proceed to creating the user + err = sqlUserSecretClient.Upsert( + ctx, + key, + DBSecret, + secrets.WithOwner(instance), + secrets.WithScheme(s.Scheme), + ) + if err != nil { + instance.Status.Message = "failed to update secret, err: " + err.Error() + return false, err + } + + // Preformatted special formats are only available through keyvault as they require separated secrets + keyVaultEnabled := reflect.TypeOf(sqlUserSecretClient).Elem().Name() == "KeyvaultSecretClient" + if keyVaultEnabled { + // Instantiate a map of all formats and flip the bool to true for any that have been requested in the spec. + // Formats that were not requested will be explicitly deleted. + requestedFormats := map[string]bool{ + "adonet": false, + "adonet-urlonly": false, + "jdbc": false, + "jdbc-urlonly": false, + "odbc": false, + "odbc-urlonly": false, + "server": false, + "database": false, + "username": false, + "password": false, + } + for _, format := range instance.Spec.KeyVaultSecretFormats { + requestedFormats[format] = true + } + + // Deleted items will be processed immediately but secrets that need to be added will be created in this array and persisted in one pass at the end + formattedSecrets := make(map[string][]byte) + + for formatName, requested := range requestedFormats { + // Add the format to the output map if it has been requested otherwise call for its deletion from the secret store + if requested { + switch formatName { + case "adonet": + formattedSecrets["adonet"] = []byte(fmt.Sprintf( + "Server=tcp:%v,1433;Initial Catalog=%v;Persist Security Info=False;User ID=%v;Password=%v;MultipleActiveResultSets=False;Encrypt=True;TrustServerCertificate=False;Connection Timeout=30;", + string(DBSecret["fullyQualifiedServerName"]), + instance.Spec.DbName, + user, + string(DBSecret["password"]), + )) + + case "adonet-urlonly": + formattedSecrets["adonet-urlonly"] = []byte(fmt.Sprintf( + "Server=tcp:%v,1433;Initial Catalog=%v;Persist Security Info=False; MultipleActiveResultSets=False;Encrypt=True;TrustServerCertificate=False;Connection Timeout", + string(DBSecret["fullyQualifiedServerName"]), + instance.Spec.DbName, + )) + + case "jdbc": + formattedSecrets["jdbc"] = []byte(fmt.Sprintf( + "jdbc:sqlserver://%v:1433;database=%v;user=%v@%v;password=%v;encrypt=true;trustServerCertificate=false;hostNameInCertificate=*.database.windows.net;loginTimeout=30;", + string(DBSecret["fullyQualifiedServerName"]), + instance.Spec.DbName, + user, + instance.Spec.Server, + string(DBSecret["password"]), + )) + case "jdbc-urlonly": + formattedSecrets["jdbc-urlonly"] = []byte(fmt.Sprintf( + "jdbc:sqlserver://%v:1433;database=%v;encrypt=true;trustServerCertificate=false;hostNameInCertificate=*.database.windows.net;loginTimeout=30;", + string(DBSecret["fullyQualifiedServerName"]), + instance.Spec.DbName, + )) + + case "odbc": + formattedSecrets["odbc"] = []byte(fmt.Sprintf( + "Server=tcp:%v,1433;Initial Catalog=%v;Persist Security Info=False;User ID=%v;Password=%v;MultipleActiveResultSets=False;Encrypt=True;TrustServerCertificate=False;Connection Timeout=30;", + string(DBSecret["fullyQualifiedServerName"]), + instance.Spec.DbName, + user, + string(DBSecret["password"]), + )) + case "odbc-urlonly": + formattedSecrets["odbc-urlonly"] = []byte(fmt.Sprintf( + "Driver={ODBC Driver 13 for SQL Server};Server=tcp:%v,1433;Database=%v; Encrypt=yes;TrustServerCertificate=no;Connection Timeout=30;", + string(DBSecret["fullyQualifiedServerName"]), + instance.Spec.DbName, + )) + case "server": + formattedSecrets["server"] = DBSecret["fullyQualifiedServerName"] + + case "database": + formattedSecrets["database"] = []byte(instance.Spec.DbName) + + case "username": + formattedSecrets["username"] = []byte(user) + + case "password": + formattedSecrets["password"] = DBSecret["password"] + } + } else { + err = sqlUserSecretClient.Delete( + ctx, + types.NamespacedName{Namespace: key.Namespace, Name: instance.Name + "-" + formatName}, + ) + } + } + + err = sqlUserSecretClient.Upsert( + ctx, + types.NamespacedName{Namespace: key.Namespace, Name: instance.Name}, + formattedSecrets, + secrets.WithOwner(instance), + secrets.WithScheme(s.Scheme), + secrets.Flatten(true), + ) + if err != nil { + return false, err + } + } + + userExists, err := s.UserExists(ctx, db, string(DBSecret[SecretUsernameKey])) + if err != nil { + instance.Status.Message = fmt.Sprintf("failed checking for user, err: %v", err) + return false, nil + } + + if !userExists { + user, err = s.CreateUser(ctx, DBSecret, db) + if err != nil { + instance.Status.Message = "failed creating user, err: " + err.Error() + return false, err + } + } + + // apply roles to user + if len(instance.Spec.Roles) == 0 { + instance.Status.Message = "No roles specified for user" + return false, fmt.Errorf("No roles specified for database user") + } + + err = s.GrantUserRoles(ctx, user, instance.Spec.Roles, db) + if err != nil { + fmt.Println(err) + instance.Status.Message = "GrantUserRoles failed" + return false, fmt.Errorf("GrantUserRoles failed") + } + + instance.Status.Provisioned = true + instance.Status.State = "Succeeded" + instance.Status.Message = resourcemanager.SuccessMsg + + return true, nil +} + +// Delete deletes a user +func (s *AzureSqlUserManager) Delete(ctx context.Context, obj runtime.Object, opts ...resourcemanager.ConfigOption) (bool, error) { + + options := &resourcemanager.Options{} + for _, opt := range opts { + opt(options) + } + + instance, err := s.convert(obj) + if err != nil { + return false, err + } + + adminSecretClient := s.SecretClient + + adminsecretName := instance.Spec.AdminSecret + if len(instance.Spec.AdminSecret) == 0 { + adminsecretName = instance.Spec.Server + } + key := types.NamespacedName{Name: adminsecretName, Namespace: instance.Namespace} + + var sqlUserSecretClient secrets.SecretClient + if options.SecretClient != nil { + sqlUserSecretClient = options.SecretClient + } else { + sqlUserSecretClient = s.SecretClient + } + + // if the admin secret keyvault is not specified, fall back to global secretclient + if len(instance.Spec.AdminSecretKeyVault) != 0 { + adminSecretClient = keyvaultSecrets.New(instance.Spec.AdminSecretKeyVault) + if len(instance.Spec.AdminSecret) != 0 { + key = types.NamespacedName{Name: instance.Spec.AdminSecret} + } + } + + adminSecret, err := adminSecretClient.Get(ctx, key) + if err != nil { + // assuming if the admin secret is gone the sql server is too + return false, nil + } + + // short circuit connection if database doesn't exist + _, err = s.GetDB(ctx, instance.Spec.ResourceGroup, instance.Spec.Server, instance.Spec.DbName) + if err != nil { + instance.Status.Message = err.Error() + + catch := []string{ + errhelp.ResourceNotFound, + errhelp.ParentNotFoundErrorCode, + errhelp.ResourceGroupNotFoundErrorCode, + } + azerr := errhelp.NewAzureErrorAzureError(err) + if helpers.ContainsString(catch, azerr.Type) { + return false, nil + } + return false, err + } + + var user = string(adminSecret[SecretUsernameKey]) + var password = string(adminSecret[SecretPasswordKey]) + + db, err := s.ConnectToSqlDb(ctx, DriverName, instance.Spec.Server, instance.Spec.DbName, SqlServerPort, user, password) + if err != nil { + instance.Status.Message = errhelp.StripErrorIDs(err) + if strings.Contains(err.Error(), "create a firewall rule for this IP address") { + + // there is nothing much we can do here - cycle forever + return true, err + } + return false, err + } + + exists, err := s.UserExists(ctx, db, user) + if err != nil { + return true, err + } + if !exists { + s.DeleteSecrets(ctx, instance, sqlUserSecretClient) + return false, nil + } + + err = s.DropUser(ctx, db, user) + if err != nil { + instance.Status.Message = fmt.Sprintf("Delete AzureSqlUser failed with %s", err.Error()) + return false, err + } + + // Once the user has been dropped, also delete their secrets. + s.DeleteSecrets(ctx, instance, sqlUserSecretClient) + + instance.Status.Message = fmt.Sprintf("Delete AzureSqlUser succeeded") + + return true, nil +} + +// GetParents gets the parents of the user +func (s *AzureSqlUserManager) GetParents(obj runtime.Object) ([]resourcemanager.KubeParent, error) { + instance, err := s.convert(obj) + if err != nil { + return nil, err + } + + return []resourcemanager.KubeParent{ + { + Key: types.NamespacedName{ + Namespace: instance.Namespace, + Name: instance.Spec.DbName, + }, + Target: &v1alpha1.AzureSqlDatabase{}, + }, + { + Key: types.NamespacedName{ + Namespace: instance.Namespace, + Name: instance.Spec.Server, + }, + Target: &v1alpha1.AzureSqlServer{}, + }, + { + Key: types.NamespacedName{ + Namespace: instance.Namespace, + Name: instance.Spec.ResourceGroup, + }, + Target: &v1alpha1.ResourceGroup{}, + }, + }, nil +} + +// GetStatus gets the status +func (s *AzureSqlUserManager) GetStatus(obj runtime.Object) (*v1alpha1.ASOStatus, error) { + instance, err := s.convert(obj) + if err != nil { + return nil, err + } + return &instance.Status, nil +} + +func (s *AzureSqlUserManager) convert(obj runtime.Object) (*v1alpha1.AzureSQLUser, error) { + local, ok := obj.(*v1alpha1.AzureSQLUser) + if !ok { + return nil, fmt.Errorf("failed type assertion on kind: %s", obj.GetObjectKind().GroupVersionKind().String()) + } + return local, nil +}