Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: support azure MangedIdentity TokenCredential #590

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
</PropertyGroup>
<ItemGroup>
<PackageVersion Include="Azure.Storage.Blobs" Version="12.14.1" />
<PackageVersion Include="Azure.Identity" Version="1.12.0" />
<PackageVersion Include="BenchmarkDotNet" Version="0.13.12" />
<PackageVersion Include="BenchmarkDotNet.Diagnostics.Windows" Version="0.13.12" />
<PackageVersion Include="CommandLineParser" Version="2.9.1" />
Expand Down
30 changes: 27 additions & 3 deletions libs/host/Configuration/Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Net;
using System.Reflection;
using System.Security.Cryptography.X509Certificates;
using Azure.Identity;
using CommandLine;
using Garnet.server;
using Garnet.server.Auth.Aad;
Expand Down Expand Up @@ -295,6 +296,13 @@ internal sealed class Options
[Option("use-azure-storage", Required = false, HelpText = "Use Azure Page Blobs for storage instead of local storage.")]
public bool? UseAzureStorage { get; set; }

[HttpsUrlValidation]
[Option("storage-service-uri", Required = false, HelpText = "The URI to use when establishing connection to Azure Blobs Storage.")]
public string AzureStorageServiceUri { get; set; }

[Option("storage-managed-identity", Required = false, HelpText = "The managed identity to use when establishing connection to Azure Blobs Storage.")]
public string AzureStorageManagedIdentity { get; set; }

[Option("storage-string", Required = false, HelpText = "The connection string to use when establishing connection to Azure Blobs Storage.")]
public string AzureStorageConnectionString { get; set; }

Expand Down Expand Up @@ -487,8 +495,17 @@ public GarnetServerOptions GetServerOptions(ILogger logger = null)
var enableStorageTier = EnableStorageTier.GetValueOrDefault();
var enableRevivification = EnableRevivification.GetValueOrDefault();

if (useAzureStorage && string.IsNullOrEmpty(AzureStorageConnectionString))
throw new Exception("Cannot enable use-azure-storage without supplying storage-string.");
if (useAzureStorage && (
string.IsNullOrEmpty(AzureStorageConnectionString)
&& (string.IsNullOrEmpty(AzureStorageServiceUri) || string.IsNullOrEmpty(AzureStorageManagedIdentity))))
{
throw new InvalidAzureConfiguration("Cannot enable use-azure-storage without supplying storage-string or storage-service-uri & storage-managed-identity");
}
if (useAzureStorage && !string.IsNullOrEmpty(AzureStorageConnectionString)
&& (!string.IsNullOrEmpty(AzureStorageServiceUri) || !string.IsNullOrEmpty(AzureStorageManagedIdentity)))
{
throw new InvalidAzureConfiguration("Cannot enable use-azure-storage with both storage-string and storage-service-uri or storage-managed-identity");
}

var logDir = LogDir;
if (!useAzureStorage && enableStorageTier) logDir = new DirectoryInfo(string.IsNullOrEmpty(logDir) ? "." : logDir).FullName;
Expand Down Expand Up @@ -606,7 +623,9 @@ public GarnetServerOptions GetServerOptions(ILogger logger = null)
ThreadPoolMinThreads = ThreadPoolMinThreads,
ThreadPoolMaxThreads = ThreadPoolMaxThreads,
DeviceFactoryCreator = useAzureStorage
? () => new AzureStorageNamedDeviceFactory(AzureStorageConnectionString, logger)
? string.IsNullOrEmpty(AzureStorageConnectionString)
? () => new AzureStorageNamedDeviceFactory(AzureStorageServiceUri, new ManagedIdentityCredential(AzureStorageManagedIdentity), logger)
: () => new AzureStorageNamedDeviceFactory(AzureStorageConnectionString, logger)
: () => new LocalStorageNamedDeviceFactory(useNativeDeviceLinux: UseNativeDeviceLinux.GetValueOrDefault(), logger: logger),
CheckpointThrottleFlushDelayMs = CheckpointThrottleFlushDelayMs,
EnableScatterGatherGet = EnableScatterGatherGet.GetValueOrDefault(),
Expand Down Expand Up @@ -665,4 +684,9 @@ internal enum ConfigFileType
// Redis.conf file format
RedisConf = 1,
}

