diff --git a/lib/dataloader/ecto.ex b/lib/dataloader/ecto.ex index b9f511d..14e0939 100644 --- a/lib/dataloader/ecto.ex +++ b/lib/dataloader/ecto.ex @@ -324,8 +324,8 @@ if Code.ensure_loaded?(Ecto) do inputs :: [any], repo_opts :: repo_opts ) :: [any] - def run_batch(repo, _queryable, query, col, inputs, repo_opts) do - results = load_rows(col, inputs, query, repo, repo_opts) + def run_batch(repo, queryable, query, col, inputs, repo_opts) do + results = load_rows(col, inputs, queryable, query, repo, repo_opts) grouped_results = group_results(results, col) for value <- inputs do @@ -335,9 +335,33 @@ if Code.ensure_loaded?(Ecto) do end end - defp load_rows(col, inputs, query, repo, repo_opts) do - query - |> where([q], field(q, ^col) in ^inputs) + defp load_rows(col, inputs, queryable, query, repo, repo_opts) do + case query do + %Ecto.Query{limit: limit, offset: offset} when not is_nil(limit) or not is_nil(offset) -> + load_rows_lateral(col, inputs, queryable, query, repo, repo_opts) + + _ -> + query + |> where([q], field(q, ^col) in ^inputs) + |> repo.all(repo_opts) + end + end + + defp load_rows_lateral(col, inputs, queryable, query, repo, repo_opts) do + # Approximate a postgres unnest with a subquery + inputs_query = + queryable + |> where([q], field(q, ^col) in ^inputs) + |> select(^[col]) + |> distinct(true) + + query = + query + |> where([q], field(q, ^col) == field(parent_as(:input), ^col)) + + from(input in subquery(inputs_query), as: :input) + |> join(:inner_lateral, q in subquery(query)) + |> select([_input, q], q) |> repo.all(repo_opts) end diff --git a/mix.lock b/mix.lock index 961bd27..5f6aaf9 100644 --- a/mix.lock +++ b/mix.lock @@ -4,13 +4,13 @@ "decimal": {:hex, :decimal, "1.8.1", "a4ef3f5f3428bdbc0d35374029ffcf4ede8533536fa79896dd450168d9acdf3c", [:mix], [], "hexpm", "3cb154b00225ac687f6cbd4acc4b7960027c757a5152b369923ead9ddbca7aec"}, "dialyxir": {:hex, :dialyxir, "0.5.1", "b331b091720fd93e878137add264bac4f644e1ddae07a70bf7062c7862c4b952", [:mix], [], "hexpm", "6c32a70ed5d452c6650916555b1f96c79af5fc4bf286997f8b15f213de786f73"}, "earmark": {:hex, :earmark, "1.4.3", "364ca2e9710f6bff494117dbbd53880d84bebb692dafc3a78eb50aa3183f2bfd", [:mix], [], "hexpm", "8cf8a291ebf1c7b9539e3cddb19e9cef066c2441b1640f13c34c1d3cfc825fec"}, - "ecto": {:hex, :ecto, "3.3.1", "82ab74298065bf0c64ca299f6c6785e68ea5d6b980883ee80b044499df35aba1", [:mix], [{:decimal, "~> 1.6", [hex: :decimal, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: true]}], "hexpm", "e6c614dfe3bcff2d575ce16d815dbd43f4ee1844599a83de1eea81976a31c174"}, - "ecto_sql": {:hex, :ecto_sql, "3.3.2", "92804e0de69bb63e621273c3492252cb08a29475c05d40eeb6f41ad2d483cfd3", [:mix], [{:db_connection, "~> 2.2", [hex: :db_connection, repo: "hexpm", optional: false]}, {:ecto, "~> 3.3", [hex: :ecto, repo: "hexpm", optional: false]}, {:myxql, "~> 0.3.0", [hex: :myxql, repo: "hexpm", optional: true]}, {:postgrex, "~> 0.15.0", [hex: :postgrex, repo: "hexpm", optional: true]}, {:telemetry, "~> 0.4.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "b82d89d4e6a9f7f7f04783b07e8b0af968e0be2f01ee4b39047fe727c5c07471"}, + "ecto": {:hex, :ecto, "3.4.5", "2bcd262f57b2c888b0bd7f7a28c8a48aa11dc1a2c6a858e45dd8f8426d504265", [:mix], [{:decimal, "~> 1.6 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: true]}, {:telemetry, "~> 0.4", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "8c6d1d4d524559e9b7a062f0498e2c206122552d63eacff0a6567ffe7a8e8691"}, + "ecto_sql": {:hex, :ecto_sql, "3.4.4", "d28bac2d420f708993baed522054870086fd45016a9d09bb2cd521b9c48d32ea", [:mix], [{:db_connection, "~> 2.2", [hex: :db_connection, repo: "hexpm", optional: false]}, {:ecto, "~> 3.4.3", [hex: :ecto, repo: "hexpm", optional: false]}, {:myxql, "~> 0.3.0 or ~> 0.4.0", [hex: :myxql, repo: "hexpm", optional: true]}, {:postgrex, "~> 0.15.0", [hex: :postgrex, repo: "hexpm", optional: true]}, {:tds, "~> 2.1.0", [hex: :tds, repo: "hexpm", optional: true]}, {:telemetry, "~> 0.4.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "edb49af715dd72f213b66adfd0f668a43c17ed510b5d9ac7528569b23af57fe8"}, "ex_doc": {:hex, :ex_doc, "0.21.2", "caca5bc28ed7b3bdc0b662f8afe2bee1eedb5c3cf7b322feeeb7c6ebbde089d6", [:mix], [{:earmark, "~> 1.3.3 or ~> 1.4", [hex: :earmark, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}], "hexpm", "f1155337ae17ff7a1255217b4c1ceefcd1860b7ceb1a1874031e7a861b052e39"}, "makeup": {:hex, :makeup, "1.0.0", "671df94cf5a594b739ce03b0d0316aa64312cee2574b6a44becb83cd90fb05dc", [:mix], [{:nimble_parsec, "~> 0.5.0", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "a10c6eb62cca416019663129699769f0c2ccf39428b3bb3c0cb38c718a0c186d"}, "makeup_elixir": {:hex, :makeup_elixir, "0.14.0", "cf8b7c66ad1cff4c14679698d532f0b5d45a3968ffbcbfd590339cb57742f1ae", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "d4b316c7222a85bbaa2fd7c6e90e37e953257ad196dc229505137c5e505e9eff"}, "nimble_parsec": {:hex, :nimble_parsec, "0.5.3", "def21c10a9ed70ce22754fdeea0810dafd53c2db3219a0cd54cf5526377af1c6", [:mix], [], "hexpm", "589b5af56f4afca65217a1f3eb3fee7e79b09c40c742fddc1c312b3ac0b3399f"}, "poolboy": {:hex, :poolboy, "1.5.1", "6b46163901cfd0a1b43d692657ed9d7e599853b3b21b95ae5ae0a777cf9b6ca8", [:rebar], [], "hexpm"}, - "postgrex": {:hex, :postgrex, "0.15.3", "5806baa8a19a68c4d07c7a624ccdb9b57e89cbc573f1b98099e3741214746ae4", [:mix], [{:connection, "~> 1.0", [hex: :connection, repo: "hexpm", optional: false]}, {:db_connection, "~> 2.1", [hex: :db_connection, repo: "hexpm", optional: false]}, {:decimal, "~> 1.5", [hex: :decimal, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: true]}], "hexpm", "4737ce62a31747b4c63c12b20c62307e51bb4fcd730ca0c32c280991e0606c90"}, + "postgrex": {:hex, :postgrex, "0.15.5", "aec40306a622d459b01bff890fa42f1430dac61593b122754144ad9033a2152f", [:mix], [{:connection, "~> 1.0", [hex: :connection, repo: "hexpm", optional: false]}, {:db_connection, "~> 2.1", [hex: :db_connection, repo: "hexpm", optional: false]}, {:decimal, "~> 1.5", [hex: :decimal, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: true]}], "hexpm", "ed90c81e1525f65a2ba2279dbcebf030d6d13328daa2f8088b9661eb9143af7f"}, "telemetry": {:hex, :telemetry, "0.4.1", "ae2718484892448a24470e6aa341bc847c3277bfb8d4e9289f7474d752c09c7f", [:rebar3], [], "hexpm", "4738382e36a0a9a2b6e25d67c960e40e1a2c95560b9f936d8e29de8cd858480f"}, } diff --git a/test/dataloader/ecto/limit_query_test.exs b/test/dataloader/ecto/limit_query_test.exs new file mode 100644 index 0000000..4f0c09f --- /dev/null +++ b/test/dataloader/ecto/limit_query_test.exs @@ -0,0 +1,64 @@ +defmodule Dataloader.LimitQueryTest do + use ExUnit.Case, async: true + + alias Dataloader.{User, Post} + import Ecto.Query + alias Dataloader.TestRepo, as: Repo + + setup do + :ok = Ecto.Adapters.SQL.Sandbox.checkout(Repo) + + test_pid = self() + + source = + Dataloader.Ecto.new( + Repo, + query: &query(&1, &2, test_pid) + ) + + loader = + Dataloader.new() + |> Dataloader.add_source(Test, source) + + {:ok, loader: loader} + end + + defp query(Post, %{limit: limit}, test_pid) do + send(test_pid, :querying) + + Post + |> where([p], is_nil(p.deleted_at)) + |> order_by(asc: :id) + |> limit(^limit) + end + + defp query(queryable, _args, test_pid) do + send(test_pid, :querying) + queryable + end + + test "Query limit does not apply globally", %{loader: loader} do + user1 = %User{username: "Ben Wilson"} |> Repo.insert!() + user2 = %User{username: "Bruce Williams"} |> Repo.insert!() + + [post1, _post2, post3, _post4] = + [ + %Post{user_id: user1.id, title: "foo"}, + %Post{user_id: user1.id, title: "baz"}, + %Post{user_id: user2.id, title: "bar"}, + %Post{user_id: user2.id, title: "qux"} + ] + |> Enum.map(&Repo.insert!/1) + + args = {{:many, Post}, %{limit: 1}} + + loader = + loader + |> Dataloader.load(Test, args, user_id: user1.id) + |> Dataloader.load(Test, args, user_id: user2.id) + |> Dataloader.run() + + assert [post1] == Dataloader.get(loader, Test, args, user_id: user1.id) + assert [post3] == Dataloader.get(loader, Test, args, user_id: user2.id) + end +end