Skip to content

Commit

Permalink
planner: fix wrong result when pushing Agg down through Union in MPP …
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-chi-bot authored Oct 13, 2023
1 parent 715048e commit e07cf40
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 37 deletions.
2 changes: 1 addition & 1 deletion executor/tiflashtest/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ go_test(
],
flaky = True,
race = "on",
shard_count = 37,
shard_count = 38,
deps = [
"//config",
"//domain",
Expand Down
29 changes: 29 additions & 0 deletions executor/tiflashtest/tiflash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,35 @@ func TestAggPushDownCountStar(t *testing.T) {
tk.MustQuery("select count(*) from c, o where c.c_id=o.c_id").Check(testkit.Rows("5"))
}

func TestAggPushDownUnionAndMPP(t *testing.T) {
store := testkit.CreateMockStore(t, withMockTiFlash(2))
tk := testkit.NewTestKit(t, store)

tk.MustExec("use test")
tk.MustExec("create table t (a int, b int)")
tk.MustExec("alter table t set tiflash replica 1")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("set @@tidb_allow_mpp=1;")
tk.MustExec("set @@tidb_enforce_mpp=1;")
tk.MustExec("set @@tidb_opt_agg_push_down=1")

tk.MustExec("create table c(c_id int)")
tk.MustExec("create table o(o_id int, c_id int)")
tk.MustExec("insert into c values(1),(1),(1),(1)")
tk.MustExec("insert into o values(1,1),(1,1),(1,2)")
tk.MustExec("alter table c set tiflash replica 1")
tk.MustExec("alter table o set tiflash replica 1")

tk.MustQuery("select a, count(1) from (select a, b from t union all select a, " +
"b from t) s group by a order by a").Check(testkit.Rows("1 10"))

tk.MustQuery("select o.o_id, count(*) from c, o where c.c_id=o.o_id group by o.o_id").Check(testkit.Rows("1 12"))
}

func TestGroupStreamAggOnTiFlash(t *testing.T) {
store := testkit.CreateMockStore(t, withMockTiFlash(2))
tk := testkit.NewTestKit(t, store)
Expand Down
9 changes: 8 additions & 1 deletion planner/core/casetest/enforce_mpp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,20 @@ func TestMPP2PhaseAggPushDown(t *testing.T) {
tk.MustExec("create table c(c_id bigint)")
tk.MustExec("create table o(o_id bigint, c_id bigint not null)")

tk.MustExec("create table t (a int, b int)")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")
tk.MustExec("insert into t values (1, 1);")

// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Session())
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
require.True(t, exists)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "c" || tblInfo.Name.L == "o" {
if tblInfo.Name.L == "c" || tblInfo.Name.L == "o" || tblInfo.Name.L == "t" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Count: 1,
Available: true,
Expand Down
3 changes: 2 additions & 1 deletion planner/core/casetest/testdata/enforce_mpp_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@
"set @@tidb_allow_mpp=1;set @@tidb_enforce_mpp=1;set @@tidb_opt_agg_push_down=1;",
"EXPLAIN select count(*) from c, o where c.c_id=o.c_id; -- 1. test agg push down, scalar aggregate",
"EXPLAIN select o.o_id, count(*) from c, o where c.c_id=o.c_id group by o.o_id; -- 2. test agg push down, group by non-join column",
"EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column"
"EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column",
"EXPLAIN format='brief' select a, count(*) from (select a, b from t union all select a, b from t) t group by a order by a limit 10"
]
},
{
Expand Down
99 changes: 65 additions & 34 deletions planner/core/casetest/testdata/enforce_mpp_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -658,48 +658,79 @@
{
"SQL": "EXPLAIN select o.o_id, count(*) from c, o where c.c_id=o.c_id group by o.o_id; -- 2. test agg push down, group by non-join column",
"Plan": [
"TableReader_78 8000.00 root MppVersion: 1, data:ExchangeSender_77",
"└─ExchangeSender_77 8000.00 mpp[tiflash] ExchangeType: PassThrough",
"TableReader_84 8000.00 root MppVersion: 1, data:ExchangeSender_83",
"└─ExchangeSender_83 8000.00 mpp[tiflash] ExchangeType: PassThrough",
" └─Projection_10 8000.00 mpp[tiflash] test.o.o_id, Column#6",
" └─Projection_76 8000.00 mpp[tiflash] Column#6, test.o.o_id",
" └─HashAgg_75 8000.00 mpp[tiflash] group by:test.o.o_id, funcs:count(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.o_id",
" └─ExchangeReceiver_71 9990.00 mpp[tiflash] ",
" └─ExchangeSender_70 9990.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary]",
" └─HashJoin_69 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
" ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST",
" │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.o_id, test.o.c_id",
" │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.o_id)->Column#8, funcs:firstrow(test.o.o_id)->test.o.o_id, funcs:firstrow(test.o.c_id)->test.o.c_id",
" │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary], [name: test.o.c_id, collate: binary]",
" │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:count(1)->Column#9",
" │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo",
" └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))",
" └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo"
" └─Projection_79 8000.00 mpp[tiflash] Column#6, test.o.o_id",
" └─HashAgg_80 8000.00 mpp[tiflash] group by:test.o.o_id, funcs:sum(Column#25)->Column#6, funcs:firstrow(Column#26)->test.o.o_id",
" └─ExchangeReceiver_82 8000.00 mpp[tiflash] ",
" └─ExchangeSender_81 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary]",
" └─HashAgg_76 8000.00 mpp[tiflash] group by:Column#29, funcs:sum(Column#27)->Column#25, funcs:firstrow(Column#28)->Column#26",
" └─Projection_85 9990.00 mpp[tiflash] cast(Column#7, decimal(20,0) BINARY)->Column#27, Column#8, test.o.o_id",
" └─HashJoin_78 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
" ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST",
" │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.o_id, test.o.c_id",
" │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.o_id)->Column#8, funcs:firstrow(test.o.o_id)->test.o.o_id, funcs:firstrow(test.o.c_id)->test.o.c_id",
" │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.o_id, collate: binary], [name: test.o.c_id, collate: binary]",
" │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:count(1)->Column#9",
" │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo",
" └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))",
" └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo"
],
"Warn": null
},
{
"SQL": "EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column",
"Plan": [
"TableReader_78 8000.00 root MppVersion: 1, data:ExchangeSender_77",
"└─ExchangeSender_77 8000.00 mpp[tiflash] ExchangeType: PassThrough",
"TableReader_84 8000.00 root MppVersion: 1, data:ExchangeSender_83",
"└─ExchangeSender_83 8000.00 mpp[tiflash] ExchangeType: PassThrough",
" └─Projection_10 8000.00 mpp[tiflash] test.o.c_id, Column#6",
" └─Projection_76 8000.00 mpp[tiflash] Column#6, test.o.c_id",
" └─HashAgg_75 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.c_id",
" └─ExchangeReceiver_71 9990.00 mpp[tiflash] ",
" └─ExchangeSender_70 9990.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]",
" └─HashJoin_69 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
" ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST",
" │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.c_id",
" │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.c_id)->Column#8, funcs:firstrow(test.o.c_id)->test.o.c_id",
" │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]",
" │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(1)->Column#9",
" │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo",
" └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))",
" └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo"
" └─Projection_79 8000.00 mpp[tiflash] Column#6, test.o.c_id",
" └─HashAgg_80 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:sum(Column#21)->Column#6, funcs:firstrow(Column#22)->test.o.c_id",
" └─ExchangeReceiver_82 8000.00 mpp[tiflash] ",
" └─ExchangeSender_81 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]",
" └─HashAgg_76 8000.00 mpp[tiflash] group by:Column#25, funcs:sum(Column#23)->Column#21, funcs:firstrow(Column#24)->Column#22",
" └─Projection_85 9990.00 mpp[tiflash] cast(Column#7, decimal(20,0) BINARY)->Column#23, Column#8, test.o.c_id",
" └─HashJoin_78 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
" ├─ExchangeReceiver_34(Build) 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_33 8000.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST",
" │ └─Projection_29 8000.00 mpp[tiflash] Column#7, Column#8, test.o.c_id",
" │ └─HashAgg_30 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:sum(Column#9)->Column#7, funcs:firstrow(test.o.c_id)->Column#8, funcs:firstrow(test.o.c_id)->test.o.c_id",
" │ └─ExchangeReceiver_32 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_31 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.o.c_id, collate: binary]",
" │ └─HashAgg_21 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(1)->Column#9",
" │ └─TableFullScan_28 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo",
" └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))",
" └─TableFullScan_18 10000.00 mpp[tiflash] table:c pushed down filter:empty, keep order:false, stats:pseudo"
],
"Warn": null
},
{
"SQL": "EXPLAIN format='brief' select a, count(*) from (select a, b from t union all select a, b from t) t group by a order by a limit 10",
"Plan": [
"Projection 10.00 root Column#7, Column#9",
"└─TopN 10.00 root Column#7, offset:0, count:10",
" └─TableReader 10.00 root MppVersion: 1, data:ExchangeSender",
" └─ExchangeSender 10.00 mpp[tiflash] ExchangeType: PassThrough",
" └─TopN 10.00 mpp[tiflash] Column#7, offset:0, count:10",
" └─Projection 16000.00 mpp[tiflash] Column#9, Column#7",
" └─HashAgg 16000.00 mpp[tiflash] group by:Column#40, funcs:sum(Column#38)->Column#9, funcs:firstrow(Column#39)->Column#7",
" └─Projection 16000.00 mpp[tiflash] cast(Column#10, decimal(20,0) BINARY)->Column#38, Column#11, Column#7",
" └─ExchangeReceiver 16000.00 mpp[tiflash] ",
" └─ExchangeSender 16000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: Column#7, collate: binary]",
" └─Union 16000.00 mpp[tiflash] ",
" ├─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:sum(Column#30)->Column#10, funcs:firstrow(test.t.a)->Column#11, funcs:firstrow(test.t.a)->Column#7",
" │ └─ExchangeReceiver 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary]",
" │ └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:count(1)->Column#30",
" │ └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo",
" └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:sum(Column#33)->Column#10, funcs:firstrow(test.t.a)->Column#11, funcs:firstrow(test.t.a)->Column#7",
" └─ExchangeReceiver 8000.00 mpp[tiflash] ",
" └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: HashPartition, Compression: FAST, Hash Cols: [name: test.t.a, collate: binary]",
" └─HashAgg 8000.00 mpp[tiflash] group by:test.t.a, funcs:count(1)->Column#33",
" └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo"
],
"Warn": null
}
Expand Down
11 changes: 11 additions & 0 deletions planner/core/exhaust_physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -3151,6 +3151,16 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
// Is this aggregate a final stage aggregate?
// Final agg can't be split into multi-stage aggregate
hasFinalAgg := len(la.AggFuncs) > 0 && la.AggFuncs[0].Mode == aggregation.FinalMode
// count final agg should become sum for MPP execution path.
// In the traditional case, TiDB take up the final agg role and push partial agg to TiKV,
// while TiDB can tell the partialMode and do the sum computation rather than counting but MPP doesn't
finalAggAdjust := func(aggFuncs []*aggregation.AggFuncDesc) {
for i, agg := range aggFuncs {
if agg.Mode == aggregation.FinalMode && agg.Name == ast.AggFuncCount {
aggFuncs[i], _ = aggregation.NewAggFuncDesc(la.SCtx(), ast.AggFuncSum, agg.Args, false)
}
}
}

if len(la.GroupByItems) > 0 {
partitionCols := la.GetPotentialPartitionKeys()
Expand All @@ -3176,6 +3186,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp)
agg.SetSchema(la.schema.Clone())
agg.MppRunMode = Mpp1Phase
finalAggAdjust(agg.AggFuncs)
hashAggs = append(hashAggs, agg)
}

Expand Down

0 comments on commit e07cf40

Please sign in to comment.