Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQL Server: transform aggregate functions over subqueries #34262

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ public sealed class ProjectionExpression : Expression, IRelationalQuotableExpres
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
[EntityFrameworkInternal]
public ProjectionExpression(SqlExpression expression, string alias)
{
Expression = expression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ public SelectExpression(
IReadOnlyList<OrderingExpression> orderings,
SqlExpression? offset,
SqlExpression? limit,
IReadOnlySet<string> tags,
IReadOnlyDictionary<string, IAnnotation>? annotations)
IReadOnlySet<string>? tags = null,
IReadOnlyDictionary<string, IAnnotation>? annotations = null)
: this(alias, tables.ToList(), predicate, groupBy.ToList(), having, projections.ToList(), distinct, orderings.ToList(),
offset, limit, tags.ToHashSet(), annotations, sqlAliasManager: null, isMutable: false)
offset, limit, tags?.ToHashSet() ?? [], annotations, sqlAliasManager: null, isMutable: false)
{
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Globalization;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;

/// <summary>
/// SQL Server doesn't support aggregate function invocations over subqueries, or other aggregate function invocations; this
/// postprocessor lifts such subqueries out to an OUTER APPLY/JOIN on the SELECT to work around this limitation.
/// </summary>
/// <remarks>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </remarks>
public class SqlServerAggregateOverSubqueryPostprocessor(SqlAliasManager sqlAliasManager) : ExpressionVisitor
{
private SelectExpression? _currentSelect;
private bool _inAggregateInvocation;
private bool _aggregateArgumentContainsSubquery;
private List<JoinExpressionBase>? _joinsToAdd;
private bool _isCorrelatedSubquery;
private HashSet<string>? _tableAliasesInScope;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override Expression VisitExtension(Expression node)
{
switch (node)
{
case SelectExpression select:
{
var (parentSelect, parentJoinsToAdd, parentAggregateInvocation) = (_currentSelect, _joinsToAdd, _inAggregateInvocation);
(_currentSelect, _joinsToAdd, _inAggregateInvocation) = (select, null, false);

// If _tableAliasesInScope is non-null, we're tracking which table aliases are in scope for the current subquery, to detect
// correlated vs. uncorrelated subqueries. Add and remove the select's tables to _tableAliasInScope.
SelectExpression visitedSelect;
if (_tableAliasesInScope is null)
{
visitedSelect = (SelectExpression)base.VisitExtension(node);
}
else
{
List<string> tableAliases = select.Tables.Select(t => t.UnwrapJoin().Alias).Where(a => a is not null).ToList()!;
_tableAliasesInScope.UnionWith(tableAliases);
visitedSelect = (SelectExpression)base.VisitExtension(node);
_tableAliasesInScope.ExceptWith(tableAliases);
}

// A subquery is being lifted out somewhere inside this SelectExpression; add the join.
if (_joinsToAdd is not null)
{
visitedSelect = visitedSelect.Update(
[.. visitedSelect.Tables, .. _joinsToAdd],
visitedSelect.Predicate,
visitedSelect.GroupBy,
visitedSelect.Having,
visitedSelect.Projection,
visitedSelect.Orderings,
visitedSelect.Offset,
visitedSelect.Limit);
}

(_currentSelect, _joinsToAdd, _inAggregateInvocation) = (parentSelect, parentJoinsToAdd, parentAggregateInvocation);
return visitedSelect;
}

// TODO: We currently don't represent the fact that a function is an aggregate or not; so for now we just match a few well-known
// functions. Improve this in the future.
case SqlFunctionExpression { IsBuiltIn: true } function
when function.Name.ToLower(CultureInfo.InvariantCulture) is "sum" or "avg" or "min" or "max" or "count":
{
var parentInAggregateInvocation = _inAggregateInvocation;
var parentIsCorrelatedSubquery = _isCorrelatedSubquery;
var parentTableAliasesInScope = _tableAliasesInScope;
var parentAggregateArgumentContainsSubquery = _aggregateArgumentContainsSubquery;
_inAggregateInvocation = true;
_isCorrelatedSubquery = false;
_tableAliasesInScope = new();
_aggregateArgumentContainsSubquery = false;

var result = base.VisitExtension(function);

if (_aggregateArgumentContainsSubquery)
{
// During our visitation of the aggregate function invocation, a subquery was encountered - this is our trigger to
// extract out the argument to be an OUTER APPLY/CROSS JOIN.
if (result is not SqlFunctionExpression { Instance: null, Arguments: [var argument] } visitedFunction)
{
throw new UnreachableException();
}

// Since the subquery is currently a scalar subquery (or EXISTS), its doesn't have an alias for the subquery, and may
// not have an alias on its projection either. As part of lifting it out, we need to assign both aliases, so that the
// projection can be referenced.
var subqueryAlias = sqlAliasManager.GenerateTableAlias("subquery");

SelectExpression liftedSubquery;

if (argument is ScalarSubqueryExpression { Subquery: { Projection: [var subqueryProjection] } subquery })
{
// In the regular, simple case (see else below), we simply extract the entire argument of the aggregate method,
// wrap it in a simple subquery, and add that to the containing SelectExpression.
// But if the aggregate argument happens to be a scalar subqueries directly, wrapping it in a subquery isn't needed:
// we can simply use that scalar subquery directly.

// Note that there's an assumption here that the scalar subquery being extracted out will only ever return a single
// row (and column); if it didn't, the APPLY/JOIN would cause the principal row to get duplicated, producing
// incorrect results. It shouldn't be possible to produce such a state of affairs with LINQ, and in any case,
// placing a multiple row/column-returning subquery inside ScalarSubqueryExpression is a bug - that SQL would fail
// in any case even if it weren't wrapped inside an aggregate function invocation.
if (subqueryProjection.Alias is null or "")
{
subqueryProjection = new ProjectionExpression(subqueryProjection.Expression, "value");
}

liftedSubquery = subquery
.Update(
subquery.Tables,
subquery.Predicate,
subquery.GroupBy,
subquery.Having,
[subqueryProjection],
subquery.Orderings,
subquery.Offset,
subquery.Limit)
.WithAlias(subqueryAlias);
}
else
{
#pragma warning disable EF1001 // SelectExpression constructor is internal
liftedSubquery = new SelectExpression(
subqueryAlias,
tables: Array.Empty<TableExpressionBase>(),
predicate: null,
groupBy: Array.Empty<SqlExpression>(),
having: null,
projections: new[] { new ProjectionExpression(argument, "value") },
distinct: false,
orderings: Array.Empty<OrderingExpression>(),
offset: null,
limit: null);
#pragma warning restore EF1001
}

_joinsToAdd ??= new();
_joinsToAdd.Add(
_isCorrelatedSubquery ? new OuterApplyExpression(liftedSubquery) : new CrossJoinExpression(liftedSubquery));

var projection = liftedSubquery.Projection.Single();

return visitedFunction.Update(
instance: null,
arguments:
[
new ColumnExpression(
projection.Alias, subqueryAlias, projection.Expression.Type, projection.Expression.TypeMapping,
nullable: true)
]);
}

_inAggregateInvocation = parentInAggregateInvocation;
_isCorrelatedSubquery = parentIsCorrelatedSubquery;
_tableAliasesInScope = parentTableAliasesInScope;
_aggregateArgumentContainsSubquery = parentAggregateArgumentContainsSubquery;

return result;
}

// We have a scalar subquery inside an aggregate function argument; lift it out to an OUTER APPLY/CROSS JOIN that will be added
// to the containing SELECT, and return a ColumnExpression in its place that references that OUTER APPLY/CROSS JOIN.
case ScalarSubqueryExpression or ExistsExpression or InExpression { Subquery: not null }
when _inAggregateInvocation && _currentSelect is not null:
_aggregateArgumentContainsSubquery = true;
return base.VisitExtension(node);

// If _tableAliasesInScope is non-null, we're tracking which table aliases are in scope for the current subquery, to detect
// correlated vs. uncorrelated subqueries. If we have a column referencing a table that isn't in the current scope, that means
// we're in a correlated subquery.
case ColumnExpression column when _tableAliasesInScope?.Contains(column.TableAlias) == false:
_isCorrelatedSubquery = true;
return base.VisitExtension(column);

case ShapedQueryExpression shapedQueryExpression:
shapedQueryExpression = shapedQueryExpression
.UpdateQueryExpression(Visit(shapedQueryExpression.QueryExpression))
.UpdateShaperExpression(Visit(shapedQueryExpression.ShaperExpression));
return shapedQueryExpression.UpdateShaperExpression(Visit(shapedQueryExpression.ShaperExpression));

default:
return base.VisitExtension(node);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,6 @@ public override bool IsBuffering
|| (QuerySplittingBehavior == EntityFrameworkCore.QuerySplittingBehavior.SplitQuery
&& !_multipleActiveResultSetsEnabled);

/// <summary>
/// Tracks whether translation is currently within the argument of an aggregate method (e.g. MAX, COUNT); SQL Server does not
/// allow subqueries and aggregates in that context.
/// </summary>
/// <remarks>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </remarks>
public virtual bool InAggregateFunction { get; set; }

/// <inheritdoc />
public override bool SupportsPrecompiledQuery => true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
public class SqlServerQueryTranslationPostprocessor : RelationalQueryTranslationPostprocessor
{
private readonly SqlServerJsonPostprocessor _jsonPostprocessor;
private readonly SqlServerAggregateOverSubqueryPostprocessor _aggregatePostprocessor;
private readonly SkipWithoutOrderByInSplitQueryVerifier _skipWithoutOrderByInSplitQueryVerifier = new();
private readonly SqlServerSqlTreePruner _pruner = new();

Expand All @@ -34,6 +35,7 @@ public SqlServerQueryTranslationPostprocessor(
{
_jsonPostprocessor = new SqlServerJsonPostprocessor(
relationalDependencies.TypeMappingSource, relationalDependencies.SqlExpressionFactory, queryCompilationContext.SqlAliasManager);
_aggregatePostprocessor = new SqlServerAggregateOverSubqueryPostprocessor(queryCompilationContext.SqlAliasManager);
}

/// <summary>
Expand All @@ -47,9 +49,10 @@ public override Expression Process(Expression query)
var query1 = base.Process(query);

var query2 = _jsonPostprocessor.Process(query1);
_skipWithoutOrderByInSplitQueryVerifier.Visit(query2);
var query3 = _aggregatePostprocessor.Visit(query2);
_skipWithoutOrderByInSplitQueryVerifier.Visit(query3);

return query2;
return query3;
}

/// <summary>
Expand Down
Loading