Skip to content

Commit

Permalink
fix: properly group aggregates together
Browse files Browse the repository at this point in the history
fix: don't attempt to add multiple filter statements to a single aggregate
  • Loading branch information
zachdaniel committed Oct 7, 2024
1 parent 901570e commit 3139261
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 53 deletions.
140 changes: 103 additions & 37 deletions lib/aggregate.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1140,10 +1140,10 @@ defmodule AshSql.Aggregate do
type when is_atom(type) ->
case Ash.Resource.Info.field(related, aggregate.field).type do
{:array, _} ->
false
true

_ ->
true
false
end

_ ->
Expand Down Expand Up @@ -1279,7 +1279,7 @@ defmodule AshSql.Aggregate do
array_agg =
query.__ash_bindings__.sql_behaviour.list_aggregate(aggregate.resource)

{sorted, query} =
{sorted, include_nil_filter_field, query} =
if has_sort? || first_relationship.sort not in [nil, []] do
{sort, binding} =
if has_sort? do
Expand Down Expand Up @@ -1315,42 +1315,68 @@ defmodule AshSql.Aggregate do
query =
AshSql.Bindings.merge_expr_accumulator(query, acc)

{sort_expr, query}
{sort_expr, nil, query}
else
question_marks = Enum.map(sort_expr, fn _ -> " ? " end)

{:ok, expr} =
Ash.Query.Function.Fragment.casted_new(
["#{array_agg}(? ORDER BY #{question_marks}) FILTER (WHERE ? IS NOT NULL)", field] ++
sort_expr ++ [field]
)
{expr, include_nil_filter_field} =
if has_filter?(aggregate.query) and !is_single? do
{:ok, expr} =
Ash.Query.Function.Fragment.casted_new(
[
"#{array_agg}(? ORDER BY #{question_marks})",
field
] ++
sort_expr
)

{expr, field}
else
{:ok, expr} =
Ash.Query.Function.Fragment.casted_new(
[
"#{array_agg}(? ORDER BY #{question_marks}) FILTER (WHERE ? IS NOT NULL)",
field
] ++
sort_expr ++ [field]
)

{expr, nil}
end

{sort_expr, acc} =
AshSql.Expr.dynamic_expr(query, expr, query.__ash_bindings__, false)

query =
AshSql.Bindings.merge_expr_accumulator(query, acc)

{sort_expr, query}
{sort_expr, include_nil_filter_field, query}
end
else
case array_agg do
"array_agg" ->
{Ecto.Query.dynamic(
[row],
fragment("array_agg(?)", ^field)
), query}
), nil, query}

"any_value" ->
{Ecto.Query.dynamic(
[row],
fragment("any_value(?)", ^field)
), query}
), nil, query}
end
end

{query, filtered} =
filter_field(sorted, query, aggregate, relationship_path, is_single?)
filter_field(
sorted,
include_nil_filter_field,
query,
aggregate,
relationship_path,
is_single?
)

value =
if array_agg == "array_agg" do
Expand Down Expand Up @@ -1436,7 +1462,7 @@ defmodule AshSql.Aggregate do

has_sort? = has_sort?(aggregate.query)

{sorted, query} =
{sorted, include_nil_filter_field, query} =
if has_sort? || (first_relationship && first_relationship.sort not in [nil, []]) do
{sort, binding} =
if has_sort? do
Expand Down Expand Up @@ -1464,25 +1490,38 @@ defmodule AshSql.Aggregate do
""
end

expr =
{expr, include_nil_filter_field} =
if aggregate.include_nil? do
{:ok, expr} =
Ash.Query.Function.Fragment.casted_new(
["array_agg(#{distinct}? ORDER BY #{question_marks})", field] ++ sort_expr
)

expr
{expr, nil}
else
{:ok, expr} =
Ash.Query.Function.Fragment.casted_new(
[
"array_agg(#{distinct}? ORDER BY #{question_marks}) FILTER (WHERE ? IS NOT NULL)",
field
] ++
sort_expr ++ [field]
)

expr
if has_filter?(aggregate.query) and !is_single? do
{:ok, expr} =
Ash.Query.Function.Fragment.casted_new(
[
"array_agg(#{distinct}? ORDER BY #{question_marks})",
field
] ++
sort_expr ++ [field]
)

{expr, field}
else
{:ok, expr} =
Ash.Query.Function.Fragment.casted_new(
[
"array_agg(#{distinct}? ORDER BY #{question_marks}) FILTER (WHERE ? IS NOT NULL)",
field
] ++
sort_expr ++ [field]
)

{expr, nil}
end
end

{expr, acc} =
Expand All @@ -1491,23 +1530,30 @@ defmodule AshSql.Aggregate do
query =
AshSql.Bindings.merge_expr_accumulator(query, acc)

{expr, query}
{expr, include_nil_filter_field, query}
else
if Map.get(aggregate, :uniq?) do
{Ecto.Query.dynamic(
[row],
fragment("array_agg(DISTINCT ?)", ^field)
), query}
), nil, query}
else
{Ecto.Query.dynamic(
[row],
fragment("array_agg(?)", ^field)
), query}
), nil, query}
end
end

{query, filtered} =
filter_field(sorted, query, aggregate, relationship_path, is_single?)
filter_field(
sorted,
include_nil_filter_field,
query,
aggregate,
relationship_path,
is_single?
)

with_default =
if aggregate.default_value do
Expand Down Expand Up @@ -1608,7 +1654,7 @@ defmodule AshSql.Aggregate do
module.dynamic(opts, binding)
end

{query, filtered} = filter_field(field, query, aggregate, relationship_path, is_single?)
{query, filtered} = filter_field(field, nil, query, aggregate, relationship_path, is_single?)

