Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Introduce EnumerableExpression which is not SQL token #27969

Merged
merged 1 commit into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions src/EFCore.Relational/Query/EnumerableExpression.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Runtime.CompilerServices;

namespace Microsoft.EntityFrameworkCore.Query.SqlExpressions;

/// <summary>
/// <para>
/// An expression that represents an enumerable or group translated from chain over a grouping element.
/// </para>
/// <para>
/// This type is typically used by database providers (and other extensions). It is generally
/// not used in application code.
/// </para>
/// </summary>
public class EnumerableExpression : Expression, IPrintableExpression
{
private readonly List<OrderingExpression> _orderings = new();

/// <summary>
/// Creates a new instance of the <see cref="EnumerableExpression" /> class.
/// </summary>
/// <param name="selector">The underlying sql expression being enumerated.</param>
public EnumerableExpression(Expression selector)
{
Selector = selector;
}

/// <summary>
/// The underlying expression being enumerated.
/// </summary>
public virtual Expression Selector { get; private set; }

/// <summary>
/// The value indicating if distinct operator is applied on the enumerable or not.
/// </summary>
public virtual bool IsDistinct { get; private set; }

/// <summary>
/// The value indicating any predicate applied on the enumerable.
/// </summary>
public virtual SqlExpression? Predicate { get; private set; }

/// <summary>
/// The list of orderings to be applied to the enumerable.
/// </summary>
public virtual IReadOnlyList<OrderingExpression> Orderings => _orderings;


/// <summary>
/// Applies new selector to the <see cref="EnumerableExpression" />.
/// </summary>
public virtual void ApplySelector(Expression expression)
{
Selector = expression;
}

/// <summary>
/// Applies DISTINCT operator to the selector of the <see cref="EnumerableExpression" />.
/// </summary>
public virtual void ApplyDistinct()
{
IsDistinct = true;
}

/// <summary>
/// Applies filter predicate to the <see cref="EnumerableExpression" />.
/// </summary>
/// <param name="sqlExpression">An expression to use for filtering.</param>
public virtual void ApplyPredicate(SqlExpression sqlExpression)
{
if (sqlExpression is SqlConstantExpression sqlConstant
&& sqlConstant.Value is bool boolValue
&& boolValue)
{
return;
}

Predicate = Predicate == null
? sqlExpression
: new SqlBinaryExpression(
ExpressionType.AndAlso,
Predicate,
sqlExpression,
typeof(bool),
sqlExpression.TypeMapping);
}

/// <summary>
/// Applies ordering to the <see cref="EnumerableExpression" />. This overwrites any previous ordering specified.
/// </summary>
/// <param name="orderingExpression">An ordering expression to use for ordering.</param>
public virtual void ApplyOrdering(OrderingExpression orderingExpression)
{
_orderings.Clear();
AppendOrdering(orderingExpression);
}

/// <summary>
/// Appends ordering to the existing orderings of the <see cref="EnumerableExpression" />.
/// </summary>
/// <param name="orderingExpression">An ordering expression to use for ordering.</param>
public virtual void AppendOrdering(OrderingExpression orderingExpression)
{
if (!_orderings.Any(o => o.Expression.Equals(orderingExpression.Expression)))
{
_orderings.Add(orderingExpression.Update(orderingExpression.Expression));
}
}

/// <inheritdoc />
protected override Expression VisitChildren(ExpressionVisitor visitor)
=> throw new InvalidOperationException(
CoreStrings.VisitIsNotAllowed($"{nameof(EnumerableExpression)}.{nameof(VisitChildren)}"));

/// <inheritdoc />
public override ExpressionType NodeType => ExpressionType.Extension;

/// <inheritdoc />
public override Type Type => typeof(IEnumerable<>).MakeGenericType(Selector.Type);

/// <inheritdoc />
public virtual void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.AppendLine(nameof(EnumerableExpression) + ":");
using (expressionPrinter.Indent())
{
expressionPrinter.Append("Selector: ");
expressionPrinter.Visit(Selector);
expressionPrinter.AppendLine();
if (IsDistinct)
{
expressionPrinter.AppendLine($"IsDistinct: {IsDistinct}");
}

if (Predicate != null)
{
expressionPrinter.Append("Predicate: ");
expressionPrinter.Visit(Predicate);
expressionPrinter.AppendLine();
}

if (Orderings.Count > 0)
{
expressionPrinter.Append("Orderings: ");
expressionPrinter.VisitCollection(Orderings);
expressionPrinter.AppendLine();
}
}
}

