Skip to content

Commit

Permalink
Implement full parentheses management
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Mar 19, 2023
1 parent fd831f7 commit 2703bf4
Show file tree
Hide file tree
Showing 35 changed files with 348 additions and 322 deletions.
152 changes: 138 additions & 14 deletions src/EFCore.PG/Query/Internal/NpgsqlQuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -438,16 +438,16 @@ protected virtual Expression VisitPostgresBinary(PostgresBinaryExpression binary
{
Check.NotNull(binaryExpression, nameof(binaryExpression));

var requiresBrackets = RequiresBrackets(binaryExpression.Left);
var requiresParentheses = RequiresParentheses(binaryExpression, binaryExpression.Left);

if (requiresBrackets)
if (requiresParentheses)
{
Sql.Append("(");
}

Visit(binaryExpression.Left);

if (requiresBrackets)
if (requiresParentheses)
{
Sql.Append(")");
}
Expand Down Expand Up @@ -516,16 +516,16 @@ binaryExpression.Right.TypeMapping is NpgsqlArrayTypeMapping arrayMapping &&
})
.Append(" ");

requiresBrackets = RequiresBrackets(binaryExpression.Right);
requiresParentheses = RequiresParentheses(binaryExpression, binaryExpression.Right);

if (requiresBrackets)
if (requiresParentheses)
{
Sql.Append("(");
}

Visit(binaryExpression.Right);