public class InvalidAzureConfiguration : Exception
{
public InvalidAzureConfiguration(string message) : base(message) { }
}
}
42 changes: 35 additions & 7 deletions libs/host/Configuration/OptionsValidators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,8 @@ internal LogDirValidationAttribute(bool mustExist, bool isRequired) : base(mustE
}

/// <summary>
/// Validation logic for Log Directory, valid if UseAzureStorage is specified or if EnableStorageTier is not specified in parent Options object
/// If neither applies, reverts to OptionValidationAttribute validation
/// Validation logic for Log Directory, valid if <see cref="Options.UseAzureStorage"/> is specified or if <see cref="Options.EnableStorageTier"/> is not specified in parent Options object
/// If neither applies, reverts to <see cref="OptionValidationAttribute"/> validation
/// </summary>
/// <param name="value">Value of Log Directory</param>
/// <param name="validationContext">Validation context</param>
Expand All @@ -483,13 +483,12 @@ internal CheckpointDirValidationAttribute(bool mustExist, bool isRequired) : bas
}

/// <summary>
/// Validation logic for Checkpoint Directory, valid if UseAzureStorage is specified in parent Options object
/// If not, reverts to OptionValidationAttribute validation
/// Validation logic for <see cref="Options.CheckpointDir"/>, valid if <see cref="Options.UseAzureStorage"/> is specified in parent Options object
/// If not, reverts to <see cref="OptionValidationAttribute"/> validation
/// </summary>
/// <param name="value">Value of Log Directory</param>
/// <param name="validationContext">Validation context</param>
/// <returns>Validation result</returns>
/// <returns></returns>
protected override ValidationResult IsValid(object value, ValidationContext validationContext)
{
var options = (Options)validationContext.ObjectInstance;
Expand All @@ -501,7 +500,7 @@ protected override ValidationResult IsValid(object value, ValidationContext vali
}

/// <summary>
/// Validation logic for CertFileName
/// Validation logic for <see cref="Options.CertFileName"/>
/// </summary>
[AttributeUsage(AttributeTargets.Property)]
internal sealed class CertFileValidationAttribute : FilePathValidationAttribute
Expand All @@ -518,7 +517,6 @@ internal CertFileValidationAttribute(bool fileMustExist, bool directoryMustExist
/// <param name="value">Value of CertFileName</param>
/// <param name="validationContext">Validation context</param>
/// <returns>Validation result</returns>
/// <returns></returns>
protected override ValidationResult IsValid(object value, ValidationContext validationContext)
{
var options = (Options)validationContext.ObjectInstance;
Expand All @@ -528,4 +526,34 @@ protected override ValidationResult IsValid(object value, ValidationContext vali
return base.IsValid(value, validationContext);
}
}

/// <summary>
/// Represents an attribute used for validating HTTPS URLs as options.
/// </summary>
[AttributeUsage(AttributeTargets.Property)]
internal sealed class HttpsUrlValidationAttribute : OptionValidationAttribute
{
internal HttpsUrlValidationAttribute(bool isRequired = false) : base(isRequired)
{
}

/// <summary>
/// HTTPS URLs validation logic, checks if string is a valid HTTPS URL.
/// </summary>
/// <param name="value">URL string</param>
/// <param name="validationContext">Validation Logic</param>
/// <returns>Validation result</returns>
protected override ValidationResult IsValid(object value, ValidationContext validationContext)
{
if (TryInitialValidation<string>(value, validationContext, out var initValidationResult, out var url))
return initValidationResult;

if (Uri.TryCreate(url, UriKind.Absolute, out var uri) && uri.Scheme == Uri.UriSchemeHttps)
return ValidationResult.Success;

var baseError = validationContext.MemberName != null ? base.FormatErrorMessage(validationContext.MemberName) : string.Empty;
var errorMessage = $"{baseError} Expected string in URI format. Actual value: {url}";
return new ValidationResult(errorMessage, [validationContext.MemberName]);
}
}
}
1 change: 1 addition & 0 deletions libs/host/Garnet.host.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

<ItemGroup>
<PackageReference Include="CommandLineParser" />
<PackageReference Include="Azure.Identity" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" />
<PackageReference Include="Microsoft.SourceLink.GitHub" PrivateAssets="All" />
<PackageReference Include="Microsoft.Extensions.Logging" />
Expand Down
6 changes: 6 additions & 0 deletions libs/host/defaults.conf
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@
/* The connection string to use when establishing connection to Azure Blobs Storage. */
"AzureStorageConnectionString" : null,

/* The URI to use when establishing connection to Azure Blobs Storage. */
"AzureStorageServiceUri": null,

/* The managed identity to use when establishing connection to Azure Blobs Storage. */
"AzureStorageManagedIdentity": null,

/* Whether and by how much should we throttle the disk IO for checkpoints: -1 - disable throttling; >= 0 - run checkpoint flush in separate task, sleep for specified time after each WriteAsync */
"CheckpointThrottleFlushDelayMs" : 0,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Azure.Core;
using Microsoft.Extensions.Logging;
using Tsavorite.core;

Expand All @@ -29,6 +30,17 @@ public AzureStorageNamedDeviceFactory(string connectionString, ILogger logger =
{
}

/// <summary>
/// Create instance of factory for Azure devices
/// </summary>
/// <param name="serviceUri"></param>
/// <param name="credential"></param>
/// <param name="logger"></param>
public AzureStorageNamedDeviceFactory(string serviceUri, TokenCredential credential, ILogger logger = null)
Meir017 marked this conversation as resolved.
Show resolved Hide resolved
: this(BlobUtilsV12.GetServiceClients(serviceUri, credential), logger)
{
}

/// <summary>
/// Create instance of factory for Azure devices
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,31 @@ public struct ServiceClients
}

internal static ServiceClients GetServiceClients(string connectionString)
{
var (aggressiveOptions, defaultOptions, withRetriesOptions) = GetBlobClientOptions();

return new ServiceClients()
{
Default = new BlobServiceClient(connectionString, defaultOptions),
Aggressive = new BlobServiceClient(connectionString, aggressiveOptions),
WithRetries = new BlobServiceClient(connectionString, withRetriesOptions),
};
}

internal static ServiceClients GetServiceClients(string serviceUrl, TokenCredential credential)
{
var (aggressiveOptions, defaultOptions, withRetriesOptions) = GetBlobClientOptions();
var serviceUri = new Uri(serviceUrl);

return new ServiceClients()
{
Default = new BlobServiceClient(serviceUri, credential, defaultOptions),
Aggressive = new BlobServiceClient(serviceUri, credential, aggressiveOptions),
WithRetries = new BlobServiceClient(serviceUri, credential, withRetriesOptions),
};
}

private static (BlobClientOptions aggressiveOptions, BlobClientOptions defaultOptions, BlobClientOptions withRetriesOptions) GetBlobClientOptions()
{
var aggressiveOptions = new BlobClientOptions();
aggressiveOptions.Retry.MaxRetries = 0;
Expand All @@ -54,12 +79,7 @@ internal static ServiceClients GetServiceClients(string connectionString)
withRetriesOptions.Retry.Delay = TimeSpan.FromSeconds(1);
withRetriesOptions.Retry.MaxDelay = TimeSpan.FromSeconds(30);

return new ServiceClients()
{
Default = new BlobServiceClient(connectionString, defaultOptions),
Aggressive = new BlobServiceClient(connectionString, aggressiveOptions),
WithRetries = new BlobServiceClient(connectionString, withRetriesOptions),
};
return (aggressiveOptions, defaultOptions, withRetriesOptions);
}

public struct ContainerClients
Expand Down
Loading
Loading