/// <inheritdoc />
public override bool Equals(object? obj) => ReferenceEquals(this, obj);

/// <inheritdoc />
public override int GetHashCode() => RuntimeHelpers.GetHashCode(this);
}
35 changes: 10 additions & 25 deletions src/EFCore.Relational/Query/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -506,31 +506,6 @@ protected override Expression VisitSqlBinary(SqlBinaryExpression sqlBinaryExpres
return sqlBinaryExpression;
}

/// <inheritdoc />
protected override Expression VisitSqlEnumerable(SqlEnumerableExpression sqlEnumerableExpression)
{
if (sqlEnumerableExpression.Orderings.Count != 0)
{
// TODO: Throw error here because we don't know how to print orderings.
// Though providers can override this method and generate orderings if they have a way to print it.
throw new InvalidOperationException();
}

if (sqlEnumerableExpression.IsDistinct)
{
_relationalCommandBuilder.Append("DISTINCT (");
}

Visit(sqlEnumerableExpression.SqlExpression);

if (sqlEnumerableExpression.IsDistinct)
{
_relationalCommandBuilder.Append(")");
}

return sqlEnumerableExpression;
}

/// <inheritdoc />
protected override Expression VisitSqlConstant(SqlConstantExpression sqlConstantExpression)
{
Expand Down Expand Up @@ -634,6 +609,16 @@ protected override Expression VisitCollate(CollateExpression collateExpression)
return collateExpression;
}

/// <inheritdoc />
protected override Expression VisitDistinct(DistinctExpression distinctExpression)
{
_relationalCommandBuilder.Append("DISTINCT (");
Visit(distinctExpression.Operand);
_relationalCommandBuilder.Append(")");

return distinctExpression;
}

/// <inheritdoc />
protected override Expression VisitCase(CaseExpression caseExpression)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1461,7 +1461,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
private ShapedQueryExpression? TranslateAggregateWithPredicate(
ShapedQueryExpression source,
LambdaExpression? predicate,
Func<SqlEnumerableExpression, SqlExpression?> aggregateTranslator,
Func<SqlExpression, SqlExpression?> aggregateTranslator,
Type resultType)
{
var selectExpression = (SelectExpression)source.QueryExpression;
Expand All @@ -1480,7 +1480,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape

HandleGroupByForAggregate(selectExpression, eraseProjection: true);

var translation = aggregateTranslator(new SqlEnumerableExpression(_sqlExpressionFactory.Fragment("*"), distinct: false, null));
var translation = aggregateTranslator(_sqlExpressionFactory.Fragment("*"));
if (translation == null)
{
return null;
Expand All @@ -1500,7 +1500,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
private ShapedQueryExpression? TranslateAggregateWithSelector(
ShapedQueryExpression source,
LambdaExpression? selector,
Func<SqlEnumerableExpression, SqlExpression?> aggregateTranslator,
Func<SqlExpression, SqlExpression?> aggregateTranslator,
bool throwWhenEmpty,
Type resultType)
{
Expand Down Expand Up @@ -1541,7 +1541,7 @@ private static Expression MatchShaperNullabilityForSetOperation(Expression shape
}
}

var projection = aggregateTranslator(new SqlEnumerableExpression(translatedSelector, distinct: false, null));
var projection = aggregateTranslator(translatedSelector);
if (projection == null)
{
return null;
Expand Down
Loading