Skip to content

Commit

Permalink
Fix Contains within SQL Server aggregate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Dec 1, 2023
1 parent 17fa62f commit 9def3db
Show file tree
Hide file tree
Showing 8 changed files with 388 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,14 @@ private StructuralTypeReferenceExpression BindComplexProperty(
}
}

private bool TryTranslateAggregateMethodCall(
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
[EntityFrameworkInternal]
protected virtual bool TryTranslateAggregateMethodCall(
MethodCallExpression methodCallExpression,
[NotNullWhen(true)] out SqlExpression? translation)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,16 @@ public override bool IsBuffering
=> base.IsBuffering
|| (QuerySplittingBehavior == EntityFrameworkCore.QuerySplittingBehavior.SplitQuery
&& !_multipleActiveResultSetsEnabled);

/// <summary>
/// Tracks whether translation is currently within the argument of an aggregate method (e.g. MAX, COUNT); SQL Server does not
/// allow subqueries and aggregates in that context.
/// </summary>
/// <remarks>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </remarks>
public bool InAggregateFunction { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
/// </summary>
public class SqlServerQueryableMethodTranslatingExpressionVisitor : RelationalQueryableMethodTranslatingExpressionVisitor
{
private readonly QueryCompilationContext _queryCompilationContext;
private readonly SqlServerQueryCompilationContext _queryCompilationContext;
private readonly IRelationalTypeMappingSource _typeMappingSource;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly int _sqlServerCompatibilityLevel;
Expand All @@ -34,7 +34,7 @@ public class SqlServerQueryableMethodTranslatingExpressionVisitor : RelationalQu
public SqlServerQueryableMethodTranslatingExpressionVisitor(
QueryableMethodTranslatingExpressionVisitorDependencies dependencies,
RelationalQueryableMethodTranslatingExpressionVisitorDependencies relationalDependencies,
QueryCompilationContext queryCompilationContext,
SqlServerQueryCompilationContext queryCompilationContext,
ISqlServerSingletonOptions sqlServerSingletonOptions)
: base(dependencies, relationalDependencies, queryCompilationContext)
{
Expand Down Expand Up @@ -121,6 +121,103 @@ protected override Expression VisitExtension(Expression extensionExpression)
return base.VisitExtension(extensionExpression);
}

#region Aggregate functions

// We override these for SQL Server to add tracking whether we're inside an aggregate function context, since SQL Server doesn't
// support subqueries (or aggregates) within them.

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateAverage(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateAverage(source, selector, resultType);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateSum(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateSum(source, selector, resultType);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateCount(ShapedQueryExpression source, LambdaExpression? predicate)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateCount(source, predicate);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateLongCount(ShapedQueryExpression source, LambdaExpression? predicate)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateLongCount(source, predicate);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateMax(source, selector, resultType);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateMin(source, selector, resultType);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

#endregion Aggregate functions

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -315,6 +412,47 @@ static IEnumerable<INavigation> GetAllNavigationsInHierarchy(IEntityType entityT
.SelectMany(t => t.GetDeclaredNavigations());
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateContains(ShapedQueryExpression source, Expression item)
{
var translatedSource = base.TranslateContains(source, item);

// SQL Server does not support subqueries inside aggregate functions (e.g. COUNT(SELECT * FROM OPENJSON(@p)...)).
// As a result, we track whether we're within an aggregate function; if we are, and we see the regular Contains translation
// (which uses IN with an OPENJSON subquery - incompatible), we transform it to the old-style IN+constants translation (as if a
// low SQL Server compatibility level were defined)
if (_queryCompilationContext.InAggregateFunction
&& translatedSource is not null
&& TryGetProjection(translatedSource, out var projection)
&& projection is InExpression
{
Item: var translatedItem,
Subquery:
{
Tables: [SqlServerOpenJsonExpression { Arguments: [SqlParameterExpression parameter] } openJsonExpression],
GroupBy: [],
Having: null,
IsDistinct: false,
Limit: null,
Offset: null,
Orderings: [],
Projection: [{ Expression: ColumnExpression { Name: "value", Table: var projectionColumnTable } }]
}
}
&& projectionColumnTable == openJsonExpression)
{
var newInExpression = _sqlExpressionFactory.In(translatedItem, parameter);
return source.UpdateQueryExpression(_sqlExpressionFactory.Select(newInExpression));
}

return translatedSource;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -504,6 +642,29 @@ protected override bool IsValidSelectExpressionForExecuteUpdate(
return false;
}

private bool TryGetProjection(ShapedQueryExpression shapedQueryExpression, [NotNullWhen(true)] out SqlExpression? projection)
{
var shaperExpression = shapedQueryExpression.ShaperExpression;
// No need to check ConvertChecked since this is convert node which we may have added during projection
if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression
&& unaryExpression.Operand.Type.IsNullableType()
&& unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type)
{
shaperExpression = unaryExpression.Operand;
}

if (shapedQueryExpression.QueryExpression is SelectExpression selectExpression
&& shaperExpression is ProjectionBindingExpression projectionBindingExpression
&& selectExpression.GetProjection(projectionBindingExpression) is SqlExpression sqlExpression)
{
projection = sqlExpression;
return true;
}

projection = null;
return false;
}

private sealed class TemporalAnnotationApplyingExpressionVisitor : ExpressionVisitor
{
private readonly Func<TableExpression, TableExpressionBase> _annotationApplyingFunc;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ public SqlServerQueryableMethodTranslatingExpressionVisitorFactory(
/// </summary>
public virtual QueryableMethodTranslatingExpressionVisitor Create(QueryCompilationContext queryCompilationContext)
=> new SqlServerQueryableMethodTranslatingExpressionVisitor(
Dependencies, RelationalDependencies, queryCompilationContext, _sqlServerSingletonOptions);
Dependencies, RelationalDependencies, (SqlServerQueryCompilationContext)queryCompilationContext, _sqlServerSingletonOptions);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
/// </summary>
public class SqlServerSqlTranslatingExpressionVisitor : RelationalSqlTranslatingExpressionVisitor
{
private readonly QueryCompilationContext _queryCompilationContext;
private readonly SqlServerQueryCompilationContext _queryCompilationContext;
private readonly ISqlExpressionFactory _sqlExpressionFactory;

private static readonly HashSet<string> DateTimeDataTypes
Expand Down Expand Up @@ -73,7 +73,7 @@ private static readonly MethodInfo StringContainsMethodInfo
/// </summary>
public SqlServerSqlTranslatingExpressionVisitor(
RelationalSqlTranslatingExpressionVisitorDependencies dependencies,
QueryCompilationContext queryCompilationContext,
SqlServerQueryCompilationContext queryCompilationContext,
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor)
: base(dependencies, queryCompilationContext, queryableMethodTranslatingExpressionVisitor)
{
Expand Down Expand Up @@ -432,6 +432,28 @@ private static string EscapeLikePattern(string pattern)
return builder.ToString();
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override bool TryTranslateAggregateMethodCall(
MethodCallExpression methodCallExpression,
[NotNullWhen(true)] out SqlExpression? translation)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;

#pragma warning disable EF1001 // Internal EF Core API usage.
var result = base.TryTranslateAggregateMethodCall(methodCallExpression, out translation);
#pragma warning restore EF1001 // Internal EF Core API usage.

_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;

return result;
}

private Expression TranslateByteArrayElementAccess(Expression array, Expression index, Type resultType)
{
var visitedArray = Visit(array);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ public virtual RelationalSqlTranslatingExpressionVisitor Create(
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor)
=> new SqlServerSqlTranslatingExpressionVisitor(
Dependencies,
queryCompilationContext,
(SqlServerQueryCompilationContext)queryCompilationContext,
queryableMethodTranslatingExpressionVisitor);
}
Original file line number Diff line number Diff line change
Expand Up @@ -1923,4 +1923,89 @@ public virtual Task Not_Any_false(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => !c.Orders.Any(o => false)).Select(c => c.CustomerID));

[ConditionalTheory] // #32374
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_inside_aggregate_function_with_GroupBy(bool async)
{
var cities = new[] { "London", "Berlin" };

return AssertQuery(
async,
ss => ss.Set<Customer>()
.GroupBy(c => c.Country)
.Select(g => g.Count(c => cities.Contains(c.City))));
}

[ConditionalTheory] // #32374
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_inside_Average_without_GroupBy(bool async)
{
var cities = new[] { "London", "Berlin" };

return AssertAverage(
async,
ss => ss.Set<Customer>(),
selector: c => cities.Contains(c.City) ? 1 : 0);
}

[ConditionalTheory] // #32374
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_inside_Sum_without_GroupBy(bool async)
{
var cities = new[] { "London", "Berlin" };

return AssertSum(
async,
ss => ss.Set<Customer>(),
selector: c => cities.Contains(c.City) ? 1 : 0);
}

[ConditionalTheory] // #32374
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_inside_Count_without_GroupBy(bool async)
{
var cities = new[] { "London", "Berlin" };

return AssertCount(
async,
ss => ss.Set<Customer>(),
predicate: c => cities.Contains(c.City));
}

[ConditionalTheory] // #32374
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_inside_LongCount_without_GroupBy(bool async)
{
var cities = new[] { "London", "Berlin" };

return AssertLongCount(
async,
ss => ss.Set<Customer>(),
predicate: c => cities.Contains(c.City));
}

[ConditionalTheory] // #32374
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_inside_Max_without_GroupBy(bool async)
{
var cities = new[] { "London", "Berlin" };

return AssertMax(
async,
ss => ss.Set<Customer>(),
selector: c => cities.Contains(c.City) ? 1 : 0);
}

[ConditionalTheory] // #32374
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_inside_Min_without_GroupBy(bool async)
{
var cities = new[] { "London", "Berlin" };

return AssertMin(
async,
ss => ss.Set<Customer>(),
selector: c => cities.Contains(c.City) ? 1 : 0);
}
}
Loading

0 comments on commit 9def3db

Please sign in to comment.