with_default =
if aggregate.default_value do
Expand All @@ -1634,11 +1680,22 @@ defmodule AshSql.Aggregate do
select_or_merge(query, aggregate.name, cast)
end

defp filter_field(field, query, _aggregate, _relationship_path, true) do
{query, field}
defp filter_field(field, include_nil_filter_field, query, _aggregate, _relationship_path, true) do
if include_nil_filter_field do
{query, Ecto.Query.dynamic(filter(^field, not is_nil(^include_nil_filter_field)))}
else
{query, field}
end
end

defp filter_field(field, query, aggregate, relationship_path, _is_single?) do
defp filter_field(
field,
include_nil_filter_field,
query,
aggregate,
relationship_path,
_is_single?
) do
if has_filter?(aggregate.query) do
filter =
Ash.Filter.move_to_relationship_path(
Expand Down Expand Up @@ -1686,10 +1743,19 @@ defmodule AshSql.Aggregate do
)
)

{AshSql.Bindings.merge_expr_accumulator(query, acc),
Ecto.Query.dynamic(filter(^field, ^expr))}
if include_nil_filter_field do
{AshSql.Bindings.merge_expr_accumulator(query, acc),
Ecto.Query.dynamic(filter(^field, ^expr and not is_nil(^include_nil_filter_field)))}
else
{AshSql.Bindings.merge_expr_accumulator(query, acc),
Ecto.Query.dynamic(filter(^field, ^expr))}
end
else
{query, field}
if include_nil_filter_field do
{query, Ecto.Query.dynamic(filter(^field, not is_nil(^include_nil_filter_field)))}
else
{query, field}
end
end
end

Expand Down
34 changes: 18 additions & 16 deletions lib/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,9 @@ defmodule AshSql.Expr do
%Fragment{
embedded?: pred_embedded?,
arguments: [
raw: "SELECT COUNT(*) FROM unnest(",
raw: "(SELECT COUNT(*) FROM unnest(",
expr: list,
raw: ") AS item WHERE item IS TRUE"
raw: ") AS item WHERE item IS TRUE)"
]
},
bindings,
Expand All @@ -376,9 +376,9 @@ defmodule AshSql.Expr do
%Fragment{
embedded?: pred_embedded?,
arguments: [
raw: "SELECT COUNT(*) FROM unnest(",
raw: "(SELECT COUNT(*) FROM unnest(",
expr: list,
raw: ") AS item WHERE item IS NULL"
raw: ") AS item WHERE item IS NULL)"
]
},
bindings,
Expand Down Expand Up @@ -537,7 +537,7 @@ defmodule AshSql.Expr do
embedded?: pred_embedded?,
arguments:
[
raw: "CASE WHEN ",
raw: "(CASE WHEN ",
casted_expr: condition,
raw: " THEN ",
casted_expr: when_true
Expand All @@ -546,7 +546,7 @@ defmodule AshSql.Expr do
[
raw: " ELSE ",
casted_expr: when_false,
raw: " END"
raw: " END)"
]
},
bindings,
Expand Down Expand Up @@ -778,10 +778,10 @@ defmodule AshSql.Expr do
arguments =
case arguments do
[{:raw, raw} | rest] ->
[{:raw, "(#{raw}"} | rest]
[{:raw, raw} | rest]

arguments ->
[{:raw, "("} | arguments]
[{:raw, ""} | arguments]
end

arguments =
Expand All @@ -790,10 +790,10 @@ defmodule AshSql.Expr do
arguments

{:raw, _} ->
List.update_at(arguments, -1, fn {:raw, raw} -> {:raw, "#{raw})"} end)
arguments

_ ->
arguments ++ [{:raw, ")"}]
arguments ++ [{:raw, ""}]
end

{params, fragment_data, _, acc} =
Expand Down Expand Up @@ -955,9 +955,11 @@ defmodule AshSql.Expr do
%Fragment{
embedded?: pred_embedded?,
arguments: [
raw: "(",
casted_expr: left_expr,
raw: " || ",
casted_expr: right_expr
casted_expr: right_expr,
raw: ")"
]
},
bindings,
Expand Down Expand Up @@ -998,15 +1000,15 @@ defmodule AshSql.Expr do
%Ash.Query.Function.Fragment{
embedded?: pred_embedded?,
arguments: [
raw: "CASE WHEN (",
raw: "(CASE WHEN (",
casted_expr: left_expr,
raw: " = FALSE OR ",
casted_expr: left_expr,
raw: " IS NULL) THEN ",
casted_expr: right_expr,
raw: " ELSE ",
casted_expr: left_expr,
raw: "END"
raw: "END)"
]
},
bindings,
Expand Down Expand Up @@ -1048,15 +1050,15 @@ defmodule AshSql.Expr do
%Fragment{
embedded?: pred_embedded?,
arguments: [
raw: "CASE WHEN (",
raw: "(CASE WHEN (",
casted_expr: left_expr,
raw: " = FALSE OR ",
casted_expr: left_expr,
raw: " IS NULL) THEN ",
casted_expr: left_expr,
raw: " ELSE ",
casted_expr: right_expr,
raw: "END"
raw: "END)"
]
},
bindings,
Expand Down Expand Up @@ -1089,7 +1091,7 @@ defmodule AshSql.Expr do

do_dynamic_expr(
query,
%Ash.Query.Function.Fragment{arguments: [raw: "", expr: string, raw: "::citext"]},
%Ash.Query.Function.Fragment{arguments: [raw: "(", expr: string, raw: "::citext)"]},
bindings,
embedded?,
acc,
Expand Down

0 comments on commit 3139261

Please sign in to comment.