Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Adds LINQ Support for FirstOrDefault #4286

Merged
merged 10 commits into from
Jan 30, 2024
28 changes: 28 additions & 0 deletions Microsoft.Azure.Cosmos/src/Linq/ClientOperation.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
//------------------------------------------------------------

namespace Microsoft.Azure.Cosmos.Linq
{
/// <summary>
/// Represents the operation that needs to be performed on the client side.
/// </summary>
/// <remarks>
/// At the moment, this enum only represents non-aggregate scalar operations such as FirstOrDefault.
/// Furthermore, scalar operations are disallowed in sub-expressions/sub-queries.
/// With these restrictations, enum is sufficient, but in future for a larger surface area we may need
/// to use an object model like ClientQL to represent these operations better.
/// </remarks>
internal enum ClientOperation
adityasa marked this conversation as resolved.
Show resolved Hide resolved
{
/// <summary>
/// Indicates that client does not need to perform any operation on query results.
/// </summary>
None,

/// <summary>
/// Indicates that the client needs to perform FirstOrDefault on the query results returned by the backend.
/// </summary>
FirstOrDefault
}
}
3 changes: 1 addition & 2 deletions Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ namespace Microsoft.Azure.Cosmos.Linq
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Diagnostics;
Expand Down Expand Up @@ -774,7 +773,7 @@ public static Task<Response<int>> SumAsync(
return ResponseHelperAsync(source.Sum());
}

return ((CosmosLinqQueryProvider)source.Provider).ExecuteAggregateAsync<int?>(
return cosmosLinqQueryProvider.ExecuteAggregateAsync<int?>(
Expression.Call(
GetMethodInfoOf<IQueryable<int?>, int?>(Queryable.Sum),
source.Expression),
Expand Down
96 changes: 78 additions & 18 deletions Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace Microsoft.Azure.Cosmos.Linq
using Microsoft.Azure.Cosmos.Serializer;
using Microsoft.Azure.Cosmos.Tracing;
using Newtonsoft.Json;
using Debug = System.Diagnostics.Debug;

/// <summary>
/// This is the entry point for LINQ query creation/execution, it generate query provider, implements IOrderedQueryable.
Expand Down Expand Up @@ -108,7 +109,12 @@ public IEnumerator<T> GetEnumerator()
" use GetItemQueryIterator to execute asynchronously");
}

