-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Transform aggregate functions over subqueries on SQL Server
Closes #34256
- Loading branch information
Showing
10 changed files
with
561 additions
and
370 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
173 changes: 173 additions & 0 deletions
173
src/EFCore.SqlServer/Query/Internal/SqlServerAggregateOverSubqueryPostprocessor.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
|
||
using System.Globalization; | ||
using Microsoft.EntityFrameworkCore.Query.SqlExpressions; | ||
|
||
namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal; | ||
|
||
/// <summary> | ||
/// SQL Server doesn't support aggregate function invocations over subqueries, or other aggregate function invocations; this | ||
/// postprocessor lifts such subqueries out to an OUTER APPLY/JOIN on the SELECT to work around this limitation. | ||
/// </summary> | ||
/// <remarks> | ||
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to | ||
/// the same compatibility standards as public APIs. It may be changed or removed without notice in | ||
/// any release. You should only use it directly in your code with extreme caution and knowing that | ||
/// doing so can result in application failures when updating to a new Entity Framework Core release. | ||
/// </remarks> | ||
public class SqlServerAggregateOverSubqueryPostprocessor(SqlAliasManager sqlAliasManager) : ExpressionVisitor | ||
{ | ||
private SelectExpression? _currentSelect; | ||
private bool _inAggregateInvocation; | ||
private List<JoinExpressionBase>? _joinsToAdd; | ||
private bool _isCorrelatedSubquery; | ||
private HashSet<string>? _tableAliasesInScope; | ||
|
||
/// <summary> | ||
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to | ||
/// the same compatibility standards as public APIs. It may be changed or removed without notice in | ||
/// any release. You should only use it directly in your code with extreme caution and knowing that | ||
/// doing so can result in application failures when updating to a new Entity Framework Core release. | ||
/// </summary> | ||
protected override Expression VisitExtension(Expression node) | ||
{ | ||
switch (node) | ||
{ | ||
case SelectExpression select: | ||
{ | ||
var (parentSelect, parentJoinsToAdd, parentAggregateInvocation) = (_currentSelect, _joinsToAdd, _inAggregateInvocation); | ||
(_currentSelect, _joinsToAdd, _inAggregateInvocation) = (select, null, false); | ||
|
||
// If _tableAliasesInScope is non-null, we're tracking which table aliases are in scope for the current subquery, to detect | ||
// correlated vs. uncorrelated subqueries. Add and remove the select's tables to _tableAliasInScope. | ||
SelectExpression visitedSelect; | ||
if (_tableAliasesInScope is null) | ||
{ | ||
visitedSelect = (SelectExpression)base.VisitExtension(node); | ||
} | ||
else | ||
{ | ||
List<string> tableAliases = select.Tables.Select(t => t.UnwrapJoin().Alias).Where(a => a is not null).ToList()!; | ||
_tableAliasesInScope.UnionWith(tableAliases); | ||
visitedSelect = (SelectExpression)base.VisitExtension(node); | ||
_tableAliasesInScope.ExceptWith(tableAliases); | ||
} | ||
|
||
// A subquery is being lifted out somewhere inside this SelectExpression; add the join. | ||
if (_joinsToAdd is not null) | ||
{ | ||
visitedSelect = visitedSelect.Update( | ||
[.. visitedSelect.Tables, .. _joinsToAdd], | ||
visitedSelect.Predicate, | ||
visitedSelect.GroupBy, | ||
visitedSelect.Having, | ||
visitedSelect.Projection, | ||
visitedSelect.Orderings, | ||
visitedSelect.Offset, | ||
visitedSelect.Limit); | ||
} | ||
|
||
(_currentSelect, _joinsToAdd, _inAggregateInvocation) = (parentSelect, parentJoinsToAdd, parentAggregateInvocation); | ||
return visitedSelect; | ||
} | ||
|
||
// TODO: We currently don't represent the fact that a function is an aggregate or not; so for now we just match a few well-known | ||
// functions. Improve this in the future. | ||
case SqlFunctionExpression function | ||
when function.Name.ToLower(CultureInfo.InvariantCulture) is "sum" or "avg" or "min" or "max" or "count": | ||
{ | ||
var parentInAggregateInvocation = _inAggregateInvocation; | ||
_inAggregateInvocation = true; | ||
var result = base.VisitExtension(function); | ||
_inAggregateInvocation = parentInAggregateInvocation; | ||
return result; | ||
} | ||
|
||
// We have a scalar subquery inside an aggregate function argument; lift it out to an OUTER APPLY/CROSS JOIN that will be added | ||
// to the containing SELECT, and return a ColumnExpression in its place that references that OUTER APPLY/CROSS JOIN. | ||
|
||
// Note that there's an assumption here that the query being lifted out will only ever return a single row (and column); | ||
// if it didn't, the APPLY/JOIN would cause the principal row to get duplicated, producing incorrect results. | ||
// It shouldn't be possible to produce such a state of affairs with LINQ, and since this is a scalar subquery, that SQL | ||
// would fail in any case even if it weren't wrapped inside an aggregate function invocation. | ||
case ScalarSubqueryExpression scalarSubquery when _inAggregateInvocation && _currentSelect is not null: | ||
return LiftSubqueryToJoin(scalarSubquery.Subquery); | ||
|
||
// EXISTS is slightly more complicated; unlike a scalar subquery, where we can just lift out the wrapped subquery (it already | ||
// returns a scalar), with EXISTS we need to conserve the ExistsExpression, pushing it down into a subquery which will become | ||
// the OUTER APPLY (which needs to return a single boolean value). | ||
#pragma warning disable EF1001 // SelectExpression constructor is internal | ||
case ExistsExpression exists when _inAggregateInvocation && _currentSelect is not null: | ||
{ | ||
var wrapperSubquery = new SelectExpression(exists, sqlAliasManager); | ||
wrapperSubquery.ApplyProjection(); | ||
return LiftSubqueryToJoin(wrapperSubquery); | ||
} | ||
|
||
case InExpression { Subquery: SelectExpression } inExpression when _inAggregateInvocation && _currentSelect is not null: | ||
{ | ||
var wrapperSubquery = new SelectExpression(inExpression, sqlAliasManager); | ||
wrapperSubquery.ApplyProjection(); | ||
return LiftSubqueryToJoin(wrapperSubquery); | ||
} | ||
#pragma warning restore EF1001 | ||
|
||
// If _tableAliasesInScope is non-null, we're tracking which table aliases are in scope for the current subquery, to detect | ||
// correlated vs. uncorrelated subqueries. If we have a column referencing a table that isn't in the current scope, that means | ||
// we're in a correlated subquery. | ||
case ColumnExpression column when _tableAliasesInScope?.Contains(column.TableAlias) == false: | ||
_isCorrelatedSubquery = true; | ||
return base.VisitExtension(column); | ||
|
||
case ShapedQueryExpression shapedQueryExpression: | ||
shapedQueryExpression = shapedQueryExpression | ||
.UpdateQueryExpression(Visit(shapedQueryExpression.QueryExpression)) | ||
.UpdateShaperExpression(Visit(shapedQueryExpression.ShaperExpression)); | ||
return shapedQueryExpression.UpdateShaperExpression(Visit(shapedQueryExpression.ShaperExpression)); | ||
|
||
default: | ||
return base.VisitExtension(node); | ||
} | ||
|
||
ColumnExpression LiftSubqueryToJoin(SelectExpression subquery) | ||
{ | ||
var (parentIsCorrelatedSubquery, parentTableAliasesInScope) = (_isCorrelatedSubquery, _tableAliasesInScope); | ||
(_isCorrelatedSubquery, _tableAliasesInScope) = (false, new()); | ||
|
||
if (Visit(subquery) is not SelectExpression { Projection: [var projection] } visitedSubquery) | ||
{ | ||
throw new UnreachableException("Invalid subquery"); | ||
} | ||
|
||
// Since the subquery is currently a scalar subquery (or EXISTS), its doesn't have an alias for the subquery, and may not have | ||
// an alias on its projection either. As part of lifting it out, we need to assign both aliases, so that the projection can be | ||
// referenced. | ||
var subqueryAlias = sqlAliasManager.GenerateTableAlias("subquery"); | ||
if (projection.Alias is null or "") | ||
{ | ||
projection = new ProjectionExpression(projection.Expression, "value"); | ||
} | ||
|
||
visitedSubquery = visitedSubquery | ||
.Update( | ||
visitedSubquery.Tables, | ||
visitedSubquery.Predicate, | ||
visitedSubquery.GroupBy, | ||
visitedSubquery.Having, | ||
[projection], | ||
visitedSubquery.Orderings, | ||
visitedSubquery.Offset, | ||
visitedSubquery.Limit) | ||
.WithAlias(subqueryAlias); | ||
|
||
_joinsToAdd ??= new(); | ||
_joinsToAdd.Add(_isCorrelatedSubquery ? new OuterApplyExpression(visitedSubquery) : new CrossJoinExpression(visitedSubquery)); | ||
|
||
(_isCorrelatedSubquery, _tableAliasesInScope) = (parentIsCorrelatedSubquery, parentTableAliasesInScope); | ||
|
||
return new ColumnExpression( | ||
projection.Alias, subqueryAlias, projection.Expression.Type, projection.Expression.TypeMapping, nullable: true); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.