From 31392613b8dbbf804b7c313ede26e6c8131ea660 Mon Sep 17 00:00:00 2001 From: Zach Daniel Date: Mon, 7 Oct 2024 10:59:12 -0400 Subject: [PATCH] fix: properly group aggregates together fix: don't attempt to add multiple filter statements to a single aggregate --- lib/aggregate.ex | 140 ++++++++++++++++++++++++++++++++++------------- lib/expr.ex | 34 ++++++------ 2 files changed, 121 insertions(+), 53 deletions(-) diff --git a/lib/aggregate.ex b/lib/aggregate.ex index 5e08199..31ce3b3 100644 --- a/lib/aggregate.ex +++ b/lib/aggregate.ex @@ -1140,10 +1140,10 @@ defmodule AshSql.Aggregate do type when is_atom(type) -> case Ash.Resource.Info.field(related, aggregate.field).type do {:array, _} -> - false + true _ -> - true + false end _ -> @@ -1279,7 +1279,7 @@ defmodule AshSql.Aggregate do array_agg = query.__ash_bindings__.sql_behaviour.list_aggregate(aggregate.resource) - {sorted, query} = + {sorted, include_nil_filter_field, query} = if has_sort? || first_relationship.sort not in [nil, []] do {sort, binding} = if has_sort? do @@ -1315,15 +1315,34 @@ defmodule AshSql.Aggregate do query = AshSql.Bindings.merge_expr_accumulator(query, acc) - {sort_expr, query} + {sort_expr, nil, query} else question_marks = Enum.map(sort_expr, fn _ -> " ? " end) - {:ok, expr} = - Ash.Query.Function.Fragment.casted_new( - ["#{array_agg}(? ORDER BY #{question_marks}) FILTER (WHERE ? IS NOT NULL)", field] ++ - sort_expr ++ [field] - ) + {expr, include_nil_filter_field} = + if has_filter?(aggregate.query) and !is_single? do + {:ok, expr} = + Ash.Query.Function.Fragment.casted_new( + [ + "#{array_agg}(? ORDER BY #{question_marks})", + field + ] ++ + sort_expr + ) + + {expr, field} + else + {:ok, expr} = + Ash.Query.Function.Fragment.casted_new( + [ + "#{array_agg}(? ORDER BY #{question_marks}) FILTER (WHERE ? IS NOT NULL)", + field + ] ++ + sort_expr ++ [field] + ) + + {expr, nil} + end {sort_expr, acc} = AshSql.Expr.dynamic_expr(query, expr, query.__ash_bindings__, false) @@ -1331,7 +1350,7 @@ defmodule AshSql.Aggregate do query = AshSql.Bindings.merge_expr_accumulator(query, acc) - {sort_expr, query} + {sort_expr, include_nil_filter_field, query} end else case array_agg do @@ -1339,18 +1358,25 @@ defmodule AshSql.Aggregate do {Ecto.Query.dynamic( [row], fragment("array_agg(?)", ^field) - ), query} + ), nil, query} "any_value" -> {Ecto.Query.dynamic( [row], fragment("any_value(?)", ^field) - ), query} + ), nil, query} end end {query, filtered} = - filter_field(sorted, query, aggregate, relationship_path, is_single?) + filter_field( + sorted, + include_nil_filter_field, + query, + aggregate, + relationship_path, + is_single? + ) value = if array_agg == "array_agg" do @@ -1436,7 +1462,7 @@ defmodule AshSql.Aggregate do has_sort? = has_sort?(aggregate.query) - {sorted, query} = + {sorted, include_nil_filter_field, query} = if has_sort? || (first_relationship && first_relationship.sort not in [nil, []]) do {sort, binding} = if has_sort? do @@ -1464,25 +1490,38 @@ defmodule AshSql.Aggregate do "" end - expr = + {expr, include_nil_filter_field} = if aggregate.include_nil? do {:ok, expr} = Ash.Query.Function.Fragment.casted_new( ["array_agg(#{distinct}? ORDER BY #{question_marks})", field] ++ sort_expr ) - expr + {expr, nil} else - {:ok, expr} = - Ash.Query.Function.Fragment.casted_new( - [ - "array_agg(#{distinct}? ORDER BY #{question_marks}) FILTER (WHERE ? IS NOT NULL)", - field - ] ++ - sort_expr ++ [field] - ) - - expr + if has_filter?(aggregate.query) and !is_single? do + {:ok, expr} = + Ash.Query.Function.Fragment.casted_new( + [ + "array_agg(#{distinct}? ORDER BY #{question_marks})", + field + ] ++ + sort_expr ++ [field] + ) + + {expr, field} + else + {:ok, expr} = + Ash.Query.Function.Fragment.casted_new( + [ + "array_agg(#{distinct}? ORDER BY #{question_marks}) FILTER (WHERE ? IS NOT NULL)", + field + ] ++ + sort_expr ++ [field] + ) + + {expr, nil} + end end {expr, acc} = @@ -1491,23 +1530,30 @@ defmodule AshSql.Aggregate do query = AshSql.Bindings.merge_expr_accumulator(query, acc) - {expr, query} + {expr, include_nil_filter_field, query} else if Map.get(aggregate, :uniq?) do {Ecto.Query.dynamic( [row], fragment("array_agg(DISTINCT ?)", ^field) - ), query} + ), nil, query} else {Ecto.Query.dynamic( [row], fragment("array_agg(?)", ^field) - ), query} + ), nil, query} end end {query, filtered} = - filter_field(sorted, query, aggregate, relationship_path, is_single?) + filter_field( + sorted, + include_nil_filter_field, + query, + aggregate, + relationship_path, + is_single? + ) with_default = if aggregate.default_value do @@ -1608,7 +1654,7 @@ defmodule AshSql.Aggregate do module.dynamic(opts, binding) end - {query, filtered} = filter_field(field, query, aggregate, relationship_path, is_single?) + {query, filtered} = filter_field(field, nil, query, aggregate, relationship_path, is_single?) with_default = if aggregate.default_value do @@ -1634,11 +1680,22 @@ defmodule AshSql.Aggregate do select_or_merge(query, aggregate.name, cast) end - defp filter_field(field, query, _aggregate, _relationship_path, true) do - {query, field} + defp filter_field(field, include_nil_filter_field, query, _aggregate, _relationship_path, true) do + if include_nil_filter_field do + {query, Ecto.Query.dynamic(filter(^field, not is_nil(^include_nil_filter_field)))} + else + {query, field} + end end - defp filter_field(field, query, aggregate, relationship_path, _is_single?) do + defp filter_field( + field, + include_nil_filter_field, + query, + aggregate, + relationship_path, + _is_single? + ) do if has_filter?(aggregate.query) do filter = Ash.Filter.move_to_relationship_path( @@ -1686,10 +1743,19 @@ defmodule AshSql.Aggregate do ) ) - {AshSql.Bindings.merge_expr_accumulator(query, acc), - Ecto.Query.dynamic(filter(^field, ^expr))} + if include_nil_filter_field do + {AshSql.Bindings.merge_expr_accumulator(query, acc), + Ecto.Query.dynamic(filter(^field, ^expr and not is_nil(^include_nil_filter_field)))} + else + {AshSql.Bindings.merge_expr_accumulator(query, acc), + Ecto.Query.dynamic(filter(^field, ^expr))} + end else - {query, field} + if include_nil_filter_field do + {query, Ecto.Query.dynamic(filter(^field, not is_nil(^include_nil_filter_field)))} + else + {query, field} + end end end diff --git a/lib/expr.ex b/lib/expr.ex index 7480797..b308be2 100644 --- a/lib/expr.ex +++ b/lib/expr.ex @@ -360,9 +360,9 @@ defmodule AshSql.Expr do %Fragment{ embedded?: pred_embedded?, arguments: [ - raw: "SELECT COUNT(*) FROM unnest(", + raw: "(SELECT COUNT(*) FROM unnest(", expr: list, - raw: ") AS item WHERE item IS TRUE" + raw: ") AS item WHERE item IS TRUE)" ] }, bindings, @@ -376,9 +376,9 @@ defmodule AshSql.Expr do %Fragment{ embedded?: pred_embedded?, arguments: [ - raw: "SELECT COUNT(*) FROM unnest(", + raw: "(SELECT COUNT(*) FROM unnest(", expr: list, - raw: ") AS item WHERE item IS NULL" + raw: ") AS item WHERE item IS NULL)" ] }, bindings, @@ -537,7 +537,7 @@ defmodule AshSql.Expr do embedded?: pred_embedded?, arguments: [ - raw: "CASE WHEN ", + raw: "(CASE WHEN ", casted_expr: condition, raw: " THEN ", casted_expr: when_true @@ -546,7 +546,7 @@ defmodule AshSql.Expr do [ raw: " ELSE ", casted_expr: when_false, - raw: " END" + raw: " END)" ] }, bindings, @@ -778,10 +778,10 @@ defmodule AshSql.Expr do arguments = case arguments do [{:raw, raw} | rest] -> - [{:raw, "(#{raw}"} | rest] + [{:raw, raw} | rest] arguments -> - [{:raw, "("} | arguments] + [{:raw, ""} | arguments] end arguments = @@ -790,10 +790,10 @@ defmodule AshSql.Expr do arguments {:raw, _} -> - List.update_at(arguments, -1, fn {:raw, raw} -> {:raw, "#{raw})"} end) + arguments _ -> - arguments ++ [{:raw, ")"}] + arguments ++ [{:raw, ""}] end {params, fragment_data, _, acc} = @@ -955,9 +955,11 @@ defmodule AshSql.Expr do %Fragment{ embedded?: pred_embedded?, arguments: [ + raw: "(", casted_expr: left_expr, raw: " || ", - casted_expr: right_expr + casted_expr: right_expr, + raw: ")" ] }, bindings, @@ -998,7 +1000,7 @@ defmodule AshSql.Expr do %Ash.Query.Function.Fragment{ embedded?: pred_embedded?, arguments: [ - raw: "CASE WHEN (", + raw: "(CASE WHEN (", casted_expr: left_expr, raw: " = FALSE OR ", casted_expr: left_expr, @@ -1006,7 +1008,7 @@ defmodule AshSql.Expr do casted_expr: right_expr, raw: " ELSE ", casted_expr: left_expr, - raw: "END" + raw: "END)" ] }, bindings, @@ -1048,7 +1050,7 @@ defmodule AshSql.Expr do %Fragment{ embedded?: pred_embedded?, arguments: [ - raw: "CASE WHEN (", + raw: "(CASE WHEN (", casted_expr: left_expr, raw: " = FALSE OR ", casted_expr: left_expr, @@ -1056,7 +1058,7 @@ defmodule AshSql.Expr do casted_expr: left_expr, raw: " ELSE ", casted_expr: right_expr, - raw: "END" + raw: "END)" ] }, bindings, @@ -1089,7 +1091,7 @@ defmodule AshSql.Expr do do_dynamic_expr( query, - %Ash.Query.Function.Fragment{arguments: [raw: "", expr: string, raw: "::citext"]}, + %Ash.Query.Function.Fragment{arguments: [raw: "(", expr: string, raw: "::citext)"]}, bindings, embedded?, acc,