Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DisableServerCertificateValidation: Fixes Default HttpClient to honor DisableServerCertificateValidation #4294

53 changes: 37 additions & 16 deletions Microsoft.Azure.Cosmos/src/CosmosClientOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -664,11 +664,46 @@ internal Protocol ConnectionProtocol
/// <para>
/// Emulator: To ignore SSL Certificate please suffix connectionstring with "DisableServerCertificateValidation=True;".
/// When CosmosClientOptions.HttpClientFactory is used, SSL certificate needs to be handled appropriately.
/// NOTE: DO NOT use this flag in production (only for emulator)
/// NOTE: DO NOT use the `DisableServerCertificateValidation` flag in production (only for emulator)
/// </para>
/// </remarks>
public Func<X509Certificate2, X509Chain, SslPolicyErrors, bool> ServerCertificateCustomValidationCallback { get; set; }


/// <summary>
/// Real call back that will be hooked down-stream to the transport clients (both http and tcp).
/// NOTE: All down stream real-usage should come through this API only and not through the public API.
///
/// Test hook DisableServerCertificateValidationInvocationCallback
/// - When configured will invoke it when ever custom validation is done
/// </summary>
internal Func<X509Certificate2, X509Chain, SslPolicyErrors, bool> GetServerCertificateCustomValidationCallback()
{
if (this.DisableServerCertificateValidation)
{
if (this.DisableServerCertificateValidationInvocationCallback == null)
{
return this.ServerCertificateCustomValidationCallback ?? ((_, _, _) => true);
}
else
{
return (X509Certificate2 cert, X509Chain chain, SslPolicyErrors policyErrors) =>
{
bool bValidationResult = true;
if (this.ServerCertificateCustomValidationCallback != null)
{
bValidationResult = this.ServerCertificateCustomValidationCallback(cert, chain, policyErrors);
}
this.DisableServerCertificateValidationInvocationCallback?.Invoke();
return bValidationResult;
};
}
}

return this.ServerCertificateCustomValidationCallback;
}

internal Action DisableServerCertificateValidationInvocationCallback { get; set; }

/// <summary>
/// API type for the account
/// </summary>
Expand Down Expand Up @@ -773,7 +808,6 @@ internal virtual ConnectionPolicy GetConnectionPolicy(int clientId)
this.ValidateDirectTCPSettings();
this.ValidateLimitToEndpointSettings();
this.ValidatePartitionLevelFailoverSettings();
this.ValidateAndSetServerCallbackSettings();

ConnectionPolicy connectionPolicy = new ConnectionPolicy()
{
Expand Down Expand Up @@ -947,19 +981,6 @@ private void ValidatePartitionLevelFailoverSettings()
}
}

private void ValidateAndSetServerCallbackSettings()
{
if (this.DisableServerCertificateValidation && this.ServerCertificateCustomValidationCallback != null)
{
throw new ArgumentException($"Cannot specify {nameof(this.DisableServerCertificateValidation)} flag in Connection String and {nameof(this.ServerCertificateCustomValidationCallback)}. Only one can be set.");
}

if (this.DisableServerCertificateValidation)
{
this.ServerCertificateCustomValidationCallback = (_, _, _) => true;
}
}

private void ValidateDirectTCPSettings()
{
string settingName = string.Empty;
Expand Down
4 changes: 2 additions & 2 deletions Microsoft.Azure.Cosmos/src/Resource/ClientContextCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ internal static CosmosClientContext Create(
HttpMessageHandler httpMessageHandler = CosmosHttpClientCore.CreateHttpClientHandler(
clientOptions.GatewayModeMaxConnectionLimit,
clientOptions.WebProxy,
clientOptions.ServerCertificateCustomValidationCallback);
clientOptions.GetServerCertificateCustomValidationCallback());

DocumentClient documentClient = new DocumentClient(
cosmosClient.Endpoint,
Expand All @@ -81,7 +81,7 @@ internal static CosmosClientContext Create(
handler: httpMessageHandler,
sessionContainer: clientOptions.SessionContainer,
cosmosClientId: cosmosClient.Id,
remoteCertificateValidationCallback: ClientContextCore.SslCustomValidationCallBack(clientOptions.ServerCertificateCustomValidationCallback),
remoteCertificateValidationCallback: ClientContextCore.SslCustomValidationCallBack(clientOptions.GetServerCertificateCustomValidationCallback()),
cosmosClientTelemetryOptions: clientOptions.CosmosClientTelemetryOptions,
chaosInterceptorFactory: clientOptions.ChaosInterceptorFactory);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,44 @@ public async Task Verify_CertificateCallBackGetsCalled_ForTCP_HTTP()
}
}