if (requiresBrackets)
if (requiresParentheses)
{
Sql.Append(")");
}
Expand Down Expand Up @@ -902,16 +902,16 @@ public virtual Expression VisitUnknownBinary(PostgresUnknownBinaryExpression unk
{
Check.NotNull(unknownBinaryExpression, nameof(unknownBinaryExpression));

var requiresBrackets = RequiresBrackets(unknownBinaryExpression.Left);
var requiresParentheses = RequiresParentheses(unknownBinaryExpression, unknownBinaryExpression.Left);

if (requiresBrackets)
if (requiresParentheses)
{
Sql.Append("(");
}

Visit(unknownBinaryExpression.Left);

if (requiresBrackets)
if (requiresParentheses)
{
Sql.Append(")");
}
Expand All @@ -921,16 +921,16 @@ public virtual Expression VisitUnknownBinary(PostgresUnknownBinaryExpression unk
.Append(unknownBinaryExpression.Operator)
.Append(" ");

requiresBrackets = RequiresBrackets(unknownBinaryExpression.Right);
requiresParentheses = RequiresParentheses(unknownBinaryExpression, unknownBinaryExpression.Right);

if (requiresBrackets)
if (requiresParentheses)
{
Sql.Append("(");
}

Visit(unknownBinaryExpression.Right);

if (requiresBrackets)
if (requiresParentheses)
{
Sql.Append(")");
}
Expand Down Expand Up @@ -1023,8 +1023,132 @@ public virtual Expression VisitPostgresFunction(PostgresFunctionExpression e)
return e;
}

private static bool RequiresBrackets(SqlExpression expression)
=> expression is SqlBinaryExpression || expression is LikeExpression || expression is PostgresBinaryExpression;
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override bool RequiresParentheses(SqlExpression outerExpression, SqlExpression innerExpression)
{
switch (innerExpression)
{
// PG doesn't support ~-, -~, ~~, -- so we add parentheses
case SqlUnaryExpression innerUnary when outerExpression is SqlUnaryExpression outerUnary
&& (innerUnary.OperatorType is ExpressionType.Negate || innerUnary.OperatorType is ExpressionType.Not && innerUnary.Type != typeof(bool))
&& (outerUnary.OperatorType is ExpressionType.Negate || outerUnary.OperatorType is ExpressionType.Not && outerUnary.Type != typeof(bool)):
return true;

// Copy paste of QuerySqlGenerator.RequiresParentheses for SqlBinaryExpression
case PostgresBinaryExpression innerBinary:
{
// If the provider defined precedence for the two expression, use that
if (TryGetOperatorInfo(outerExpression, out var outerPrecedence, out var isOuterAssociative)
&& TryGetOperatorInfo(innerExpression, out var innerPrecedence, out _))
{
return outerPrecedence.CompareTo(innerPrecedence) switch
{
> 0 => true,
< 0 => false,

// If both operators have the same precedence, add parentheses unless they're the same operator, and
// that operator is associative (e.g. a + b + c)
0 => outerExpression is not PostgresBinaryExpression outerBinary
|| outerBinary.OperatorType != innerBinary.OperatorType
|| !isOuterAssociative
// Arithmetic operators on floating points aren't associative, because of rounding errors.
|| outerExpression.Type == typeof(float)
|| outerExpression.Type == typeof(double)
|| innerExpression.Type == typeof(float)
|| innerExpression.Type == typeof(double)
};
}

// Otherwise always parenthesize for safety
return true;
}

default:
return base.RequiresParentheses(outerExpression, innerExpression);
}
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override bool TryGetOperatorInfo(SqlExpression expression, out int precedence, out bool isAssociative)
{
// See https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-PRECEDENCE
(precedence, isAssociative) = expression switch
{
// TODO: Exponent => 1300

SqlBinaryExpression sqlBinaryExpression => sqlBinaryExpression.OperatorType switch
{
// Multiplication, division, modulo
ExpressionType.Multiply => (1200, true),
ExpressionType.Divide => (1200, false),
ExpressionType.Modulo => (1200, false),

// Addition, subtraction (binary)
ExpressionType.Add => (1100, true),
ExpressionType.Subtract => (1100, false),

// All other native and user-defined operators => 1000
ExpressionType.LeftShift => (1000, true),
ExpressionType.RightShift => (1000, true),
ExpressionType.And when sqlBinaryExpression.Type != typeof(bool) => (1000, true),
ExpressionType.Or when sqlBinaryExpression.Type != typeof(bool) => (1000, true),

// Comparison operators
ExpressionType.Equal => (800, false),
ExpressionType.NotEqual => (800, false),
ExpressionType.LessThan => (800, false),
ExpressionType.LessThanOrEqual => (800, false),
ExpressionType.GreaterThan => (800, false),
ExpressionType.GreaterThanOrEqual => (800, false),

// Logical operators
ExpressionType.AndAlso => (500, true),
ExpressionType.OrElse => (500, true),
ExpressionType.And when sqlBinaryExpression.Type == typeof(bool) => (500, true),
ExpressionType.Or when sqlBinaryExpression.Type == typeof(bool) => (500, true),

_ => default,
},

SqlUnaryExpression sqlUnaryExpression => sqlUnaryExpression.OperatorType switch
{
ExpressionType.Convert => (1600, false),
ExpressionType.Negate => (1400, false),
ExpressionType.Not when sqlUnaryExpression.Type != typeof(bool) => (1000, false),
ExpressionType.Equal => (700, false), // IS NULL
ExpressionType.NotEqual => (700, false), // IS NOT NULL
ExpressionType.Not when sqlUnaryExpression.Type == typeof(bool) => (600, false),

_ => default,
},

// There's an "any other operator" category in the PG operator precedence table, we assign that a numeric value of 1000.
// TODO: Some operators here may be associative
PostgresBinaryExpression => (1000, false),

CollateExpression => (1000, false),
AtTimeZoneExpression => (1000, false),
InExpression => (900, false),
PostgresJsonTraversalExpression => (1000, false),
PostgresArrayIndexExpression => (1500, false),
PostgresAllExpression or PostgresAnyExpression => (800, false),
LikeExpression or PostgresILikeExpression or PostgresRegexMatchExpression => (900, false),

_ => default,
};

return precedence != default;
}

private void GenerateList<T>(
IReadOnlyList<T> items,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public override async Task Delete_where_hierarchy(bool async)
AssertSql(
"""
DELETE FROM "Animals" AS a
WHERE (a."CountryId" = 1) AND (a."Name" = 'Great spotted kiwi')
WHERE a."CountryId" = 1 AND a."Name" = 'Great spotted kiwi'
""");
}

Expand All @@ -31,7 +31,7 @@ public override async Task Delete_where_hierarchy_derived(bool async)
AssertSql(
"""
DELETE FROM "Animals" AS a
WHERE ((a."Discriminator" = 'Kiwi') AND (a."CountryId" = 1)) AND (a."Name" = 'Great spotted kiwi')
WHERE a."Discriminator" = 'Kiwi' AND a."CountryId" = 1 AND a."Name" = 'Great spotted kiwi'
""");
}

Expand All @@ -45,7 +45,7 @@ DELETE FROM "Countries" AS c
WHERE (
SELECT count(*)::int
FROM "Animals" AS a
WHERE ((a."CountryId" = 1) AND (c."Id" = a."CountryId")) AND (a."CountryId" > 0)) > 0
WHERE a."CountryId" = 1 AND c."Id" = a."CountryId" AND a."CountryId" > 0) > 0
""");
}

Expand All @@ -59,7 +59,7 @@ DELETE FROM "Countries" AS c
WHERE (
SELECT count(*)::int
FROM "Animals" AS a
WHERE (((a."CountryId" = 1) AND (c."Id" = a."CountryId")) AND (a."Discriminator" = 'Kiwi')) AND (a."CountryId" > 0)) > 0
WHERE a."CountryId" = 1 AND c."Id" = a."CountryId" AND a."Discriminator" = 'Kiwi' AND a."CountryId" > 0) > 0
""");
}