FeedIterator<T> localFeedIterator = this.CreateFeedIterator(false);
FeedIterator<T> localFeedIterator = this.CreateFeedIterator(false, out ClientOperation clientOperation);
Debug.Assert(
clientOperation == ClientOperation.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{clientOperation}'");

while (localFeedIterator.HasMoreResults)
{
#pragma warning disable VSTHRD002 // Avoid problematic synchronous waits
Expand All @@ -133,7 +139,7 @@ IEnumerator IEnumerable.GetEnumerator()

public override string ToString()
{
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions);
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions).SqlQuerySpec;
if (querySpec != null)
{
return JsonConvert.SerializeObject(querySpec);
Expand All @@ -144,20 +150,36 @@ public override string ToString()

public QueryDefinition ToQueryDefinition(IDictionary<object, string> parameters = null)
{
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions, parameters);
return QueryDefinition.CreateFromQuerySpec(querySpec);
LinqQuery linqQuery = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions, parameters);
ClientOperation clientOperation = linqQuery.ClientOperation;
Debug.Assert(
clientOperation == ClientOperation.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{clientOperation}'");

return QueryDefinition.CreateFromQuerySpec(linqQuery.SqlQuerySpec);
}

public FeedIterator<T> ToFeedIterator()
{
return new FeedIteratorInlineCore<T>(this.CreateFeedIterator(true),
this.container.ClientContext);
FeedIterator<T> iterator = this.CreateFeedIterator(true, out ClientOperation clientOperation);
Debug.Assert(
clientOperation == ClientOperation.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{clientOperation}'");

return new FeedIteratorInlineCore<T>(iterator, this.container.ClientContext);
}

public FeedIterator ToStreamIterator()
{
return new FeedIteratorInlineCore(this.CreateStreamIterator(true),
this.container.ClientContext);
FeedIterator iterator = this.CreateStreamIterator(true, out ClientOperation clientOperation);
Debug.Assert(
clientOperation == ClientOperation.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{clientOperation}'");

return new FeedIteratorInlineCore(iterator, this.container.ClientContext);
}

public void Dispose()
Expand All @@ -180,15 +202,18 @@ internal async Task<Response<T>> AggregateResultAsync(CancellationToken cancella
List<T> result = new List<T>();
Headers headers = new Headers();

FeedIterator<T> localFeedIterator = this.CreateFeedIterator(isContinuationExpected: false);
FeedIteratorInternal<T> localFeedIteratorInternal = (FeedIteratorInternal<T>)localFeedIterator;
FeedIteratorInlineCore<T> localFeedIterator = this.CreateFeedIterator(isContinuationExpected: false, clientOperation: out ClientOperation clientOperation);
Debug.Assert(
clientOperation == ClientOperation.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{clientOperation}'");

ITrace rootTrace;
using (rootTrace = Trace.GetRootTrace("Aggregate LINQ Operation"))
{
while (localFeedIterator.HasMoreResults)
{
FeedResponse<T> response = await localFeedIteratorInternal.ReadNextAsync(rootTrace, cancellationToken);
FeedResponse<T> response = await localFeedIterator.ReadNextAsync(rootTrace, cancellationToken);
headers.RequestCharge += response.RequestCharge;
result.AddRange(response);
}
Expand All @@ -202,23 +227,58 @@ internal async Task<Response<T>> AggregateResultAsync(CancellationToken cancella
null);
}

private FeedIteratorInternal CreateStreamIterator(bool isContinuationExcpected)
internal T ExecuteScalar()
{
FeedIteratorInlineCore<T> localFeedIterator = this.CreateFeedIterator(isContinuationExpected: false, out ClientOperation clientOperation);
Headers headers = new Headers();

List<T> result = new List<T>();
ITrace rootTrace;
using (rootTrace = Trace.GetRootTrace("Scalar LINQ Operation"))
{
while (localFeedIterator.HasMoreResults)
{
FeedResponse<T> response = localFeedIterator.ReadNextAsync(rootTrace, cancellationToken: default).GetAwaiter().GetResult();
headers.RequestCharge += response.RequestCharge;
result.AddRange(response);
}
}

switch (clientOperation)
{
case ClientOperation.FirstOrDefault:
System.Diagnostics.Debug.Assert(result.Count <= 1, "CosmosLinqQuery Assert!", "At most one result is expected!");
return result.FirstOrDefault();

// ExecuteScalar gets called when (sync) aggregates such as Max, Min, Sum are invoked on the IQueryable.
// Since query fully supprots these operations, there is no client operation involved.
// In these cases we return FirstOrDefault which handles empty/undefined/null result set from the backend.
case ClientOperation.None:
return result.FirstOrDefault();
adityasa marked this conversation as resolved.
Show resolved Hide resolved

default:
throw new InvalidOperationException($"Unsupported client operation {clientOperation}");
}
}

private FeedIteratorInternal CreateStreamIterator(bool isContinuationExcpected, out ClientOperation clientOperation)
{
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions);
LinqQuery querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions);
clientOperation = querySpec.ClientOperation;

return this.container.GetItemQueryStreamIteratorInternal(
sqlQuerySpec: querySpec,
sqlQuerySpec: querySpec.SqlQuerySpec,
isContinuationExcpected: isContinuationExcpected,
continuationToken: this.continuationToken,
feedRange: null,
requestOptions: this.cosmosQueryRequestOptions);
}

private FeedIterator<T> CreateFeedIterator(bool isContinuationExpected)
private FeedIteratorInlineCore<T> CreateFeedIterator(bool isContinuationExpected, out ClientOperation clientOperation)
{
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions);

FeedIteratorInternal streamIterator = this.CreateStreamIterator(isContinuationExpected);
FeedIteratorInternal streamIterator = this.CreateStreamIterator(
isContinuationExpected,
out clientOperation);
return new FeedIteratorInlineCore<T>(new FeedIteratorCore<T>(
streamIterator,
this.responseFactory.CreateQueryFeedUserTypeResponse<T>),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace Microsoft.Azure.Cosmos.Linq
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
Expand Down Expand Up @@ -60,6 +61,7 @@ public IQueryable<TElement> CreateQuery<TElement>(Expression expression)

public IQueryable CreateQuery(Expression expression)
{
// ISSUE-TODO-adityasa-2024/1/26 - Investigate if reflection usage can be removed.
Type expressionType = TypeSystem.GetElementType(expression.Type);
Type documentQueryType = typeof(CosmosLinqQuery<bool>).GetGenericTypeDefinition().MakeGenericType(expressionType);
return (IQueryable)Activator.CreateInstance(
Expand All @@ -76,6 +78,7 @@ public IQueryable CreateQuery(Expression expression)

public TResult Execute<TResult>(Expression expression)
{
// ISSUE-TODO-adityasa-2024/1/26 - We should be able to delegate the implementation to ExecuteAggregateAsync method below by providing an Async implementation of ExecuteScalar.
Type cosmosQueryType = typeof(CosmosLinqQuery<bool>).GetGenericTypeDefinition().MakeGenericType(typeof(TResult));
CosmosLinqQuery<TResult> cosmosLINQQuery = (CosmosLinqQuery<TResult>)Activator.CreateInstance(
cosmosQueryType,
Expand All @@ -88,7 +91,7 @@ public TResult Execute<TResult>(Expression expression)
this.allowSynchronousQueryExecution,
this.linqSerializerOptions);
this.onExecuteScalarQueryCallback?.Invoke(cosmosLINQQuery);
return cosmosLINQQuery.ToList().FirstOrDefault();
return cosmosLINQQuery.ExecuteScalar();
}

//Sync execution of query via direct invoke on IQueryProvider.
Expand Down
6 changes: 3 additions & 3 deletions Microsoft.Azure.Cosmos/src/Linq/DocumentQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,10 @@ IEnumerator IEnumerable.GetEnumerator()

public override string ToString()
{
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression);
if (querySpec != null)
LinqQuery querySpec = DocumentQueryEvaluator.Evaluate(this.Expression);
if (querySpec.SqlQuerySpec != null)
{
return JsonConvert.SerializeObject(querySpec);
return JsonConvert.SerializeObject(querySpec.SqlQuerySpec);
}

return new Uri(this.client.ServiceEndpoint, this.documentsFeedOrDatabaseLink).ToString();
Expand Down
17 changes: 9 additions & 8 deletions Microsoft.Azure.Cosmos/src/Linq/DocumentQueryEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ internal static class DocumentQueryEvaluator
{
private const string SQLMethod = "AsSQL";

public static SqlQuerySpec Evaluate(
public static LinqQuery Evaluate(
Expression expression,
CosmosLinqSerializerOptionsInternal linqSerializerOptions = null,
IDictionary<object, string> parameters = null)
Expand Down Expand Up @@ -51,7 +51,7 @@ public static bool IsTransformExpression(Expression expression)
/// foreach(Database db in client.CreateDatabaseQuery()) {}
/// </summary>
/// <param name="expression"></param>
private static SqlQuerySpec HandleEmptyQuery(ConstantExpression expression)
private static LinqQuery HandleEmptyQuery(ConstantExpression expression)
{
if (expression.Value == null)
{
Expand All @@ -69,11 +69,12 @@ private static SqlQuerySpec HandleEmptyQuery(ConstantExpression expression)
ClientResources.BadQuery_InvalidExpression,
expression.ToString()));
}

//No query specified.
return null;
return new LinqQuery(sqlQuerySpec: null, clientOperation: ClientOperation.None);
}

private static SqlQuerySpec HandleMethodCallExpression(
private static LinqQuery HandleMethodCallExpression(
MethodCallExpression expression,
IDictionary<object, string> parameters,
CosmosLinqSerializerOptionsInternal linqSerializerOptions = null)
Expand All @@ -100,7 +101,7 @@ private static SqlQuerySpec HandleMethodCallExpression(
/// foreach(string record in client.CreateDocumentQuery().Navigate("Raw JQuery"))
/// </summary>
/// <param name="expression"></param>
private static SqlQuerySpec HandleAsSqlTransformExpression(MethodCallExpression expression)
private static LinqQuery HandleAsSqlTransformExpression(MethodCallExpression expression)
{
Expression paramExpression = expression.Arguments[1];

Expand All @@ -122,7 +123,7 @@ private static SqlQuerySpec HandleAsSqlTransformExpression(MethodCallExpression
}
}

private static SqlQuerySpec GetSqlQuerySpec(object value)
private static LinqQuery GetSqlQuerySpec(object value)
{
if (value == null)
{
Expand All @@ -133,11 +134,11 @@ private static SqlQuerySpec GetSqlQuerySpec(object value)
}
else if (value.GetType() == typeof(SqlQuerySpec))
{
return (SqlQuerySpec)value;
return new LinqQuery((SqlQuerySpec)value, ClientOperation.None);
}
else if (value.GetType() == typeof(string))
{
return new SqlQuerySpec((string)value);
return new LinqQuery(new SqlQuerySpec((string)value), ClientOperation.None);
}
else
{
Expand Down
Loading
Loading