Skip to content

Commit

Permalink
Support non-public cloud environments in the Azure Storage Queue and …
Browse files Browse the repository at this point in the history
…Azure Storage Blob scalers (#1863)
  • Loading branch information
amirschw authored Jun 8, 2021
1 parent b68db58 commit fb76638
Show file tree
Hide file tree
Showing 14 changed files with 139 additions and 27 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

### New

- Support non-public cloud environments in the Azure Storage Queue and Azure Storage Blob scalers ([#1863](https://github.com/kedacore/keda/pull/1863))
- Show HashiCorp Vault Address when using `kubectl get ta` or `kubectl get cta` ([#1862](https://github.com/kedacore/keda/pull/1862))

### Improvements
Expand Down
4 changes: 2 additions & 2 deletions pkg/scalers/azure/azure_blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
)

// GetAzureBlobListLength returns the count of the blobs in blob container in int
func GetAzureBlobListLength(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, blobContainerName string, accountName string, blobDelimiter string, blobPrefix string) (int, error) {
credential, endpoint, err := ParseAzureStorageBlobConnection(httpClient, podIdentity, connectionString, accountName)
func GetAzureBlobListLength(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, blobContainerName string, accountName string, blobDelimiter string, blobPrefix string, endpointSuffix string) (int, error) {
credential, endpoint, err := ParseAzureStorageBlobConnection(httpClient, podIdentity, connectionString, accountName, endpointSuffix)
if err != nil {
return -1, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/scalers/azure/azure_blob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

func TestGetBlobLength(t *testing.T) {
httpClient := http.DefaultClient
length, err := GetAzureBlobListLength(context.TODO(), httpClient, "", "", "blobContainerName", "", "", "")
length, err := GetAzureBlobListLength(context.TODO(), httpClient, "", "", "blobContainerName", "", "", "", "")
if length != -1 {
t.Error("Expected length to be -1, but got", length)
}
Expand All @@ -22,7 +22,7 @@ func TestGetBlobLength(t *testing.T) {
t.Error("Expected error to contain parsing error message, but got", err.Error())
}

length, err = GetAzureBlobListLength(context.TODO(), httpClient, "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "blobContainerName", "", "", "")
length, err = GetAzureBlobListLength(context.TODO(), httpClient, "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "blobContainerName", "", "", "", "")

if length != -1 {
t.Error("Expected length to be -1, but got", length)
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/azure/azure_eventhub_checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func (checkpointer *defaultCheckpointer) extractCheckpoint(get *azblob.DownloadR
}

func getCheckpoint(ctx context.Context, httpClient util.HTTPDoer, info EventHubInfo, checkpointer checkpointer) (Checkpoint, error) {
blobCreds, storageEndpoint, err := ParseAzureStorageBlobConnection(httpClient, kedav1alpha1.PodIdentityProviderNone, info.StorageConnection, "")
blobCreds, storageEndpoint, err := ParseAzureStorageBlobConnection(httpClient, kedav1alpha1.PodIdentityProviderNone, info.StorageConnection, "", "")
if err != nil {
return Checkpoint{}, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/azure/azure_eventhub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func TestShouldParseCheckpointForGoSdk(t *testing.T) {
}

func createNewCheckpointInStorage(urlPath string, containerName string, partitionID string, checkpoint string, metadata map[string]string) (context.Context, error) {
credential, endpoint, _ := ParseAzureStorageBlobConnection(http.DefaultClient, "none", StorageConnectionString, "")
credential, endpoint, _ := ParseAzureStorageBlobConnection(http.DefaultClient, "none", StorageConnectionString, "", "")

// Create container
ctx := context.Background()
Expand Down
4 changes: 2 additions & 2 deletions pkg/scalers/azure/azure_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
)

// GetAzureQueueLength returns the length of a queue in int
func GetAzureQueueLength(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, queueName string, accountName string) (int32, error) {
credential, endpoint, err := ParseAzureStorageQueueConnection(httpClient, podIdentity, connectionString, accountName)
func GetAzureQueueLength(ctx context.Context, httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, queueName, accountName, endpointSuffix string) (int32, error) {
credential, endpoint, err := ParseAzureStorageQueueConnection(httpClient, podIdentity, connectionString, accountName, endpointSuffix)
if err != nil {
return -1, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/scalers/azure/azure_queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

func TestGetQueueLength(t *testing.T) {
length, err := GetAzureQueueLength(context.TODO(), http.DefaultClient, "", "", "queueName", "")
length, err := GetAzureQueueLength(context.TODO(), http.DefaultClient, "", "", "queueName", "", "")
if length != -1 {
t.Error("Expected length to be -1, but got", length)
}
Expand All @@ -21,7 +21,7 @@ func TestGetQueueLength(t *testing.T) {
t.Error("Expected error to contain parsing error message, but got", err.Error())
}

length, err = GetAzureQueueLength(context.TODO(), http.DefaultClient, "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "queueName", "")
length, err = GetAzureQueueLength(context.TODO(), http.DefaultClient, "", "DefaultEndpointsProtocol=https;AccountName=name;AccountKey=key==;EndpointSuffix=core.windows.net", "queueName", "", "")

if length != -1 {
t.Error("Expected length to be -1, but got", length)
Expand Down
67 changes: 51 additions & 16 deletions pkg/scalers/azure/azure_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/Azure/azure-storage-blob-go/azblob"
"github.com/Azure/azure-storage-queue-go/azqueue"
az "github.com/Azure/go-autorest/autorest/azure"

kedav1alpha1 "github.com/kedacore/keda/v2/api/v1alpha1"
"github.com/kedacore/keda/v2/pkg/util"
Expand All @@ -30,6 +31,9 @@ const (
TableEndpoint
// FileEndpoint storage type
FileEndpoint

// PrivateCloud cloud type
PrivateCloud string = "Private"
)

// Prefix returns prefix for a StorageEndpointType
Expand All @@ -42,21 +46,42 @@ func (e StorageEndpointType) Name() string {
return [...]string{"blob", "queue", "table", "file"}[e]
}

// GetEndpointSuffix returns the endpoint suffix for a StorageEndpointType based on the specified environment
func (e StorageEndpointType) GetEndpointSuffix(environment az.Environment) string {
return fmt.Sprintf("%s.%s", e.Name(), environment.StorageEndpointSuffix)
}

// ParseAzureStorageEndpointSuffix parses cloud and endpointSuffix metadata and returns endpoint suffix
func ParseAzureStorageEndpointSuffix(metadata map[string]string, endpointType StorageEndpointType) (string, error) {
if val, ok := metadata["cloud"]; ok && val != "" {
if strings.EqualFold(val, PrivateCloud) {
if val, ok := metadata["endpointSuffix"]; ok && val != "" {
return val, nil
}
return "", fmt.Errorf("endpointSuffix must be provided for %s cloud type", PrivateCloud)
}

env, err := az.EnvironmentFromName(val)
if err != nil {
return "", fmt.Errorf("invalid cloud environment %s", val)
}
return endpointType.GetEndpointSuffix(env), nil
}

// Use the default public cloud endpoint suffix if `cloud` isn't specified
return endpointType.GetEndpointSuffix(az.PublicCloud), nil
}

// ParseAzureStorageQueueConnection parses queue connection string and returns credential and resource url
func ParseAzureStorageQueueConnection(httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName string) (azqueue.Credential, *url.URL, error) {
func ParseAzureStorageQueueConnection(httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName, endpointSuffix string) (azqueue.Credential, *url.URL, error) {
switch podIdentity {
case kedav1alpha1.PodIdentityProviderAzure:
token, err := GetAzureADPodIdentityToken(httpClient, "https://storage.azure.com/")
token, endpoint, err := parseAcessTokenAndEndpoint(httpClient, accountName, endpointSuffix)
if err != nil {
return nil, nil, err
}

if accountName == "" {
return nil, nil, fmt.Errorf("accountName is required for podIdentity azure")
}

credential := azqueue.NewTokenCredential(token.AccessToken, nil)
endpoint, _ := url.Parse(fmt.Sprintf("https://%s.queue.core.windows.net", accountName))
credential := azqueue.NewTokenCredential(token, nil)
return credential, endpoint, nil
case "", kedav1alpha1.PodIdentityProviderNone:
endpoint, accountName, accountKey, err := parseAzureStorageConnectionString(connectionString, QueueEndpoint)
Expand All @@ -76,20 +101,15 @@ func ParseAzureStorageQueueConnection(httpClient util.HTTPDoer, podIdentity keda
}

// ParseAzureStorageBlobConnection parses blob connection string and returns credential and resource url
func ParseAzureStorageBlobConnection(httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName string) (azblob.Credential, *url.URL, error) {
func ParseAzureStorageBlobConnection(httpClient util.HTTPDoer, podIdentity kedav1alpha1.PodIdentityProvider, connectionString, accountName, endpointSuffix string) (azblob.Credential, *url.URL, error) {
switch podIdentity {
case kedav1alpha1.PodIdentityProviderAzure:
token, err := GetAzureADPodIdentityToken(httpClient, "https://storage.azure.com/")
token, endpoint, err := parseAcessTokenAndEndpoint(httpClient, accountName, endpointSuffix)
if err != nil {
return nil, nil, err
}

if accountName == "" {
return nil, nil, fmt.Errorf("accountName is required for podIdentity azure")
}

credential := azblob.NewTokenCredential(token.AccessToken, nil)
endpoint, _ := url.Parse(fmt.Sprintf("https://%s.blob.core.windows.net", accountName))
credential := azblob.NewTokenCredential(token, nil)
return credential, endpoint, nil
case "", kedav1alpha1.PodIdentityProviderNone:
endpoint, accountName, accountKey, err := parseAzureStorageConnectionString(connectionString, BlobEndpoint)
Expand Down Expand Up @@ -164,3 +184,18 @@ func parseAzureStorageConnectionString(connectionString string, endpointType Sto

return u, name, key, nil
}

func parseAcessTokenAndEndpoint(httpClient util.HTTPDoer, accountName string, endpointSuffix string) (string, *url.URL, error) {
// Azure storage resource is "https://storage.azure.com/" in all cloud environments
token, err := GetAzureADPodIdentityToken(httpClient, "https://storage.azure.com/")
if err != nil {
return "", nil, err
}

if accountName == "" {
return "", nil, fmt.Errorf("accountName is required for podIdentity azure")
}

endpoint, _ := url.Parse(fmt.Sprintf("https://%s.%s", accountName, endpointSuffix))
return token.AccessToken, endpoint, nil
}
36 changes: 36 additions & 0 deletions pkg/scalers/azure/azure_storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,39 @@ func TestParseStorageConnectionString(t *testing.T) {
}
}
}

type parseEndpointSuffixTestData struct {
metadata map[string]string
endpointSuffix string
endpointType StorageEndpointType
isError bool
}

var parseEndpointSuffixTestDataset = []parseEndpointSuffixTestData{
{map[string]string{}, "queue.core.windows.net", QueueEndpoint, false},
{map[string]string{"cloud": "InvalidCloud"}, "", QueueEndpoint, true},
{map[string]string{"cloud": "AzureUSGovernmentCloud"}, "queue.core.usgovcloudapi.net", QueueEndpoint, false},
{map[string]string{"cloud": "Private"}, "", BlobEndpoint, true},
{map[string]string{"cloud": "Private", "endpointSuffix": "blob.core.private.cloud"}, "blob.core.private.cloud", BlobEndpoint, false},
{map[string]string{"endpointSuffix": "ignored"}, "blob.core.windows.net", BlobEndpoint, false},
}

func TestParseAzureStorageEndpointSuffix(t *testing.T) {
for _, testData := range parseEndpointSuffixTestDataset {
endpointSuffix, err := ParseAzureStorageEndpointSuffix(testData.metadata, testData.endpointType)
if !testData.isError && err != nil {
t.Error("Expected success but got error", err)
}
if testData.isError && err == nil {
t.Error("Expected error but got success")
}
if err == nil {
if endpointSuffix != testData.endpointSuffix {
t.Error(
"For", testData.metadata,
"expected endpointSuffix=", testData.endpointSuffix,
"but got", endpointSuffix)
}
}
}
}
10 changes: 10 additions & 0 deletions pkg/scalers/azure_blob_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type azureBlobMetadata struct {
connection string
accountName string
metricName string
endpointSuffix string
}

var azureBlobLog = logf.Log.WithName("azure_blob_scaler")
Expand Down Expand Up @@ -88,6 +89,13 @@ func parseAzureBlobMetadata(config *ScalerConfig) (*azureBlobMetadata, kedav1alp
meta.blobPrefix = val + meta.blobDelimiter
}

endpointSuffix, err := azure.ParseAzureStorageEndpointSuffix(config.TriggerMetadata, azure.BlobEndpoint)
if err != nil {
return nil, "", err
}

meta.endpointSuffix = endpointSuffix

// before triggerAuthentication CRD, pod identity was configured using this property
if val, ok := config.TriggerMetadata["useAAdPodIdentity"]; ok && config.PodIdentity == "" && val == "true" {
config.PodIdentity = kedav1alpha1.PodIdentityProviderAzure
Expand Down Expand Up @@ -143,6 +151,7 @@ func (s *azureBlobScaler) IsActive(ctx context.Context) (bool, error) {
s.metadata.accountName,
s.metadata.blobDelimiter,
s.metadata.blobPrefix,
s.metadata.endpointSuffix,
)

if err != nil {
Expand Down Expand Up @@ -183,6 +192,7 @@ func (s *azureBlobScaler) GetMetrics(ctx context.Context, metricName string, met
s.metadata.accountName,
s.metadata.blobDelimiter,
s.metadata.blobPrefix,
s.metadata.endpointSuffix,
)

if err != nil {
Expand Down
10 changes: 10 additions & 0 deletions pkg/scalers/azure_blob_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ var testAzBlobMetadata = []parseAzBlobMetadataTestData{
{map[string]string{"accountName": "", "blobContainerName": "sample_container"}, true, testAzBlobResolvedEnv, map[string]string{}, kedav1alpha1.PodIdentityProviderAzure},
// podIdentity = azure without blob container name
{map[string]string{"accountName": "sample_acc", "blobContainerName": ""}, true, testAzBlobResolvedEnv, map[string]string{}, kedav1alpha1.PodIdentityProviderAzure},
// podIdentity = azure with cloud
{map[string]string{"accountName": "sample_acc", "blobContainerName": "sample_container", "cloud": "AzureGermanCloud"}, false, testAzBlobResolvedEnv, map[string]string{}, kedav1alpha1.PodIdentityProviderAzure},
// podIdentity = azure with invalid cloud
{map[string]string{"accountName": "sample_acc", "blobContainerName": "sample_container", "cloud": "InvalidCloud"}, true, testAzBlobResolvedEnv, map[string]string{}, kedav1alpha1.PodIdentityProviderAzure},
// podIdentity = azure with private cloud and endpoint suffix
{map[string]string{"accountName": "sample_acc", "blobContainerName": "sample_container", "cloud": "Private", "endpointSuffix": "queue.core.private.cloud"}, false, testAzBlobResolvedEnv, map[string]string{}, kedav1alpha1.PodIdentityProviderAzure},
// podIdentity = azure with private cloud and no endpoint suffix
{map[string]string{"accountName": "sample_acc", "blobContainerName": "sample_container", "cloud": "Private", "endpointSuffix": ""}, true, testAzBlobResolvedEnv, map[string]string{}, kedav1alpha1.PodIdentityProviderAzure},
// podIdentity = azure with endpoint suffix and no cloud
{map[string]string{"accountName": "sample_acc", "blobContainerName": "sample_container", "cloud": "", "endpointSuffix": "ignored"}, false, testAzBlobResolvedEnv, map[string]string{}, kedav1alpha1.PodIdentityProviderAzure},
// connection from authParams
{map[string]string{"blobContainerName": "sample_container", "blobCount": "5"}, false, testAzBlobResolvedEnv, map[string]string{"connection": "value"}, kedav1alpha1.PodIdentityProviderNone},
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/scalers/azure_eventhub_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func TestGetUnprocessedEventCountInPartition(t *testing.T) {

if eventHubKey != "" && storageConnectionString != "" {
eventHubConnectionString := fmt.Sprintf("Endpoint=sb://%s.servicebus.windows.net/;SharedAccessKeyName=RootManageSharedAccessKey;SharedAccessKey=%s;EntityPath=%s", testEventHubNamespace, eventHubKey, testEventHubName)
storageCredentials, endpoint, err := azure.ParseAzureStorageBlobConnection(http.DefaultClient, "none", storageConnectionString, "")
storageCredentials, endpoint, err := azure.ParseAzureStorageBlobConnection(http.DefaultClient, "none", storageConnectionString, "", "")
if err != nil {
t.Error(err)
t.FailNow()
Expand Down
10 changes: 10 additions & 0 deletions pkg/scalers/azure_queue_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type azureQueueMetadata struct {
queueName string
connection string
accountName string
endpointSuffix string
}

var azureQueueLog = logf.Log.WithName("azure_queue_scaler")
Expand Down Expand Up @@ -68,6 +69,13 @@ func parseAzureQueueMetadata(config *ScalerConfig) (*azureQueueMetadata, kedav1a
meta.targetQueueLength = queueLength
}

endpointSuffix, err := azure.ParseAzureStorageEndpointSuffix(config.TriggerMetadata, azure.QueueEndpoint)
if err != nil {
return nil, "", err
}

meta.endpointSuffix = endpointSuffix

if val, ok := config.TriggerMetadata["queueName"]; ok && val != "" {
meta.queueName = val
} else {
Expand Down Expand Up @@ -120,6 +128,7 @@ func (s *azureQueueScaler) IsActive(ctx context.Context) (bool, error) {
s.metadata.connection,
s.metadata.queueName,
s.metadata.accountName,
s.metadata.endpointSuffix,
)

if err != nil {
Expand Down Expand Up @@ -158,6 +167,7 @@ func (s *azureQueueScaler) GetMetrics(ctx context.Context, metricName string, me
s.metadata.connection,
s.metadata.queueName,
s.metadata.accountName,
s.metadata.endpointSuffix,
)

if err != nil {
Expand Down
10 changes: 10 additions & 0 deletions pkg/scalers/azure_queue_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ var testAzQueueMetadata = []parseAzQueueMetadataTestData{
{map[string]string{"accountName": "", "queueName": "sample_queue"}, true, testAzQueueResolvedEnv, map[string]string{}, kedav1alpha1.PodIdentityProviderAzure},
// podIdentity = azure without queue name
{map[string]string{"accountName": "sample_acc", "queueName": ""}, true, testAzQueueResolvedEnv, map[string]string{}, kedav1alpha1.PodIdentityProviderAzure},
// podIdentity = azure with cloud
{map[string]string{"accountName": "sample_acc", "queueName": "sample_queue", "cloud": "AzurePublicCloud"}, false, testAzQueueResolvedEnv, map[string]string{}, kedav1alpha1.PodIdentityProviderAzure},
// podIdentity = azure with invalid cloud
{map[string]string{"accountName": "sample_acc", "queueName": "sample_queue", "cloud": "InvalidCloud"}, true, testAzQueueResolvedEnv, map[string]string{}, kedav1alpha1.PodIdentityProviderAzure},
// podIdentity = azure with private cloud and endpoint suffix
{map[string]string{"accountName": "sample_acc", "queueName": "sample_queue", "cloud": "Private", "endpointSuffix": "queue.core.private.cloud"}, false, testAzQueueResolvedEnv, map[string]string{}, kedav1alpha1.PodIdentityProviderAzure},
// podIdentity = azure with private cloud and no endpoint suffix
{map[string]string{"accountName": "sample_acc", "queueName": "sample_queue", "cloud": "Private", "endpointSuffix": ""}, true, testAzQueueResolvedEnv, map[string]string{}, kedav1alpha1.PodIdentityProviderAzure},
// podIdentity = azure with endpoint suffix and no cloud
{map[string]string{"accountName": "sample_acc", "queueName": "sample_queue", "cloud": "", "endpointSuffix": "ignored"}, false, testAzQueueResolvedEnv, map[string]string{}, kedav1alpha1.PodIdentityProviderAzure},
// connection from authParams
{map[string]string{"queueName": "sample", "queueLength": "5"}, false, testAzQueueResolvedEnv, map[string]string{"connection": "value"}, kedav1alpha1.PodIdentityProviderNone},
}
Expand Down

0 comments on commit fb76638

Please sign in to comment.