Expand All @@ -84,16 +84,16 @@ public override async Task Delete_GroupBy_Where_Select_First_3(bool async)
AssertSql(
"""
DELETE FROM "Animals" AS a
WHERE (a."CountryId" = 1) AND EXISTS (
WHERE a."CountryId" = 1 AND EXISTS (
SELECT 1
FROM "Animals" AS a0
WHERE a0."CountryId" = 1
GROUP BY a0."CountryId"
HAVING (count(*)::int < 3) AND ((
HAVING count(*)::int < 3 AND (
SELECT a1."Id"
FROM "Animals" AS a1
WHERE (a1."CountryId" = 1) AND (a0."CountryId" = a1."CountryId")
LIMIT 1) = a."Id"))
WHERE a1."CountryId" = 1 AND a0."CountryId" = a1."CountryId"
LIMIT 1) = a."Id")
""");
}

Expand All @@ -119,7 +119,7 @@ SELECT 1
FROM (
SELECT a0."Id", a0."CountryId", a0."Discriminator", a0."Name", a0."Species", a0."EagleId", a0."IsFlightless", a0."Group", a0."FoundOn"
FROM "Animals" AS a0
WHERE (a0."CountryId" = 1) AND (a0."Name" = 'Great spotted kiwi')
WHERE a0."CountryId" = 1 AND a0."Name" = 'Great spotted kiwi'
ORDER BY a0."Name" NULLS FIRST
LIMIT @__p_1 OFFSET @__p_0
) AS t
Expand All @@ -135,7 +135,7 @@ public override async Task Update_where_hierarchy(bool async)
"""
UPDATE "Animals" AS a
SET "Name" = 'Animal'
WHERE (a."CountryId" = 1) AND (a."Name" = 'Great spotted kiwi')
WHERE a."CountryId" = 1 AND a."Name" = 'Great spotted kiwi'
""");
}

Expand All @@ -154,7 +154,7 @@ public override async Task Update_where_hierarchy_derived(bool async)
"""
UPDATE "Animals" AS a
SET "Name" = 'Kiwi'
WHERE ((a."Discriminator" = 'Kiwi') AND (a."CountryId" = 1)) AND (a."Name" = 'Great spotted kiwi')
WHERE a."Discriminator" = 'Kiwi' AND a."CountryId" = 1 AND a."Name" = 'Great spotted kiwi'
""");
}

