Skip to content

Commit

Permalink
[GLUTEN-1392][CH] Support new ExpandRel (#1432)
Browse files Browse the repository at this point in the history
What changes were proposed in this pull request?
support new ExpandRel introduced by #1361

(Fixes: #1392)

How was this patch tested?
unit tests
  • Loading branch information
exmy authored Apr 26, 2023
1 parent 1bb242a commit bd43690
Show file tree
Hide file tree
Showing 12 changed files with 232 additions and 133 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
# vscode config
.vscode

# vscode scala
.bloop
.metals

# Mobile Tools for Java (J2ME)
.mtj.tmp/

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ object CHBackendSettings extends BackendSettings with Logging {

override def supportExpandExec(): Boolean = true

override def supportNewExpandContract(): Boolean = true

override def excludeScanExecFromCollapsedStage(): Boolean =
SQLConf.get
.getConfString(GLUTEN_CLICKHOUSE_SEP_SCAN_RDD, GLUTEN_CLICKHOUSE_SEP_SCAN_RDD_DEFAULT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,52 @@ class GlutenClickHouseTPCHParquetSuite extends GlutenClickHouseTPCHAbstractSuite
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}

test("expand with nullable type not match") {
val sql =
"""
|select a, n_regionkey, n_nationkey from
|(select nvl(n_name, "aaaa") as a, n_regionkey, n_nationkey from nation)
|group by n_regionkey, n_nationkey
|grouping sets((a, n_regionkey, n_nationkey),(a, n_regionkey), (a))
|order by a, n_regionkey, n_nationkey
|""".stripMargin
runQueryAndCompare(sql)(checkOperatorMatch[ExpandExecTransformer])
}

test("expand col result") {
val sql =
"""
|select n_regionkey, n_nationkey, count(1) as cnt from nation
|group by n_regionkey, n_nationkey with rollup
|order by n_regionkey, n_nationkey, cnt
|""".stripMargin
runQueryAndCompare(sql)(checkOperatorMatch[ExpandExecTransformer])
}

test("expand with not nullable") {
val sql =
"""
|select a,b, sum(c) from
|(select nvl(n_nationkey, 0) as c, nvl(n_name, '') as b, nvl(n_nationkey, 0) as a from nation)
|group by a,b with rollup
|""".stripMargin
runQueryAndCompare(sql)(checkOperatorMatch[ExpandExecTransformer])
}

test("expand with function expr") {
val sql =
"""
|select
| n_name,
| count(distinct n_regionkey) as col1,
| count(distinct concat(n_regionkey, n_nationkey)) as col2
|from nation
|group by n_name
|order by n_name, col1, col2
|""".stripMargin
runQueryAndCompare(sql)(checkOperatorMatch[ExpandExecTransformer])
}

test("test 'position/locate'") {
runQueryAndCompare(
"""
Expand Down
45 changes: 13 additions & 32 deletions cpp-ch/local-engine/Operator/ExpandStep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,53 +32,34 @@ static DB::ITransformingStep::Traits getTraits()

ExpandStep::ExpandStep(
const DB::DataStream & input_stream_,
const std::vector<size_t> & aggregating_expressions_columns_,
const std::vector<std::set<size_t>> & grouping_sets_,
const std::string & grouping_id_name_)
const ExpandField & project_set_exprs_)
: DB::ITransformingStep(
input_stream_,
buildOutputHeader(input_stream_.header, aggregating_expressions_columns_, grouping_id_name_),
buildOutputHeader(input_stream_.header, project_set_exprs_),
getTraits())
, aggregating_expressions_columns(aggregating_expressions_columns_)
, grouping_sets(grouping_sets_)
, grouping_id_name(grouping_id_name_)
, project_set_exprs(project_set_exprs_)
{
header = input_stream_.header;
output_header = getOutputStream().header;
}

DB::Block ExpandStep::buildOutputHeader(
const DB::Block & input_header,
const std::vector<size_t> & aggregating_expressions_columns_,
const std::string & grouping_id_name_)
const ExpandField & project_set_exprs_)
{
DB::ColumnsWithTypeAndName cols;
std::set<size_t> agg_cols;
const auto & types = project_set_exprs_.getTypes();
const auto & names = project_set_exprs_.getNames();

for (size_t i = 0; i < input_header.columns(); ++i)
for (size_t i = 0; i < project_set_exprs_.getExpandCols(); ++i)
{
const auto & old_col = input_header.getByPosition(i);
if (i < aggregating_expressions_columns_.size())
{
// do nothing with the aggregating columns.
cols.push_back(old_col);
continue;
}
if (old_col.type->isNullable())
cols.push_back(old_col);
String col_name;
if (!names[i].empty())
col_name = names[i];
else
{
auto null_map = DB::ColumnUInt8::create(0, 0);
auto null_col = DB::ColumnNullable::create(old_col.column, std::move(null_map));
auto null_type = std::make_shared<DB::DataTypeNullable>(old_col.type);
cols.push_back(DB::ColumnWithTypeAndName(null_col, null_type, old_col.name));
}
col_name = "expand_" + std::to_string(i);
cols.push_back(DB::ColumnWithTypeAndName(types[i], col_name));
}

// add group id column
auto grouping_id_col = DB::ColumnInt64::create(0, 0);
auto grouping_id_type = std::make_shared<DB::DataTypeInt64>();
cols.emplace_back(DB::ColumnWithTypeAndName(std::move(grouping_id_col), grouping_id_type, grouping_id_name_));
return DB::Block(cols);
}

Expand All @@ -89,7 +70,7 @@ void ExpandStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB
DB::Processors new_processors;
for (auto & output : outputs)
{
auto expand_op = std::make_shared<ExpandTransform>(header, output_header, aggregating_expressions_columns, grouping_sets);
auto expand_op = std::make_shared<ExpandTransform>(header, output_header, project_set_exprs);
new_processors.push_back(expand_op);
DB::connect(*output, expand_op->getInputs().front());
}
Expand Down
12 changes: 4 additions & 8 deletions cpp-ch/local-engine/Operator/ExpandStep.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <Core/Block.h>
#include <Processors/QueryPlan/IQueryPlanStep.h>
#include <Processors/QueryPlan/ITransformingStep.h>
#include <Parser/ExpandField.h>

namespace local_engine
{
Expand All @@ -12,27 +13,22 @@ class ExpandStep : public DB::ITransformingStep
// The input stream should only contain grouping columns.
explicit ExpandStep(
const DB::DataStream & input_stream_,
const std::vector<size_t> & aggregating_expressions_columns_,
const std::vector<std::set<size_t>> & grouping_sets_,
const std::string & grouping_id_name_);
const ExpandField & project_set_exprs_);
~ExpandStep() override = default;

String getName() const override { return "ExpandStep"; }

void transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & settings) override;
void describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const override;
private:
std::vector<size_t> aggregating_expressions_columns;
std::vector<std::set<size_t>> grouping_sets;
std::string grouping_id_name;
ExpandField project_set_exprs;
DB::Block header;
DB::Block output_header;

void updateOutputStream() override;

static DB::Block buildOutputHeader(
const DB::Block & header,
const std::vector<size_t> & aggregating_expressions_columns_,
const std::string & grouping_id_name_);
const ExpandField & project_set_exprs_);
};
}
72 changes: 43 additions & 29 deletions cpp-ch/local-engine/Operator/ExpandTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
#include <Columns/ColumnsNumber.h>
#include <Columns/IColumn.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeNullable.h>
#include <Processors/IProcessor.h>

#include "Common/Exception.h"
#include <Common/logger_useful.h>
#include <Poco/Logger.h>

Expand All @@ -14,11 +16,9 @@ namespace local_engine
ExpandTransform::ExpandTransform(
const DB::Block & input_,
const DB::Block & output_,
const std::vector<size_t> & aggregating_expressions_columns_,
const std::vector<std::set<size_t>> & grouping_sets_)
const ExpandField & project_set_exprs_)
: DB::IProcessor({input_}, {output_})
, aggregating_expressions_columns(aggregating_expressions_columns_)
, grouping_sets(grouping_sets_)
, project_set_exprs(project_set_exprs_)
{}

ExpandTransform::Status ExpandTransform::prepare()
Expand Down Expand Up @@ -68,43 +68,57 @@ ExpandTransform::Status ExpandTransform::prepare()
void ExpandTransform::work()
{
assert(expanded_chunks.empty());
size_t agg_cols_size = aggregating_expressions_columns.size();
for (int set_id = 0; static_cast<size_t>(set_id) < grouping_sets.size(); ++set_id)
const auto & original_cols = input_chunk.getColumns();
size_t rows = input_chunk.getNumRows();

for (size_t i = 0; i < project_set_exprs.getExpandRows(); ++i)
{
const auto & sets = grouping_sets[set_id];
DB::Columns cols;
const auto & original_cols = input_chunk.getColumns();
for (size_t i = 0; i < original_cols.size(); ++i)
for (size_t j = 0; j < project_set_exprs.getExpandCols(); ++j)
{
const auto & original_col = original_cols[i];
size_t rows = original_col->size();
if (i < agg_cols_size)
{
cols.push_back(original_col);
continue;
}
// the output columns should all be nullable.
if (!sets.contains(i))
{
auto null_map = DB::ColumnUInt8::create(rows, 1);
auto col = DB::ColumnNullable::create(original_col, std::move(null_map));
cols.push_back(std::move(col));
}
else
const auto & type = project_set_exprs.getTypes()[j];
const auto & kind = project_set_exprs.getKinds()[i][j];
const auto & field = project_set_exprs.getFields()[i][j];

if (kind == EXPAND_FIELD_KIND_SELECTION)
{
if (original_col->isNullable())
const auto & original_col = original_cols[field.get<Int32>()];
if (type->isNullable() == original_col->isNullable())
{
cols.push_back(original_col);
else
}
else if (type->isNullable() && !original_col->isNullable())
{
auto null_map = DB::ColumnUInt8::create(rows, 0);
auto col = DB::ColumnNullable::create(original_col, std::move(null_map));
cols.push_back(std::move(col));
}
else
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR,
"Miss match nullable, column {} is nullable, but type {} is not nullable",
original_col->getName(), type->getName());
}
}
else
{
if (field.isNull())
{
// Add null column
auto null_map = DB::ColumnUInt8::create(rows, 1);
auto nested_type = DB::removeNullable(type);
auto col = DB::ColumnNullable::create(nested_type->createColumn()->cloneResized(rows), std::move(null_map));
cols.push_back(std::move(col));
}
else
{
// Add constant column: gid, gpos, etc.
auto col = type->createColumnConst(rows, field);
cols.push_back(std::move(col->convertToFullColumnIfConst()));
}
}
}
auto id_col = DB::DataTypeInt64().createColumnConst(input_chunk.getNumRows(), set_id);
cols.push_back(std::move(id_col));
expanded_chunks.push_back(DB::Chunk(cols, input_chunk.getNumRows()));
expanded_chunks.push_back(DB::Chunk(cols, rows));
}
has_output = true;
has_input = false;
Expand Down
7 changes: 3 additions & 4 deletions cpp-ch/local-engine/Operator/ExpandTransorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <Processors/Chunk.h>
#include <Processors/IProcessor.h>
#include <Processors/Port.h>
#include <Parser/ExpandField.h>
namespace local_engine
{
// For handling substrait expand node.
Expand All @@ -21,16 +22,14 @@ class ExpandTransform : public DB::IProcessor
ExpandTransform(
const DB::Block & input_,
const DB::Block & output_,
const std::vector<size_t> & aggregating_expressions_columns_,
const std::vector<std::set<size_t>> & grouping_sets_);
const ExpandField & project_set_exprs_);

Status prepare() override;
void work() override;

DB::String getName() const override { return "ExpandTransform"; }
private:
std::vector<size_t> aggregating_expressions_columns;
std::vector<std::set<size_t>> grouping_sets;
ExpandField project_set_exprs;
bool has_input = false;
bool has_output = false;

Expand Down
42 changes: 42 additions & 0 deletions cpp-ch/local-engine/Parser/ExpandField.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#pragma once

#include <Core/Field.h>
#include <DataTypes/IDataType.h>

namespace local_engine
{

enum ExpandFieldKind
{
EXPAND_FIELD_KIND_SELECTION,
EXPAND_FIELD_KIND_LITERAL,
};

class ExpandField
{
public:
ExpandField() = default;
ExpandField(
const std::vector<std::string> & names_,
const std::vector<DB::DataTypePtr> & types_,
const std::vector<std::vector<ExpandFieldKind>> & kinds_,
const std::vector<std::vector<DB::Field>> & fields_):
names(names_), types(types_), kinds(kinds_), fields(fields_)
{}

const std::vector<std::string> & getNames() const { return names; }
const std::vector<DB::DataTypePtr> & getTypes() const { return types; }
const std::vector<std::vector<ExpandFieldKind>> & getKinds() const { return kinds; }
const std::vector<std::vector<DB::Field>> & getFields() const { return fields; }

size_t getExpandRows() const { return kinds.size(); }
size_t getExpandCols() const { return types.size(); }

private:
std::vector<std::string> names;
std::vector<DB::DataTypePtr> types;
std::vector<std::vector<ExpandFieldKind>> kinds;
std::vector<std::vector<DB::Field>> fields;
};

}
Loading

0 comments on commit bd43690

Please sign in to comment.