forked from npgsql/efcore.pg
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Infrastructure for custom aggregate translation
Takes care of the built-in aggregates only for now (Min/Max/etc.) Part of npgsql#727
- Loading branch information
Showing
14 changed files
with
490 additions
and
153 deletions.
There are no files selected for viewing
5 changes: 1 addition & 4 deletions
5
src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlDbFunctionsExtensions.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 0 additions & 4 deletions
4
src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlMultirangeDbFunctionsExtensions.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 1 addition & 3 deletions
4
src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlRangeDbFunctionsExtensions.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 1 addition & 3 deletions
4
src/EFCore.PG/Extensions/DbFunctionsExtensions/NpgsqlTrigramsDbFunctionsExtensions.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
17 changes: 17 additions & 0 deletions
17
...re.PG/Query/ExpressionTranslators/Internal/NpgsqlAggregateMethodCallTranslatorProvider.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal; | ||
|
||
public class NpgsqlAggregateMethodCallTranslatorProvider : RelationalAggregateMethodCallTranslatorProvider | ||
{ | ||
public NpgsqlAggregateMethodCallTranslatorProvider(RelationalAggregateMethodCallTranslatorProviderDependencies dependencies) | ||
: base(dependencies) | ||
{ | ||
var sqlExpressionFactory = (NpgsqlSqlExpressionFactory)dependencies.SqlExpressionFactory; | ||
var typeMappingSource = dependencies.RelationalTypeMappingSource; | ||
|
||
AddTranslators( | ||
new IAggregateMethodCallTranslator[] | ||
{ | ||
new NpgsqlQueryableAggregateMethodTranslator(sqlExpressionFactory, typeMappingSource) | ||
}); | ||
} | ||
} |
169 changes: 169 additions & 0 deletions
169
...FCore.PG/Query/ExpressionTranslators/Internal/NpgsqlQueryableAggregateMethodTranslator.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
using static Npgsql.EntityFrameworkCore.PostgreSQL.Utilities.Statics; | ||
|
||
namespace Npgsql.EntityFrameworkCore.PostgreSQL.Query.ExpressionTranslators.Internal; | ||
|
||
public class NpgsqlQueryableAggregateMethodTranslator : IAggregateMethodCallTranslator | ||
{ | ||
private readonly NpgsqlSqlExpressionFactory _sqlExpressionFactory; | ||
private readonly IRelationalTypeMappingSource _typeMappingSource; | ||
|
||
public NpgsqlQueryableAggregateMethodTranslator( | ||
NpgsqlSqlExpressionFactory sqlExpressionFactory, | ||
IRelationalTypeMappingSource typeMappingSource) | ||
{ | ||
_sqlExpressionFactory = sqlExpressionFactory; | ||
_typeMappingSource = typeMappingSource; | ||
} | ||
|
||
public virtual SqlExpression? Translate( | ||
MethodInfo method, | ||
EnumerableExpression source, | ||
IReadOnlyList<SqlExpression> arguments, | ||
IDiagnosticsLogger<DbLoggerCategory.Query> logger) | ||
{ | ||
if (method.DeclaringType == typeof(Queryable)) | ||
{ | ||
var methodInfo = method.IsGenericMethod | ||
? method.GetGenericMethodDefinition() | ||
: method; | ||
switch (methodInfo.Name) | ||
{ | ||
case nameof(Queryable.Average) | ||
when (QueryableMethods.IsAverageWithoutSelector(methodInfo) | ||
|| QueryableMethods.IsAverageWithSelector(methodInfo)) | ||
&& source.Selector is SqlExpression averageSqlExpression: | ||
var averageInputType = averageSqlExpression.Type; | ||
if (averageInputType == typeof(int) | ||
|| averageInputType == typeof(long)) | ||
{ | ||
averageSqlExpression = _sqlExpressionFactory.ApplyDefaultTypeMapping( | ||
_sqlExpressionFactory.Convert(averageSqlExpression, typeof(double))); | ||
} | ||
|
||
return averageInputType == typeof(float) | ||
? _sqlExpressionFactory.Convert( | ||
_sqlExpressionFactory.AggregateFunction( | ||
"AVG", | ||
new[] { averageSqlExpression }, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
typeof(double)), | ||
averageSqlExpression.Type, | ||
averageSqlExpression.TypeMapping) | ||
: _sqlExpressionFactory.AggregateFunction( | ||
"AVG", | ||
new[] { averageSqlExpression }, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
averageSqlExpression.Type, | ||
averageSqlExpression.TypeMapping); | ||
|
||
// PostgreSQL COUNT() always returns bigint, so we need to downcast to int | ||
case nameof(Queryable.Count) | ||
when methodInfo == QueryableMethods.CountWithoutPredicate | ||
|| methodInfo == QueryableMethods.CountWithPredicate: | ||
var countSqlExpression = (source.Selector as SqlExpression) ?? _sqlExpressionFactory.Fragment("*"); | ||
return _sqlExpressionFactory.Convert( | ||
_sqlExpressionFactory.ApplyDefaultTypeMapping( | ||
_sqlExpressionFactory.AggregateFunction( | ||
"COUNT", | ||
new[] { countSqlExpression }, | ||
nullable: false, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
typeof(long))), | ||
typeof(int), _typeMappingSource.FindMapping(typeof(int))); | ||
|
||
case nameof(Queryable.LongCount) | ||
when methodInfo == QueryableMethods.LongCountWithoutPredicate | ||
|| methodInfo == QueryableMethods.LongCountWithPredicate: | ||
var longCountSqlExpression = (source.Selector as SqlExpression) ?? _sqlExpressionFactory.Fragment("*"); | ||
return _sqlExpressionFactory.ApplyDefaultTypeMapping( | ||
_sqlExpressionFactory.AggregateFunction( | ||
"COUNT", | ||
new[] { longCountSqlExpression }, | ||
nullable: false, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
typeof(long))); | ||
|
||
case nameof(Queryable.Max) | ||
when (methodInfo == QueryableMethods.MaxWithoutSelector | ||
|| methodInfo == QueryableMethods.MaxWithSelector) | ||
&& source.Selector is SqlExpression maxSqlExpression: | ||
return _sqlExpressionFactory.AggregateFunction( | ||
"MAX", | ||
new[] { maxSqlExpression }, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
maxSqlExpression.Type, | ||
maxSqlExpression.TypeMapping); | ||
|
||
case nameof(Queryable.Min) | ||
when (methodInfo == QueryableMethods.MinWithoutSelector | ||
|| methodInfo == QueryableMethods.MinWithSelector) | ||
&& source.Selector is SqlExpression minSqlExpression: | ||
return _sqlExpressionFactory.AggregateFunction( | ||
"MIN", | ||
new[] { minSqlExpression }, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
minSqlExpression.Type, | ||
minSqlExpression.TypeMapping); | ||
|
||
// In PostgreSQL SUM() doesn't return the same type as its argument for smallint, int and bigint. | ||
// Cast to get the same type. | ||
// http://www.postgresql.org/docs/current/static/functions-aggregate.html | ||
case nameof(Queryable.Sum) | ||
when (QueryableMethods.IsSumWithoutSelector(methodInfo) | ||
|| QueryableMethods.IsSumWithSelector(methodInfo)) | ||
&& source.Selector is SqlExpression sumSqlExpression: | ||
var sumInputType = sumSqlExpression.Type; | ||
|
||
// Note that there is no Sum over short in LINQ | ||
if (sumInputType == typeof(int)) | ||
{ | ||
return _sqlExpressionFactory.Convert( | ||
_sqlExpressionFactory.AggregateFunction( | ||
"SUM", | ||
new[] { sumSqlExpression }, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
typeof(long)), | ||
sumInputType, | ||
sumSqlExpression.TypeMapping); | ||
} | ||
|
||
if (sumInputType == typeof(long)) | ||
{ | ||
return _sqlExpressionFactory.Convert( | ||
_sqlExpressionFactory.AggregateFunction( | ||
"SUM", | ||
new[] { sumSqlExpression }, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
typeof(decimal)), | ||
sumInputType, | ||
sumSqlExpression.TypeMapping); | ||
} | ||
|
||
return _sqlExpressionFactory.AggregateFunction( | ||
"SUM", | ||
new[] { sumSqlExpression }, | ||
nullable: true, | ||
argumentsPropagateNullability: FalseArrays[1], | ||
source, | ||
sumInputType, | ||
sumSqlExpression.TypeMapping); | ||
} | ||
} | ||
|
||
return null; | ||
} | ||
} |
Oops, something went wrong.