Expand All @@ -169,7 +169,7 @@ public override async Task Update_where_using_hierarchy(bool async)
WHERE (
SELECT count(*)::int
FROM "Animals" AS a
WHERE ((a."CountryId" = 1) AND (c."Id" = a."CountryId")) AND (a."CountryId" > 0)) > 0
WHERE a."CountryId" = 1 AND c."Id" = a."CountryId" AND a."CountryId" > 0) > 0
""");
}

Expand All @@ -184,7 +184,7 @@ public override async Task Update_where_using_hierarchy_derived(bool async)
WHERE (
SELECT count(*)::int
FROM "Animals" AS a
WHERE (((a."CountryId" = 1) AND (c."Id" = a."CountryId")) AND (a."Discriminator" = 'Kiwi')) AND (a."CountryId" > 0)) > 0
WHERE a."CountryId" = 1 AND c."Id" = a."CountryId" AND a."Discriminator" = 'Kiwi' AND a."CountryId" > 0) > 0
""");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public override async Task Delete_where_hierarchy_derived(bool async)
AssertSql(
"""
DELETE FROM "Animals" AS a
WHERE (a."Discriminator" = 'Kiwi') AND (a."Name" = 'Great spotted kiwi')
WHERE a."Discriminator" = 'Kiwi' AND a."Name" = 'Great spotted kiwi'
""");
}

Expand All @@ -45,7 +45,7 @@ DELETE FROM "Countries" AS c
WHERE (
SELECT count(*)::int
FROM "Animals" AS a
WHERE (c."Id" = a."CountryId") AND (a."CountryId" > 0)) > 0
WHERE c."Id" = a."CountryId" AND a."CountryId" > 0) > 0
""");
}

Expand All @@ -59,7 +59,7 @@ DELETE FROM "Countries" AS c
WHERE (
SELECT count(*)::int
FROM "Animals" AS a
WHERE ((c."Id" = a."CountryId") AND (a."Discriminator" = 'Kiwi')) AND (a."CountryId" > 0)) > 0
WHERE c."Id" = a."CountryId" AND a."Discriminator" = 'Kiwi' AND a."CountryId" > 0) > 0
""");
}

Expand Down Expand Up @@ -130,11 +130,11 @@ WHERE EXISTS (
SELECT 1
FROM "Animals" AS a0
GROUP BY a0."CountryId"
HAVING (count(*)::int < 3) AND ((
HAVING count(*)::int < 3 AND (
SELECT a1."Id"
FROM "Animals" AS a1
WHERE a0."CountryId" = a1."CountryId"
LIMIT 1) = a."Id"))
LIMIT 1) = a."Id")
""");
}

Expand Down Expand Up @@ -165,7 +165,7 @@ public override async Task Update_where_hierarchy_derived(bool async)
"""
UPDATE "Animals" AS a
SET "Name" = 'Kiwi'
WHERE (a."Discriminator" = 'Kiwi') AND (a."Name" = 'Great spotted kiwi')
WHERE a."Discriminator" = 'Kiwi' AND a."Name" = 'Great spotted kiwi'
""");
}

Expand All @@ -180,7 +180,7 @@ public override async Task Update_where_using_hierarchy(bool async)
WHERE (
SELECT count(*)::int
FROM "Animals" AS a
WHERE (c."Id" = a."CountryId") AND (a."CountryId" > 0)) > 0
WHERE c."Id" = a."CountryId" AND a."CountryId" > 0) > 0
""");
}

Expand All @@ -195,7 +195,7 @@ public override async Task Update_where_using_hierarchy_derived(bool async)
WHERE (
SELECT count(*)::int
FROM "Animals" AS a
WHERE ((c."Id" = a."CountryId") AND (a."Discriminator" = 'Kiwi')) AND (a."CountryId" > 0)) > 0
WHERE c."Id" = a."CountryId" AND a."Discriminator" = 'Kiwi' AND a."CountryId" > 0) > 0
""");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ WHERE EXISTS (
SELECT 1
FROM "Posts" AS p0
LEFT JOIN "Blogs" AS b ON p0."BlogId" = b."Id"
WHERE (b."Title" IS NOT NULL AND (b."Title" LIKE 'Arthur%')) AND (p0."Id" = p."Id"))
WHERE b."Title" IS NOT NULL AND b."Title" LIKE 'Arthur%' AND p0."Id" = p."Id")
""");
}

Expand Down
Loading

0 comments on commit 2703bf4

Please sign in to comment.