[TestMethod]
public async Task Verify_DisableCertificateValidationCallBackGetsCalled_ForTCP_HTTP()
{
int counter = 0;
CosmosClientOptions options = new CosmosClientOptions()
{
DisableServerCertificateValidationInvocationCallback = () => counter++,
};

string authKey = ConfigurationManager.AppSettings["MasterKey"];
string endpoint = ConfigurationManager.AppSettings["GatewayEndpoint"];
string connectionStringWithSslDisable = $"AccountEndpoint={endpoint};AccountKey={authKey};DisableServerCertificateValidation=true";

using CosmosClient cosmosClient = new CosmosClient(connectionStringWithSslDisable, options);

string databaseName = Guid.NewGuid().ToString();
string databaseId = Guid.NewGuid().ToString();
Cosmos.Database database = null;

try
{
//HTTP callback
Trace.TraceInformation("Creating test database and container");
database = await cosmosClient.CreateDatabaseAsync(databaseId);
Cosmos.Container container = await database.CreateContainerAsync(Guid.NewGuid().ToString(), "/id");

// TCP callback
ToDoActivity item = ToDoActivity.CreateRandomToDoActivity();
ResponseMessage responseMessage = await container.CreateItemStreamAsync(TestCommon.SerializerCore.ToStream(item), new Cosmos.PartitionKey(item.id));
}
finally
{
await database?.DeleteStreamAsync();
}

Assert.IsTrue(counter >= 2);
}

[TestMethod]
public void SqlQuerySpecSerializationTest()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -938,27 +938,56 @@ public void TestServerCertificatesValidationCallback(string connStr, bool expect

if (expectedIgnoreCertificateFlag)
{
Assert.IsNotNull(cosmosClient.ClientOptions.ServerCertificateCustomValidationCallback);
Assert.IsNull(cosmosClient.ClientOptions.ServerCertificateCustomValidationCallback);
Assert.IsNull(cosmosClient.DocumentClient.ConnectionPolicy.ServerCertificateCustomValidationCallback);
Assert.IsTrue(cosmosClient.ClientOptions.DisableServerCertificateValidation);
Assert.IsTrue(cosmosClient
.ClientOptions
.ServerCertificateCustomValidationCallback(x509Certificate2, x509Chain, sslPolicyErrors));
.GetServerCertificateCustomValidationCallback()(x509Certificate2, x509Chain, sslPolicyErrors));


CosmosHttpClient httpClient = cosmosClient.DocumentClient.httpClient;
SocketsHttpHandler socketsHttpHandler = (SocketsHttpHandler)httpClient.HttpMessageHandler;

RemoteCertificateValidationCallback httpClientRemoreCertValidationCallback = socketsHttpHandler.SslOptions.RemoteCertificateValidationCallback;
Assert.IsNotNull(httpClientRemoreCertValidationCallback);

Assert.IsTrue(httpClientRemoreCertValidationCallback(this, x509Certificate2, x509Chain, sslPolicyErrors));
}
else
{
Assert.IsNull(cosmosClient.ClientOptions.ServerCertificateCustomValidationCallback);
Assert.IsFalse(cosmosClient.ClientOptions.DisableServerCertificateValidation);

Assert.IsNull(cosmosClient.DocumentClient.ConnectionPolicy.ServerCertificateCustomValidationCallback);
}
}

[TestMethod]
[DataRow(ConnectionString + "DisableServerCertificateValidation=true;")]
[ExpectedException(typeof(ArgumentException))]
public void TestServerCertificatesValidationWithDisableSSLFlagTrue(string connStr)
[DataRow(ConnectionString + "DisableServerCertificateValidation=true;", true)]
[DataRow(ConnectionString + "DisableServerCertificateValidation=true;", false)]
public void TestServerCertificatesValidationWithDisableSSLFlagTrue(string connStr, bool setCallback)
{
CosmosClientOptions options = new CosmosClientOptions
{
ServerCertificateCustomValidationCallback = (certificate, chain, sslPolicyErrors) => true
ServerCertificateCustomValidationCallback = (certificate, chain, sslPolicyErrors) => true,
};

if (setCallback)
{
options.DisableServerCertificateValidationInvocationCallback = () => { };
}

CosmosClient cosmosClient = new CosmosClient(connStr, options);
Assert.IsTrue(cosmosClient.ClientOptions.DisableServerCertificateValidation);
Assert.AreEqual(cosmosClient.ClientOptions.ServerCertificateCustomValidationCallback, options.ServerCertificateCustomValidationCallback);
Assert.AreEqual(cosmosClient.DocumentClient.ConnectionPolicy.ServerCertificateCustomValidationCallback, options.ServerCertificateCustomValidationCallback);

CosmosHttpClient httpClient = cosmosClient.DocumentClient.httpClient;
SocketsHttpHandler socketsHttpHandler = (SocketsHttpHandler)httpClient.HttpMessageHandler;

RemoteCertificateValidationCallback? httpClientRemoreCertValidationCallback = socketsHttpHandler.SslOptions.RemoteCertificateValidationCallback;
Assert.IsNotNull(httpClientRemoreCertValidationCallback);
}

private class TestWebProxy : IWebProxy
Expand Down
Loading