Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner: fix correlated aggregates which should be evaluated in outer query #21431

Merged
merged 38 commits into from
Dec 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d00e7b4
planner: fix the behavior of correlated aggregate in subquery
dyzsr Nov 30, 2020
6ea8272
extract aggfuncs inside subquery's where clause and window functions
dyzsr Dec 1, 2020
7c3d74a
Merge branch 'master' into agginsubquery
dyzsr Dec 3, 2020
d08ebe5
updates
dyzsr Dec 3, 2020
2101350
Merge branch 'master' into agginsubquery
dyzsr Dec 3, 2020
9fa1b84
Merge branch 'master' into agginsubquery
dyzsr Dec 4, 2020
5c5985c
bug fixes
dyzsr Dec 4, 2020
939dbaf
update tests
dyzsr Dec 4, 2020
309e007
add resultSetNode cache & modify tests
dyzsr Dec 4, 2020
691373e
resolve pattern in, exists, eq subquery
dyzsr Dec 4, 2020
74fcf5a
refactors
dyzsr Dec 7, 2020
fba20e8
cleanup code
dyzsr Dec 8, 2020
6187fc3
updates
dyzsr Dec 8, 2020
6d2afce
Merge branch 'master' into agginsubquery
dyzsr Dec 8, 2020
2897eb6
support nested aggregates
dyzsr Dec 8, 2020
0c3bff7
add comments
dyzsr Dec 8, 2020
a619159
fix explain_easy
dyzsr Dec 8, 2020
130e0ec
update testcases
dyzsr Dec 8, 2020
e6f68b5
Revert "fix explain_easy"
dyzsr Dec 8, 2020
3461bae
updates
dyzsr Dec 8, 2020
e2e8581
Merge branch 'master' into agginsubquery
dyzsr Dec 8, 2020
e84804f
add clause code
dyzsr Dec 9, 2020
cd94fdb
Merge branch 'master' into agginsubquery
dyzsr Dec 10, 2020
e159293
update addressing comments
dyzsr Dec 10, 2020
36cf039
Merge branch 'master' into agginsubquery
dyzsr Dec 11, 2020
c141f38
add plan test
dyzsr Dec 11, 2020
ce204b8
collect from GROUP BY & bug fixes
dyzsr Dec 11, 2020
4d9399d
Merge branch 'master' into agginsubquery
dyzsr Dec 11, 2020
4ede850
Merge branch 'master' into agginsubquery
dyzsr Dec 17, 2020
452fc0b
remove unnecessary methods
dyzsr Dec 17, 2020
1b15ec5
bug fixes & add explain tests
dyzsr Dec 17, 2020
d4a3064
fix CI
dyzsr Dec 18, 2020
c44df47
fix CI
dyzsr Dec 18, 2020
dce2cd3
Merge branch 'master' into agginsubquery
dyzsr Dec 18, 2020
8ef33b0
updates
dyzsr Dec 18, 2020
a60946d
add testcases
dyzsr Dec 18, 2020
1741414
Merge branch 'master' into agginsubquery
ti-srebot Dec 18, 2020
e06c3c1
Merge branch 'master' into agginsubquery
ti-srebot Dec 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions cmd/explaintest/r/explain_easy.result
Original file line number Diff line number Diff line change
Expand Up @@ -651,9 +651,9 @@ drop table if exists t;
create table t(a int, b int, c int);
explain select * from (select * from t order by (select 2)) t order by a, b;
id estRows task access object operator info
Sort_12 10000.00 root test.t.a, test.t.b
└─TableReader_18 10000.00 root data:TableFullScan_17
└─TableFullScan_17 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
Sort_13 10000.00 root test.t.a, test.t.b
└─TableReader_19 10000.00 root data:TableFullScan_18
└─TableFullScan_18 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
explain select * from (select * from t order by c) t order by a, b;
id estRows task access object operator info
Sort_6 10000.00 root test.t.a, test.t.b
Expand Down Expand Up @@ -784,3 +784,65 @@ Update_4 N/A root N/A
├─IndexRangeScan_9(Build) 0.10 cop[tikv] table:t, index:a(a, b) range:[0xFA34E1093CB428485734E3917F000000 "xb",0xFA34E1093CB428485734E3917F000000 "xb"], keep order:false, stats:pseudo
└─TableRowIDScan_10(Probe) 0.10 cop[tikv] table:t keep order:false, stats:pseudo
drop table if exists t;
create table t(a int, b int);
explain select (select count(n.a) from t) from t n;
id estRows task access object operator info
Projection_9 1.00 root Column#8
└─Apply_11 1.00 root CARTESIAN left outer join
├─StreamAgg_23(Build) 1.00 root funcs:count(Column#13)->Column#7
│ └─TableReader_24 1.00 root data:StreamAgg_15
│ └─StreamAgg_15 1.00 cop[tikv] funcs:count(test.t.a)->Column#13
│ └─TableFullScan_22 10000.00 cop[tikv] table:n keep order:false, stats:pseudo
└─MaxOneRow_27(Probe) 1.00 root
└─Projection_28 2.00 root Column#7
└─TableReader_30 2.00 root data:TableFullScan_29
└─TableFullScan_29 2.00 cop[tikv] table:t keep order:false, stats:pseudo
explain select (select sum((select count(a)))) from t;
id estRows task access object operator info
Projection_23 1.00 root Column#7
└─Apply_25 1.00 root CARTESIAN left outer join
├─StreamAgg_37(Build) 1.00 root funcs:count(Column#15)->Column#5
│ └─TableReader_38 1.00 root data:StreamAgg_29
│ └─StreamAgg_29 1.00 cop[tikv] funcs:count(test.t.a)->Column#15
│ └─TableFullScan_36 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
└─HashAgg_43(Probe) 1.00 root funcs:sum(Column#12)->Column#7
└─HashJoin_44 1.00 root CARTESIAN left outer join
├─HashAgg_49(Build) 1.00 root group by:1, funcs:sum(Column#16)->Column#12
│ └─Projection_54 1.00 root cast(Column#6, decimal(42,0) BINARY)->Column#16
│ └─MaxOneRow_50 1.00 root
│ └─Projection_51 1.00 root Column#5
│ └─TableDual_52 1.00 root rows:1
└─TableDual_46(Probe) 1.00 root rows:1
explain select count(a) from t group by b order by (select count(a));
id estRows task access object operator info
Sort_12 8000.00 root Column#4
└─HashJoin_14 8000.00 root CARTESIAN left outer join
├─TableDual_24(Build) 1.00 root rows:1
└─HashAgg_20(Probe) 8000.00 root group by:test.t.b, funcs:count(Column#8)->Column#4
└─TableReader_21 8000.00 root data:HashAgg_16
└─HashAgg_16 8000.00 cop[tikv] group by:test.t.b, funcs:count(test.t.a)->Column#8
└─TableFullScan_19 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
explain select (select sum(count(a))) from t;
id estRows task access object operator info
Projection_11 1.00 root Column#5
└─Apply_13 1.00 root CARTESIAN left outer join
├─StreamAgg_25(Build) 1.00 root funcs:count(Column#8)->Column#4
│ └─TableReader_26 1.00 root data:StreamAgg_17
│ └─StreamAgg_17 1.00 cop[tikv] funcs:count(test.t.a)->Column#8
│ └─TableFullScan_24 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
└─StreamAgg_32(Probe) 1.00 root funcs:sum(Column#9)->Column#5
└─Projection_39 1.00 root cast(Column#4, decimal(42,0) BINARY)->Column#9
└─TableDual_37 1.00 root rows:1
explain select sum(a), (select sum(a)), count(a) from t group by b order by (select count(a));
id estRows task access object operator info
Projection_16 8000.00 root Column#4, Column#4, Column#5
└─Sort_17 8000.00 root Column#5
└─HashJoin_19 8000.00 root CARTESIAN left outer join
├─TableDual_33(Build) 1.00 root rows:1
└─HashJoin_21(Probe) 8000.00 root CARTESIAN left outer join
├─TableDual_31(Build) 1.00 root rows:1
└─HashAgg_27(Probe) 8000.00 root group by:test.t.b, funcs:sum(Column#13)->Column#4, funcs:count(Column#14)->Column#5
└─TableReader_28 8000.00 root data:HashAgg_23
└─HashAgg_23 8000.00 cop[tikv] group by:test.t.b, funcs:sum(test.t.a)->Column#13, funcs:count(test.t.a)->Column#14
└─TableFullScan_26 10000.00 cop[tikv] table:t keep order:false, stats:pseudo
drop table if exists t;
8 changes: 4 additions & 4 deletions cmd/explaintest/r/tpch.result
Original file line number Diff line number Diff line change
Expand Up @@ -711,10 +711,10 @@ and n_name = 'MOZAMBIQUE'
order by
value desc;
id estRows task access object operator info
Projection_57 1304801.67 root tpch.partsupp.ps_partkey, Column#18
└─Sort_58 1304801.67 root Column#18:desc
└─Selection_60 1304801.67 root gt(Column#18, NULL)
└─HashAgg_63 1631002.09 root group by:Column#44, funcs:sum(Column#42)->Column#18, funcs:firstrow(Column#43)->tpch.partsupp.ps_partkey
Projection_57 1304801.67 root tpch.partsupp.ps_partkey, Column#35
└─Sort_58 1304801.67 root Column#35:desc
└─Selection_60 1304801.67 root gt(Column#35, NULL)
└─HashAgg_63 1631002.09 root group by:Column#44, funcs:sum(Column#42)->Column#35, funcs:firstrow(Column#43)->tpch.partsupp.ps_partkey
└─Projection_89 1631002.09 root mul(tpch.partsupp.ps_supplycost, cast(tpch.partsupp.ps_availqty, decimal(20,0) BINARY))->Column#42, tpch.partsupp.ps_partkey, tpch.partsupp.ps_partkey
└─HashJoin_67 1631002.09 root inner join, equal:[eq(tpch.supplier.s_suppkey, tpch.partsupp.ps_suppkey)]
├─HashJoin_80(Build) 20000.00 root inner join, equal:[eq(tpch.nation.n_nationkey, tpch.supplier.s_nationkey)]
Expand Down
8 changes: 8 additions & 0 deletions cmd/explaintest/t/explain_easy.test
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,11 @@ create table t(a binary(16) not null, b varchar(2) default null, c varchar(100)
explain select * from t where a=x'FA34E1093CB428485734E3917F000000' and b='xb';
explain update t set c = 'ssss' where a=x'FA34E1093CB428485734E3917F000000' and b='xb';
drop table if exists t;

create table t(a int, b int);
explain select (select count(n.a) from t) from t n;
explain select (select sum((select count(a)))) from t;
explain select count(a) from t group by b order by (select count(a));
explain select (select sum(count(a))) from t;
explain select sum(a), (select sum(a)), count(a) from t group by b order by (select count(a));
drop table if exists t;
8 changes: 4 additions & 4 deletions executor/testdata/agg_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,21 @@
"Name": "TestIssue12759HashAggCalledByApply",
"Cases": [
[
"Projection_28 1.00 root Column#3, Column#6, Column#9, Column#12",
"Projection_28 1.00 root Column#9, Column#10, Column#11, Column#12",
"└─Apply_30 1.00 root CARTESIAN left outer join",
" ├─Apply_32(Build) 1.00 root CARTESIAN left outer join",
" │ ├─Apply_34(Build) 1.00 root CARTESIAN left outer join",
" │ │ ├─HashAgg_39(Build) 1.00 root funcs:sum(Column#22)->Column#3, funcs:firstrow(Column#23)->test.test.a",
" │ │ ├─HashAgg_39(Build) 1.00 root funcs:sum(Column#22)->Column#9, funcs:firstrow(Column#23)->test.test.a",
" │ │ │ └─TableReader_40 1.00 root data:HashAgg_35",
" │ │ │ └─HashAgg_35 1.00 cop[tikv] funcs:sum(test.test.a)->Column#22, funcs:firstrow(test.test.a)->Column#23",
" │ │ │ └─TableFullScan_38 10000.00 cop[tikv] table:tt keep order:false, stats:pseudo",
" │ │ └─Projection_43(Probe) 1.00 root <nil>->Column#6",
" │ │ └─Projection_43(Probe) 1.00 root <nil>->Column#10",
" │ │ └─Limit_44 1.00 root offset:0, count:1",
" │ │ └─TableReader_50 1.00 root data:Limit_49",
" │ │ └─Limit_49 1.00 cop[tikv] offset:0, count:1",
" │ │ └─Selection_48 1.00 cop[tikv] eq(test.test.a, test.test.a)",
" │ │ └─TableFullScan_47 1000.00 cop[tikv] table:test keep order:false, stats:pseudo",
" │ └─Projection_54(Probe) 1.00 root <nil>->Column#9",
" │ └─Projection_54(Probe) 1.00 root <nil>->Column#11",
" │ └─Limit_55 1.00 root offset:0, count:1",
" │ └─TableReader_61 1.00 root data:Limit_60",
" │ └─Limit_60 1.00 cop[tikv] offset:0, count:1",
Expand Down
2 changes: 1 addition & 1 deletion planner/cascades/stringer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,6 @@ func (s *testStringerSuite) TestGroupStringer(c *C) {
output[i].SQL = sql
output[i].Result = ToString(group)
})
c.Assert(ToString(group), DeepEquals, output[i].Result)
c.Assert(ToString(group), DeepEquals, output[i].Result, Commentf("case:%v, sql:%s", i, sql))
}
}
4 changes: 2 additions & 2 deletions planner/cascades/testdata/integration_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -1048,10 +1048,10 @@
{
"SQL": "select sum(a), (select t1.a from t1 where t1.a = t2.a limit 1), (select t1.b from t1 where t1.b = t2.b limit 1) from t2",
"Plan": [
"Projection_30 1.00 root Column#3, test.t1.a, test.t1.b",
"Projection_30 1.00 root Column#7, test.t1.a, test.t1.b",
"└─Apply_32 1.00 root CARTESIAN left outer join",
" ├─Apply_34(Build) 1.00 root CARTESIAN left outer join",
" │ ├─HashAgg_39(Build) 1.00 root funcs:sum(Column#8)->Column#3, funcs:firstrow(Column#9)->test.t2.a, funcs:firstrow(Column#10)->test.t2.b",
" │ ├─HashAgg_39(Build) 1.00 root funcs:sum(Column#8)->Column#7, funcs:firstrow(Column#9)->test.t2.a, funcs:firstrow(Column#10)->test.t2.b",
" │ │ └─TableReader_40 1.00 root data:HashAgg_41",
" │ │ └─HashAgg_41 1.00 cop[tikv] funcs:sum(test.t2.a)->Column#8, funcs:firstrow(test.t2.a)->Column#9, funcs:firstrow(test.t2.b)->Column#10",
" │ │ └─TableFullScan_38 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo",
Expand Down
2 changes: 1 addition & 1 deletion planner/cascades/testdata/stringer_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@
"SQL": "select a = (select a from t t2 where t1.b = t2.b order by a limit 1) from t t1",
"Result": [
"Group#0 Schema:[Column#25]",
" Projection_2 input:[Group#1], eq(test.t.a, test.t.a)->Column#25",
" Projection_3 input:[Group#1], eq(test.t.a, test.t.a)->Column#25",
"Group#1 Schema:[test.t.a,test.t.b,test.t.a]",
" Apply_9 input:[Group#2,Group#3], left outer join",
"Group#2 Schema:[test.t.a,test.t.b], UniqueKey:[test.t.a]",
Expand Down
19 changes: 16 additions & 3 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,24 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
if er.aggrMap != nil {
index, ok = er.aggrMap[v]
}
if !ok {
er.err = ErrInvalidGroupFuncUse
if ok {
// index < 0 indicates this is a correlated aggregate belonging to outer query,
// for which a correlated column will be created later, so we append a null constant
// as a temporary result expression.
if index < 0 {
er.ctxStackAppend(expression.NewNull(), types.EmptyName)
} else {
// index >= 0 indicates this is a regular aggregate column
er.ctxStackAppend(er.schema.Columns[index], er.names[index])
}
return inNode, true
}
er.ctxStackAppend(er.schema.Columns[index], er.names[index])
// replace correlated aggregate in sub-query with its corresponding correlated column
if col, ok := er.b.correlatedAggMapper[v]; ok {
er.ctxStackAppend(col, types.EmptyName)
return inNode, true
}
er.err = ErrInvalidGroupFuncUse
return inNode, true
case *ast.ColumnNameExpr:
if index, ok := er.b.colMapper[v]; ok {
Expand Down
80 changes: 80 additions & 0 deletions planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2049,6 +2049,86 @@ func (s *testIntegrationSuite) TestOrderByNotInSelectDistinct(c *C) {
tk.MustQuery("select distinct v1 as z from ttest order by v1+z").Check(testkit.Rows("1", "4"))
}

func (s *testIntegrationSuite) TestCorrelatedAggregate(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")

// #18350
tk.MustExec("DROP TABLE IF EXISTS tab, tab2")
tk.MustExec("CREATE TABLE tab(i INT)")
tk.MustExec("CREATE TABLE tab2(j INT)")
tk.MustExec("insert into tab values(1),(2),(3)")
tk.MustExec("insert into tab2 values(1),(2),(3),(15)")
tk.MustQuery(`SELECT m.i,
(SELECT COUNT(n.j)
FROM tab2 WHERE j=15) AS o
FROM tab m, tab2 n GROUP BY 1 order by m.i`).Check(testkit.Rows("1 4", "2 4", "3 4"))
tk.MustQuery(`SELECT
(SELECT COUNT(n.j)
FROM tab2 WHERE j=15) AS o
FROM tab m, tab2 n order by m.i`).Check(testkit.Rows("12"))

// #17748
tk.MustExec("drop table if exists t1, t2")
tk.MustExec("create table t1 (a int, b int)")
tk.MustExec("create table t2 (m int, n int)")
tk.MustExec("insert into t1 values (2,2), (2,2), (3,3), (3,3), (3,3), (4,4)")
tk.MustExec("insert into t2 values (1,11), (2,22), (3,32), (4,44), (4,44)")
tk.MustExec("set @@sql_mode='TRADITIONAL'")

tk.MustQuery(`select count(*) c, a,
( select group_concat(count(a)) from t2 where m = a )
from t1 group by a order by a`).
Check(testkit.Rows("2 2 2", "3 3 3", "1 4 1,1"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a int, b int)")
tk.MustExec("insert into t values (1,1),(2,1),(2,2),(3,1),(3,2),(3,3)")

// Sub-queries in SELECT fields
// from SELECT fields
tk.MustQuery("select (select count(a)) from t").Check(testkit.Rows("6"))
tk.MustQuery("select (select (select (select count(a)))) from t").Check(testkit.Rows("6"))
tk.MustQuery("select (select (select count(n.a)) from t m order by count(m.b)) from t n").Check(testkit.Rows("6"))
// from WHERE
tk.MustQuery("select (select count(n.a) from t where count(n.a)=3) from t n").Check(testkit.Rows("<nil>"))
tk.MustQuery("select (select count(a) from t where count(distinct n.a)=3) from t n").Check(testkit.Rows("6"))
// from HAVING
tk.MustQuery("select (select count(n.a) from t having count(n.a)=6 limit 1) from t n").Check(testkit.Rows("6"))
tk.MustQuery("select (select count(n.a) from t having count(distinct n.b)=3 limit 1) from t n").Check(testkit.Rows("6"))
tk.MustQuery("select (select sum(distinct n.a) from t having count(distinct n.b)=3 limit 1) from t n").Check(testkit.Rows("6"))
tk.MustQuery("select (select sum(distinct n.a) from t having count(distinct n.b)=6 limit 1) from t n").Check(testkit.Rows("<nil>"))
// from ORDER BY
tk.MustQuery("select (select count(n.a) from t order by count(n.b) limit 1) from t n").Check(testkit.Rows("6"))
tk.MustQuery("select (select count(distinct n.b) from t order by count(n.b) limit 1) from t n").Check(testkit.Rows("3"))
// from TableRefsClause
tk.MustQuery("select (select cnt from (select count(a) cnt) s) from t").Check(testkit.Rows("6"))
tk.MustQuery("select (select count(cnt) from (select count(a) cnt) s) from t").Check(testkit.Rows("1"))
// from sub-query inside aggregate
tk.MustQuery("select (select sum((select count(a)))) from t").Check(testkit.Rows("6"))
tk.MustQuery("select (select sum((select count(a))+sum(a))) from t").Check(testkit.Rows("20"))
// from GROUP BY
tk.MustQuery("select (select count(a) from t group by count(n.a)) from t n").Check(testkit.Rows("6"))
tk.MustQuery("select (select count(distinct a) from t group by count(n.a)) from t n").Check(testkit.Rows("3"))

// Sub-queries in HAVING
tk.MustQuery("select sum(a) from t having (select count(a)) = 0").Check(testkit.Rows())
tk.MustQuery("select sum(a) from t having (select count(a)) > 0").Check(testkit.Rows("14"))

// Sub-queries in ORDER BY
tk.MustQuery("select count(a) from t group by b order by (select count(a))").Check(testkit.Rows("1", "2", "3"))
tk.MustQuery("select count(a) from t group by b order by (select -count(a))").Check(testkit.Rows("3", "2", "1"))

// Nested aggregate (correlated aggregate inside aggregate)
tk.MustQuery("select (select sum(count(a))) from t").Check(testkit.Rows("6"))
tk.MustQuery("select (select sum(sum(a))) from t").Check(testkit.Rows("14"))

// Combining aggregates
tk.MustQuery("select count(a), (select count(a)) from t").Check(testkit.Rows("6 6"))
tk.MustQuery("select sum(distinct b), count(a), (select count(a)), (select cnt from (select sum(distinct b) as cnt) n) from t").
Check(testkit.Rows("6 6 6 6"))
}

func (s *testIntegrationSuite) TestCorrelatedColumnAggFuncPushDown(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test;")
Expand Down
Loading