From 9def3dbd4ccdfa3d21a4fb80f3c5fe58d2472966 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Thu, 30 Nov 2023 17:24:17 +0100 Subject: [PATCH] Fix Contains within SQL Server aggregate functions Fixes #32374 --- ...lationalSqlTranslatingExpressionVisitor.cs | 9 +- .../SqlServerQueryCompilationContext.cs | 12 ++ ...yableMethodTranslatingExpressionVisitor.cs | 165 +++++++++++++++++- ...thodTranslatingExpressionVisitorFactory.cs | 2 +- ...qlServerSqlTranslatingExpressionVisitor.cs | 26 ++- ...rSqlTranslatingExpressionVisitorFactory.cs | 2 +- ...orthwindAggregateOperatorsQueryTestBase.cs | 85 +++++++++ ...indAggregateOperatorsQuerySqlServerTest.cs | 94 ++++++++++ 8 files changed, 388 insertions(+), 7 deletions(-) diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index d81479b5577..59b3a8c882a 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -1423,7 +1423,14 @@ private StructuralTypeReferenceExpression BindComplexProperty( } } - private bool TryTranslateAggregateMethodCall( + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + [EntityFrameworkInternal] + protected virtual bool TryTranslateAggregateMethodCall( MethodCallExpression methodCallExpression, [NotNullWhen(true)] out SqlExpression? translation) { diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs index 98cf3777fb5..fbda6592e3b 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryCompilationContext.cs @@ -39,4 +39,16 @@ public override bool IsBuffering => base.IsBuffering || (QuerySplittingBehavior == EntityFrameworkCore.QuerySplittingBehavior.SplitQuery && !_multipleActiveResultSetsEnabled); + + /// + /// Tracks whether translation is currently within the argument of an aggregate method (e.g. MAX, COUNT); SQL Server does not + /// allow subqueries and aggregates in that context. + /// + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public bool InAggregateFunction { get; set; } } diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitor.cs index be2c39fcfe2..8dc1c8a73a7 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitor.cs @@ -18,7 +18,7 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; /// public class SqlServerQueryableMethodTranslatingExpressionVisitor : RelationalQueryableMethodTranslatingExpressionVisitor { - private readonly QueryCompilationContext _queryCompilationContext; + private readonly SqlServerQueryCompilationContext _queryCompilationContext; private readonly IRelationalTypeMappingSource _typeMappingSource; private readonly ISqlExpressionFactory _sqlExpressionFactory; private readonly int _sqlServerCompatibilityLevel; @@ -34,7 +34,7 @@ public class SqlServerQueryableMethodTranslatingExpressionVisitor : RelationalQu public SqlServerQueryableMethodTranslatingExpressionVisitor( QueryableMethodTranslatingExpressionVisitorDependencies dependencies, RelationalQueryableMethodTranslatingExpressionVisitorDependencies relationalDependencies, - QueryCompilationContext queryCompilationContext, + SqlServerQueryCompilationContext queryCompilationContext, ISqlServerSingletonOptions sqlServerSingletonOptions) : base(dependencies, relationalDependencies, queryCompilationContext) { @@ -121,6 +121,103 @@ protected override Expression VisitExtension(Expression extensionExpression) return base.VisitExtension(extensionExpression); } + #region Aggregate functions + + // We override these for SQL Server to add tracking whether we're inside an aggregate function context, since SQL Server doesn't + // support subqueries (or aggregates) within them. + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateAverage(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) + { + var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction; + _queryCompilationContext.InAggregateFunction = true; + var result = base.TranslateAverage(source, selector, resultType); + _queryCompilationContext.InAggregateFunction = previousInAggregateFunction; + return result; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateSum(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) + { + var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction; + _queryCompilationContext.InAggregateFunction = true; + var result = base.TranslateSum(source, selector, resultType); + _queryCompilationContext.InAggregateFunction = previousInAggregateFunction; + return result; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateCount(ShapedQueryExpression source, LambdaExpression? predicate) + { + var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction; + _queryCompilationContext.InAggregateFunction = true; + var result = base.TranslateCount(source, predicate); + _queryCompilationContext.InAggregateFunction = previousInAggregateFunction; + return result; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateLongCount(ShapedQueryExpression source, LambdaExpression? predicate) + { + var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction; + _queryCompilationContext.InAggregateFunction = true; + var result = base.TranslateLongCount(source, predicate); + _queryCompilationContext.InAggregateFunction = previousInAggregateFunction; + return result; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) + { + var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction; + _queryCompilationContext.InAggregateFunction = true; + var result = base.TranslateMax(source, selector, resultType); + _queryCompilationContext.InAggregateFunction = previousInAggregateFunction; + return result; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) + { + var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction; + _queryCompilationContext.InAggregateFunction = true; + var result = base.TranslateMin(source, selector, resultType); + _queryCompilationContext.InAggregateFunction = previousInAggregateFunction; + return result; + } + + #endregion Aggregate functions + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -315,6 +412,47 @@ static IEnumerable GetAllNavigationsInHierarchy(IEntityType entityT .SelectMany(t => t.GetDeclaredNavigations()); } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override ShapedQueryExpression? TranslateContains(ShapedQueryExpression source, Expression item) + { + var translatedSource = base.TranslateContains(source, item); + + // SQL Server does not support subqueries inside aggregate functions (e.g. COUNT(SELECT * FROM OPENJSON(@p)...)). + // As a result, we track whether we're within an aggregate function; if we are, and we see the regular Contains translation + // (which uses IN with an OPENJSON subquery - incompatible), we transform it to the old-style IN+constants translation (as if a + // low SQL Server compatibility level were defined) + if (_queryCompilationContext.InAggregateFunction + && translatedSource is not null + && TryGetProjection(translatedSource, out var projection) + && projection is InExpression + { + Item: var translatedItem, + Subquery: + { + Tables: [SqlServerOpenJsonExpression { Arguments: [SqlParameterExpression parameter] } openJsonExpression], + GroupBy: [], + Having: null, + IsDistinct: false, + Limit: null, + Offset: null, + Orderings: [], + Projection: [{ Expression: ColumnExpression { Name: "value", Table: var projectionColumnTable } }] + } + } + && projectionColumnTable == openJsonExpression) + { + var newInExpression = _sqlExpressionFactory.In(translatedItem, parameter); + return source.UpdateQueryExpression(_sqlExpressionFactory.Select(newInExpression)); + } + + return translatedSource; + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -504,6 +642,29 @@ protected override bool IsValidSelectExpressionForExecuteUpdate( return false; } + private bool TryGetProjection(ShapedQueryExpression shapedQueryExpression, [NotNullWhen(true)] out SqlExpression? projection) + { + var shaperExpression = shapedQueryExpression.ShaperExpression; + // No need to check ConvertChecked since this is convert node which we may have added during projection + if (shaperExpression is UnaryExpression { NodeType: ExpressionType.Convert } unaryExpression + && unaryExpression.Operand.Type.IsNullableType() + && unaryExpression.Operand.Type.UnwrapNullableType() == unaryExpression.Type) + { + shaperExpression = unaryExpression.Operand; + } + + if (shapedQueryExpression.QueryExpression is SelectExpression selectExpression + && shaperExpression is ProjectionBindingExpression projectionBindingExpression + && selectExpression.GetProjection(projectionBindingExpression) is SqlExpression sqlExpression) + { + projection = sqlExpression; + return true; + } + + projection = null; + return false; + } + private sealed class TemporalAnnotationApplyingExpressionVisitor : ExpressionVisitor { private readonly Func _annotationApplyingFunc; diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitorFactory.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitorFactory.cs index 37389e4edf2..f8c045a5e72 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitorFactory.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerQueryableMethodTranslatingExpressionVisitorFactory.cs @@ -49,5 +49,5 @@ public SqlServerQueryableMethodTranslatingExpressionVisitorFactory( /// public virtual QueryableMethodTranslatingExpressionVisitor Create(QueryCompilationContext queryCompilationContext) => new SqlServerQueryableMethodTranslatingExpressionVisitor( - Dependencies, RelationalDependencies, queryCompilationContext, _sqlServerSingletonOptions); + Dependencies, RelationalDependencies, (SqlServerQueryCompilationContext)queryCompilationContext, _sqlServerSingletonOptions); } diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs index 30aecaa7b57..44378357d38 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs @@ -17,7 +17,7 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; /// public class SqlServerSqlTranslatingExpressionVisitor : RelationalSqlTranslatingExpressionVisitor { - private readonly QueryCompilationContext _queryCompilationContext; + private readonly SqlServerQueryCompilationContext _queryCompilationContext; private readonly ISqlExpressionFactory _sqlExpressionFactory; private static readonly HashSet DateTimeDataTypes @@ -73,7 +73,7 @@ private static readonly MethodInfo StringContainsMethodInfo /// public SqlServerSqlTranslatingExpressionVisitor( RelationalSqlTranslatingExpressionVisitorDependencies dependencies, - QueryCompilationContext queryCompilationContext, + SqlServerQueryCompilationContext queryCompilationContext, QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor) : base(dependencies, queryCompilationContext, queryableMethodTranslatingExpressionVisitor) { @@ -432,6 +432,28 @@ private static string EscapeLikePattern(string pattern) return builder.ToString(); } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override bool TryTranslateAggregateMethodCall( + MethodCallExpression methodCallExpression, + [NotNullWhen(true)] out SqlExpression? translation) + { + var previousInAggregateFunction = _queryCompilationContext.InAggregateFunction; + _queryCompilationContext.InAggregateFunction = true; + +#pragma warning disable EF1001 // Internal EF Core API usage. + var result = base.TryTranslateAggregateMethodCall(methodCallExpression, out translation); +#pragma warning restore EF1001 // Internal EF Core API usage. + + _queryCompilationContext.InAggregateFunction = previousInAggregateFunction; + + return result; + } + private Expression TranslateByteArrayElementAccess(Expression array, Expression index, Type resultType) { var visitedArray = Visit(array); diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitorFactory.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitorFactory.cs index f46d424bde3..e038210d2cd 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitorFactory.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitorFactory.cs @@ -39,6 +39,6 @@ public virtual RelationalSqlTranslatingExpressionVisitor Create( QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpressionVisitor) => new SqlServerSqlTranslatingExpressionVisitor( Dependencies, - queryCompilationContext, + (SqlServerQueryCompilationContext)queryCompilationContext, queryableMethodTranslatingExpressionVisitor); } diff --git a/test/EFCore.Specification.Tests/Query/NorthwindAggregateOperatorsQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/NorthwindAggregateOperatorsQueryTestBase.cs index d320ac31868..1340464a075 100644 --- a/test/EFCore.Specification.Tests/Query/NorthwindAggregateOperatorsQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/NorthwindAggregateOperatorsQueryTestBase.cs @@ -1923,4 +1923,89 @@ public virtual Task Not_Any_false(bool async) => AssertQuery( async, ss => ss.Set().Where(c => !c.Orders.Any(o => false)).Select(c => c.CustomerID)); + + [ConditionalTheory] // #32374 + [MemberData(nameof(IsAsyncData))] + public virtual Task Contains_inside_aggregate_function_with_GroupBy(bool async) + { + var cities = new[] { "London", "Berlin" }; + + return AssertQuery( + async, + ss => ss.Set() + .GroupBy(c => c.Country) + .Select(g => g.Count(c => cities.Contains(c.City)))); + } + + [ConditionalTheory] // #32374 + [MemberData(nameof(IsAsyncData))] + public virtual Task Contains_inside_Average_without_GroupBy(bool async) + { + var cities = new[] { "London", "Berlin" }; + + return AssertAverage( + async, + ss => ss.Set(), + selector: c => cities.Contains(c.City) ? 1 : 0); + } + + [ConditionalTheory] // #32374 + [MemberData(nameof(IsAsyncData))] + public virtual Task Contains_inside_Sum_without_GroupBy(bool async) + { + var cities = new[] { "London", "Berlin" }; + + return AssertSum( + async, + ss => ss.Set(), + selector: c => cities.Contains(c.City) ? 1 : 0); + } + + [ConditionalTheory] // #32374 + [MemberData(nameof(IsAsyncData))] + public virtual Task Contains_inside_Count_without_GroupBy(bool async) + { + var cities = new[] { "London", "Berlin" }; + + return AssertCount( + async, + ss => ss.Set(), + predicate: c => cities.Contains(c.City)); + } + + [ConditionalTheory] // #32374 + [MemberData(nameof(IsAsyncData))] + public virtual Task Contains_inside_LongCount_without_GroupBy(bool async) + { + var cities = new[] { "London", "Berlin" }; + + return AssertLongCount( + async, + ss => ss.Set(), + predicate: c => cities.Contains(c.City)); + } + + [ConditionalTheory] // #32374 + [MemberData(nameof(IsAsyncData))] + public virtual Task Contains_inside_Max_without_GroupBy(bool async) + { + var cities = new[] { "London", "Berlin" }; + + return AssertMax( + async, + ss => ss.Set(), + selector: c => cities.Contains(c.City) ? 1 : 0); + } + + [ConditionalTheory] // #32374 + [MemberData(nameof(IsAsyncData))] + public virtual Task Contains_inside_Min_without_GroupBy(bool async) + { + var cities = new[] { "London", "Berlin" }; + + return AssertMin( + async, + ss => ss.Set(), + selector: c => cities.Contains(c.City) ? 1 : 0); + } } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqlServerTest.cs index 3f1bf17251e..3e0acf57fed 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindAggregateOperatorsQuerySqlServerTest.cs @@ -2902,6 +2902,100 @@ FROM [Customers] AS [c] """); } + public override async Task Contains_inside_aggregate_function_with_GroupBy(bool async) + { + await base.Contains_inside_aggregate_function_with_GroupBy(async); + + AssertSql( + """ +SELECT COUNT(CASE + WHEN [c].[City] IN (N'London', N'Berlin') THEN 1 +END) +FROM [Customers] AS [c] +GROUP BY [c].[Country] +"""); + } + + public override async Task Contains_inside_Average_without_GroupBy(bool async) + { + await base.Contains_inside_Average_without_GroupBy(async); + + AssertSql( + """ +SELECT AVG(CAST(CASE + WHEN [c].[City] IN (N'London', N'Berlin') THEN 1 + ELSE 0 +END AS float)) +FROM [Customers] AS [c] +"""); + } + + public override async Task Contains_inside_Sum_without_GroupBy(bool async) + { + await base.Contains_inside_Sum_without_GroupBy(async); + + AssertSql( + """ +SELECT COALESCE(SUM(CASE + WHEN [c].[City] IN (N'London', N'Berlin') THEN 1 + ELSE 0 +END), 0) +FROM [Customers] AS [c] +"""); + } + + public override async Task Contains_inside_Count_without_GroupBy(bool async) + { + await base.Contains_inside_Count_without_GroupBy(async); + + AssertSql( + """ +SELECT COUNT(*) +FROM [Customers] AS [c] +WHERE [c].[City] IN (N'London', N'Berlin') +"""); + } + + public override async Task Contains_inside_LongCount_without_GroupBy(bool async) + { + await base.Contains_inside_LongCount_without_GroupBy(async); + + AssertSql( + """ +SELECT COUNT_BIG(*) +FROM [Customers] AS [c] +WHERE [c].[City] IN (N'London', N'Berlin') +"""); + } + + public override async Task Contains_inside_Max_without_GroupBy(bool async) + { + await base.Contains_inside_Max_without_GroupBy(async); + + AssertSql( + """ +SELECT MAX(CASE + WHEN [c].[City] IN (N'London', N'Berlin') THEN 1 + ELSE 0 +END) +FROM [Customers] AS [c] +"""); + } + + public override async Task Contains_inside_Min_without_GroupBy(bool async) + { + await base.Contains_inside_Min_without_GroupBy(async); + + AssertSql( + """ +SELECT MIN(CASE + WHEN [c].[City] IN (N'London', N'Berlin') THEN 1 + ELSE 0 +END) +FROM [Customers] AS [c] +"""); + } + private void AssertSql(params string[] expected) => Fixture.TestSqlLoggerFactory.AssertBaseline(expected);