Skip to content

Commit

Permalink
Fix Contains within SQL Server aggregate functions
Browse files Browse the repository at this point in the history
Fixes #32374
  • Loading branch information
roji committed Dec 1, 2023
1 parent 17fa62f commit 0eefa4b
Show file tree
Hide file tree
Showing 8 changed files with 386 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,14 @@ private StructuralTypeReferenceExpression BindComplexProperty(
}
}

private bool TryTranslateAggregateMethodCall(
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
[EntityFrameworkInternal]
protected virtual bool TryTranslateAggregateMethodCall(
MethodCallExpression methodCallExpression,
[NotNullWhen(true)] out SqlExpression? translation)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,16 @@ public override bool IsBuffering
=> base.IsBuffering
|| (QuerySplittingBehavior == EntityFrameworkCore.QuerySplittingBehavior.SplitQuery
&& !_multipleActiveResultSetsEnabled);

/// <summary>
/// Tracks whether translation is currently within the argument of an aggregate method (e.g. MAX, COUNT); SQL Server does not
/// allow subqueries and aggregates in that context.
/// </summary>
/// <remarks>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </remarks>
public bool InAggregateFunction { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
/// </summary>
public class SqlServerQueryableMethodTranslatingExpressionVisitor : RelationalQueryableMethodTranslatingExpressionVisitor
{
private readonly QueryCompilationContext _queryCompilationContext;
private readonly SqlServerQueryCompilationContext _queryCompilationContext;
private readonly IRelationalTypeMappingSource _typeMappingSource;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly int _sqlServerCompatibilityLevel;
Expand All @@ -34,7 +34,7 @@ public class SqlServerQueryableMethodTranslatingExpressionVisitor : RelationalQu
public SqlServerQueryableMethodTranslatingExpressionVisitor(
QueryableMethodTranslatingExpressionVisitorDependencies dependencies,
RelationalQueryableMethodTranslatingExpressionVisitorDependencies relationalDependencies,
QueryCompilationContext queryCompilationContext,
SqlServerQueryCompilationContext queryCompilationContext,
ISqlServerSingletonOptions sqlServerSingletonOptions)
: base(dependencies, relationalDependencies, queryCompilationContext)
{
Expand Down Expand Up @@ -121,6 +121,103 @@ protected override Expression VisitExtension(Expression extensionExpression)
return base.VisitExtension(extensionExpression);
}

#region Aggregate functions

// We override these for SQL Server to add tracking whether we're inside an aggregate function context, since SQL Server doesn't
// support subqueries (or aggregates) within them.

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateAverage(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateAverage(source, selector, resultType);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateSum(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateSum(source, selector, resultType);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateCount(ShapedQueryExpression source, LambdaExpression? predicate)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateCount(source, predicate);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateLongCount(ShapedQueryExpression source, LambdaExpression? predicate)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateLongCount(source, predicate);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateMax(source, selector, resultType);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;
var result = base.TranslateMin(source, selector, resultType);
_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;
return result;
}

#endregion Aggregate functions

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -315,6 +412,47 @@ static IEnumerable<INavigation> GetAllNavigationsInHierarchy(IEntityType entityT
.SelectMany(t => t.GetDeclaredNavigations());
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateContains(ShapedQueryExpression source, Expression item)
{
var translatedSource = base.TranslateContains(source, item);

// SQL Server does not support subqueries inside aggregate functions (e.g. COUNT(SELECT * FROM OPENJSON(@p)...)).
// As a result, we track whether we're within an aggregate function; if we are, and we see the regular Contains translation
// (which uses IN with an OPENJSON subquery - incompatible), we transform it to the old-style IN+constants translation (as if a
// low SQL Server compatibility level were defined)
if (_queryCompilationContext.InAggregateFunction
&& translatedSource is not null
&& TryGetProjection(translatedSource, out var projection)
&& projection is InExpression
{
Item: var translatedItem,
Subquery:
{
Tables: [SqlServerOpenJsonExpression { Arguments: [SqlParameterExpression parameter] } openJsonExpression],
GroupBy: [],
Having: null,
IsDistinct: false,
Limit: null,
Offset: null,
Orderings: [],
Projection: [{ Expression: ColumnExpression { Name: "value", Table: var projectionColumnTable } }]
}
}
&& projectionColumnTable == openJsonExpression)
{
var newInExpression = _sqlExpressionFactory.In(translatedItem, parameter);
return source.UpdateQueryExpression(_sqlExpressionFactory.Select(newInExpression));
}

return translatedSource;
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -504,6 +642,29 @@ protected override bool IsValidSelectExpressionForExecuteUpdate(
return false;
}

private bool TryGetProjection(ShapedQueryExpression shapedQueryExpression, [NotNullWhen(true)] out SqlExpression? projection)
{
var shaperExpression = shapedQueryExpression.ShaperExpression;
// No need to check ConvertChecked since this is convert node which we may have added during projection
if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression
&& unaryExpression.Operand.Type.IsNullableType()
&& unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type)
{
shaperExpression = unaryExpression.Operand;
}

if (shapedQueryExpression.QueryExpression is SelectExpression selectExpression
&& shaperExpression is ProjectionBindingExpression projectionBindingExpression
&& selectExpression.GetProjection(projectionBindingExpression) is SqlExpression sqlExpression)
{
projection = sqlExpression;
return true;
}

projection = null;
return false;
}

private sealed class TemporalAnnotationApplyingExpressionVisitor : ExpressionVisitor
{
private readonly Func<TableExpression, TableExpressionBase> _annotationApplyingFunc;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ public SqlServerQueryableMethodTranslatingExpressionVisitorFactory(
/// </summary>
public virtual QueryableMethodTranslatingExpressionVisitor Create(QueryCompilationContext queryCompilationContext)
=> new SqlServerQueryableMethodTranslatingExpressionVisitor(
Dependencies, RelationalDependencies, queryCompilationContext, _sqlServerSingletonOptions);
Dependencies, RelationalDependencies, (SqlServerQueryCompilationContext)queryCompilationContext, _sqlServerSingletonOptions);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
/// </summary>
public class SqlServerSqlTranslatingExpressionVisitor : RelationalSqlTranslatingExpressionVisitor
{
private readonly QueryCompilationContext _queryCompilationContext;
private readonly SqlServerQueryCompilationContext _queryCompilationContext;
private readonly ISqlExpressionFactory _sqlExpressionFactory;

private static readonly HashSet<string> DateTimeDataTypes
Expand Down Expand Up @@ -73,7 +73,7 @@ private static readonly MethodInfo StringContainsMethodInfo
/// </summary>
public SqlServerSqlTranslatingExpressionVisitor(
RelationalSqlTranslatingExpressionVisitorDependencies dependencies,
QueryCompilationContext queryCompilationContext,
SqlServerQueryCompilationContext queryCompilationContext,
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor)
: base(dependencies, queryCompilationContext, queryableMethodTranslatingExpressionVisitor)
{
Expand Down Expand Up @@ -432,6 +432,26 @@ private static string EscapeLikePattern(string pattern)
return builder.ToString();
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override bool TryTranslateAggregateMethodCall(
MethodCallExpression methodCallExpression,
[NotNullWhen(true)] out SqlExpression? translation)
{
var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction;
_queryCompilationContext.InAggregateFunction = true;

var result = base.TryTranslateAggregateMethodCall(methodCallExpression, out translation);

Check failure on line 448 in src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs

View check run for this annotation

Azure Pipelines / efcore-ci (Build Linux)

src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs#L448

src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs(448,27): error EF1001: (NETCORE_ENGINEERING_TELEMETRY=Build) Microsoft.EntityFrameworkCore.Query.RelationalSqlTranslatingExpressionVisitor.TryTranslateAggregateMethodCall is an internal API that supports the Entity Framework Core infrastructure and not subject to the same compatibility standards as public APIs. It may be changed or removed without notice in any release.

Check failure on line 448 in src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs

View check run for this annotation

Azure Pipelines / efcore-ci (Build macOS)

src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs#L448

src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs(448,27): error EF1001: (NETCORE_ENGINEERING_TELEMETRY=Build) Microsoft.EntityFrameworkCore.Query.RelationalSqlTranslatingExpressionVisitor.TryTranslateAggregateMethodCall is an internal API that supports the Entity Framework Core infrastructure and not subject to the same compatibility standards as public APIs. It may be changed or removed without notice in any release.

Check failure on line 448 in src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs

View check run for this annotation

Azure Pipelines / efcore-ci

src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs#L448

src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs(448,27): error EF1001: (NETCORE_ENGINEERING_TELEMETRY=Build) Microsoft.EntityFrameworkCore.Query.RelationalSqlTranslatingExpressionVisitor.TryTranslateAggregateMethodCall is an internal API that supports the Entity Framework Core infrastructure and not subject to the same compatibility standards as public APIs. It may be changed or removed without notice in any release.

Check failure on line 448 in src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs

View check run for this annotation

Azure Pipelines / efcore-ci

src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs#L448

src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs(448,27): error EF1001: (NETCORE_ENGINEERING_TELEMETRY=Build) Microsoft.EntityFrameworkCore.Query.RelationalSqlTranslatingExpressionVisitor.TryTranslateAggregateMethodCall is an internal API that supports the Entity Framework Core infrastructure and not subject to the same compatibility standards as public APIs. It may be changed or removed without notice in any release.

_queryCompilationContext.InAggregateFunction = previousInAggregateFunction;

return result;
}

private Expression TranslateByteArrayElementAccess(Expression array, Expression index, Type resultType)
{
var visitedArray = Visit(array);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ public virtual RelationalSqlTranslatingExpressionVisitor Create(
QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor)
=> new SqlServerSqlTranslatingExpressionVisitor(
Dependencies,
queryCompilationContext,
(SqlServerQueryCompilationContext)queryCompilationContext,
queryableMethodTranslatingExpressionVisitor);
}
Original file line number Diff line number Diff line change
Expand Up @@ -1923,4 +1923,89 @@ public virtual Task Not_Any_false(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().Where(c => !c.Orders.Any(o => false)).Select(c => c.CustomerID));

[ConditionalTheory] // #32374
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_inside_aggregate_function_with_GroupBy(bool async)
{
var cities = new[] { "London", "Berlin" };

return AssertQuery(
async,
ss => ss.Set<Customer>()
.GroupBy(c => c.Country)
.Select(g => g.Count(c => cities.Contains(c.City))));
}

[ConditionalTheory] // #32374
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_inside_Average_without_GroupBy(bool async)
{
var cities = new[] { "London", "Berlin" };

return AssertAverage(
async,
ss => ss.Set<Customer>(),
selector: c => cities.Contains(c.City) ? 1 : 0);
}

[ConditionalTheory] // #32374
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_inside_Sum_without_GroupBy(bool async)
{
var cities = new[] { "London", "Berlin" };

return AssertSum(
async,
ss => ss.Set<Customer>(),
selector: c => cities.Contains(c.City) ? 1 : 0);
}

[ConditionalTheory] // #32374
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_inside_Count_without_GroupBy(bool async)
{
var cities = new[] { "London", "Berlin" };

return AssertCount(
async,
ss => ss.Set<Customer>(),
predicate: c => cities.Contains(c.City));
}

[ConditionalTheory] // #32374
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_inside_LongCount_without_GroupBy(bool async)
{
var cities = new[] { "London", "Berlin" };

return AssertLongCount(
async,
ss => ss.Set<Customer>(),
predicate: c => cities.Contains(c.City));
}

[ConditionalTheory] // #32374
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_inside_Max_without_GroupBy(bool async)
{
var cities = new[] { "London", "Berlin" };

return AssertMax(
async,
ss => ss.Set<Customer>(),
selector: c => cities.Contains(c.City) ? 1 : 0);
}

[ConditionalTheory] // #32374
[MemberData(nameof(IsAsyncData))]
public virtual Task Contains_inside_Min_without_GroupBy(bool async)
{
var cities = new[] { "London", "Berlin" };

return AssertMin(
async,
ss => ss.Set<Customer>(),
selector: c => cities.Contains(c.City) ? 1 : 0);
}
}
Loading

0 comments on commit 0eefa4b

Please sign in to comment.