Skip to content

Commit

Permalink
[gluten-3865]
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Nov 28, 2023
1 parent b707ba6 commit 959febf
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite {
| create table test_tbl(id bigint, name string) using parquet;
|""".stripMargin
)
val sql1 = "select count(1), sum(id), max(id), min(id) from test_tbl"
val sql1 = "select count(1), sum(id), max(id), min(id), 'abc' as x from test_tbl"
val sql2 =
"select count(1) as cnt, sum(id) as sum, max(id) as max, min(id) as min from test_tbl"
compareResultsAgainstVanillaSpark(sql1, true, { _ => })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "EmptyHashAggregate.h"
#include "DefaultHashAggregateResult.h"
#include <memory>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnsNumber.h>
Expand Down Expand Up @@ -46,13 +46,28 @@ static DB::ITransformingStep::Traits getTraits()
}};
}

/// Always return a block with one row. Don't care what is in it.
class EmptyHashAggregate : public DB::IProcessor
/// A more special case, the aggregate functions is also empty.
/// We add a fake block to downstream.
DB::Block adjustOutputHeader(const DB::Block & original_block)
{
public:
explicit EmptyHashAggregate(const DB::Block & input_) : DB::IProcessor({input_}, {BlockUtil::buildRowCountHeader()}) { }
~EmptyHashAggregate() override = default;
if (original_block)
return original_block;
return BlockUtil::buildRowCountHeader();
}

class DefaultHashAggrgateResultTransform : public DB::IProcessor
{
public:
explicit DefaultHashAggrgateResultTransform(const DB::Block & input_) : DB::IProcessor({input_}, {adjustOutputHeader(input_)}), header(input_) { }
~DefaultHashAggrgateResultTransform() override = default;
void work() override
{
if (has_input)
{
has_input = false;
has_output = true;
}
}
Status prepare() override
{
auto & output = outputs.front();
Expand All @@ -66,59 +81,75 @@ class EmptyHashAggregate : public DB::IProcessor
{
if (output.canPush())
{

output.push(std::move(output_chunk));
has_output = false;
has_outputed = true;
return Status::PortFull;
}
return Status::PortFull;
}

if (has_input)
return Status::Ready;

if (input.isFinished())
{
if (has_outputed)
{
output.finish();
return Status::Finished;
}
if (!has_output)
DB::Columns result_cols;
if (header)
{
for (const auto & col : header.getColumnsWithTypeAndName())
{
auto result_col = col.type->createColumnConst(1, col.type->getDefault());
result_cols.emplace_back(result_col);
}
}
else
{
output_chunk = BlockUtil::buildRowCountChunk(1);
has_output = true;
auto cnt_chunk = BlockUtil::buildRowCountChunk(1);
result_cols = cnt_chunk.detachColumns();
}
has_input = true;
output_chunk = DB::Chunk(result_cols, 1);
return Status::Ready;
}

input.setNeeded();
if (input.hasData())
{
(void)input.pullData(true);
output_chunk = input.pull(true);
has_input = true;
return Status::Ready;
}
return Status::NeedData;
}
void work() override { }

DB::String getName() const override { return "EmptyHashAggregate"; }
}

DB::String getName() const override { return "DefaultHashAggrgateResultTransform"; }
private:
bool has_outputed = false;
DB::Block header;
bool has_input = false;
bool has_output = false;
bool has_outputed = false;
DB::Chunk output_chunk;
};

EmptyHashAggregateStep::EmptyHashAggregateStep(const DB::DataStream & input_stream_)
: DB::ITransformingStep(input_stream_, BlockUtil::buildRowCountHeader(), getTraits())
DefaultHashAggregateResultStep::DefaultHashAggregateResultStep(const DB::DataStream & input_stream_)
: DB::ITransformingStep(input_stream_, input_stream_.header, getTraits())
{
}

void EmptyHashAggregateStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & /*settings*/)
void DefaultHashAggregateResultStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & /*settings*/)
{
auto build_transform = [&](DB::OutputPortRawPtrs outputs)
{
DB::Processors new_processors;
for (auto & output : outputs)
{
auto op = std::make_shared<EmptyHashAggregate>(output->getHeader());
auto op = std::make_shared<DefaultHashAggrgateResultTransform>(output->getHeader());
new_processors.push_back(op);
DB::connect(*output, op->getInputs().front());
}
Expand All @@ -127,14 +158,14 @@ void EmptyHashAggregateStep::transformPipeline(DB::QueryPipelineBuilder & pipeli
pipeline.transform(build_transform);
}

void EmptyHashAggregateStep::describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const
void DefaultHashAggregateResultStep::describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const
{
if (!processors.empty())
DB::IQueryPlanStep::describePipeline(processors, settings);
}

void EmptyHashAggregateStep::updateOutputStream()
void DefaultHashAggregateResultStep::updateOutputStream()
{
createOutputStream(input_streams.front(), BlockUtil::buildRowCountHeader(), getDataStreamTraits());
createOutputStream(input_streams.front(), input_streams.front().header, getDataStreamTraits());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
#include <Processors/QueryPlan/ITransformingStep.h>
namespace local_engine
{
class EmptyHashAggregateStep : public DB::ITransformingStep

/// Special case: goruping keys is empty, and there is no input from updstream, but still need to return one default row.
class DefaultHashAggregateResultStep : public DB::ITransformingStep
{
public:
explicit EmptyHashAggregateStep(const DB::DataStream & input_stream_);
~EmptyHashAggregateStep() override = default;
explicit DefaultHashAggregateResultStep(const DB::DataStream & input_stream_);
~DefaultHashAggregateResultStep() override = default;

String getName() const override { return "EmptyHashAggregateStep"; }
String getName() const override { return "DefaultHashAggregateResultStep"; }

void transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & settings) override;
void describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const override;
Expand Down
23 changes: 11 additions & 12 deletions cpp-ch/local-engine/Parser/AggregateRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include <Common/StringUtils/StringUtils.h>
#include "DataTypes/IDataType.h"

#include <Operator/EmptyHashAggregate.h>
#include <Operator/DefaultHashAggregateResult.h>

namespace DB
{
Expand All @@ -52,17 +52,6 @@ AggregateRelParser::AggregateRelParser(SerializedPlanParser * plan_paser_) : Rel
DB::QueryPlanPtr AggregateRelParser::parse(DB::QueryPlanPtr query_plan, const substrait::Rel & rel, std::list<const substrait::Rel *> &)
{
setup(std::move(query_plan), rel);
LOG_TRACE(logger, "original header is: {}", plan->getCurrentDataStream().header.dumpStructure());
if (rel.aggregate().measures().empty()
&& (rel.aggregate().groupings().empty() || rel.aggregate().groupings()[0].grouping_expressions().empty()))
{
LOG_TRACE(&Poco::Logger::get("AggregateRelParser"), "Empty aggregate step");
auto empty_agg = std::make_unique<EmptyHashAggregateStep>(plan->getCurrentDataStream());
empty_agg->setStepDescription("Empty aggregate");
steps.push_back(empty_agg.get());
plan->addStep(std::move(empty_agg));
return std::move(plan);
}
addPreProjection();
LOG_TRACE(logger, "header after pre-projection is: {}", plan->getCurrentDataStream().header.dumpStructure());
if (has_final_stage)
Expand All @@ -77,6 +66,16 @@ DB::QueryPlanPtr AggregateRelParser::parse(DB::QueryPlanPtr query_plan, const su
addAggregatingStep();
LOG_TRACE(logger, "header after aggregating is: {}", plan->getCurrentDataStream().header.dumpStructure());
}

/// If the groupings is empty, we still need to return one row with default values even if the input is empty.
if (rel.aggregate().groupings().empty() || rel.aggregate().groupings()[0].grouping_expressions().empty())
{
LOG_TRACE(&Poco::Logger::get("AggregateRelParser"), "default aggregate step");
auto default_agg = std::make_unique<DefaultHashAggregateResultStep>(plan->getCurrentDataStream());
default_agg->setStepDescription("Default aggregate");
steps.push_back(default_agg.get());
plan->addStep(std::move(default_agg));
}
return std::move(plan);
}

Expand Down
37 changes: 0 additions & 37 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2169,10 +2169,6 @@ bool LocalExecutor::hasNext()
auto empty_block = header.cloneEmpty();
setCurrentBlock(empty_block);
has_next = executor->pull(currentBlock());
if (!has_next)
{
has_next = checkAndSetDefaultBlock(columns, has_next);
}
produce();
}
else
Expand Down Expand Up @@ -2236,39 +2232,6 @@ LocalExecutor::LocalExecutor(QueryContext & _query_context, ContextPtr context_)
{
}

bool LocalExecutor::checkAndSetDefaultBlock(size_t current_block_columns, bool has_next_blocks)
{
if (current_block_columns > 0 || has_next_blocks)
{
return has_next_blocks;
}
bool should_set_default_value = false;
for (auto p : query_pipeline.getProcessors())
{
if (p->getName() == "MergingAggregatedTransform")
{
DB::MergingAggregatedStep * agg_step = static_cast<DB::MergingAggregatedStep *>(p->getQueryPlanStep());
auto query_params = agg_step->getParams();
should_set_default_value = query_params.keys_size == 0;
}
}
if (!should_set_default_value)
return false;
auto cols = currentBlock().getColumnsWithTypeAndName();
for (size_t i = 0; i < cols.size(); i++)
{
const DB::ColumnWithTypeAndName col = cols[i];
String col_name = col.name;
DataTypePtr col_type = col.type;
const DB::ColumnPtr & default_col_ptr = col_type->createColumnConst(1, col_type->getDefault());
const DB::ColumnWithTypeAndName default_col(default_col_ptr, col_type, col_name);
currentBlock().setColumn(i, default_col);
}
if (cols.size() > 0)
return true;
return false;
}

NonNullableColumnsResolver::NonNullableColumnsResolver(
const DB::Block & header_, SerializedPlanParser & parser_, const substrait::Expression & cond_rel_)
: header(header_), parser(parser_), cond_rel(cond_rel_)
Expand Down
1 change: 0 additions & 1 deletion cpp-ch/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,6 @@ class LocalExecutor : public BlockIterator
private:
QueryContext query_context;
std::unique_ptr<SparkRowInfo> writeBlockToSparkRow(DB::Block & block);
bool checkAndSetDefaultBlock(size_t current_block_columns, bool has_next_blocks);
QueryPipeline query_pipeline;
std::unique_ptr<PullingPipelineExecutor> executor;
Block header;
Expand Down

0 comments on commit 959febf

Please sign in to comment.