Skip to content

Commit

Permalink
[PIR+CINN]Polish CompilationCache for Parsing GroupInputDimExprs in c…
Browse files Browse the repository at this point in the history
…ase of BroadcastTree (#63750)
  • Loading branch information
Aurelius84 authored Apr 23, 2024
1 parent 20481c7 commit 66818fb
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,8 @@ pir::Operation* ProcessDyShapeGroup(
} else { // no condition block
// compile group to jit_kernel_op
std::vector<pir::Type> output_types;
const auto& group_output_values = group->output_values();
for (size_t i = 0; i < group_output_values.size(); ++i) {
auto base_type =
group_output_values[i].type().dyn_cast<::pir::DenseTensorType>();
auto dim_info = base_type.dims();
if (shape_analysis.HasShapeOrDataForValue(group_output_values[i])) {
auto shape = group->GetShapeOrDataExprs(group_output_values[i]).shape();
for (size_t k = 0; k < shape.size(); ++k) {
if (shape[k].isa<int64_t>()) {
dim_info[k] = shape[k].Get<int64_t>();
}
}
}
auto new_type = ::pir::DenseTensorType::get(pir::IrContext::Instance(),
base_type.dtype(),
dim_info,
base_type.data_layout(),
base_type.lod(),
base_type.offset());
output_types.push_back(new_type);
for (const auto& value : group->output_values()) {
output_types.push_back(value.type());
}
auto jit_kernel_op = rewriter.Build<cinn::dialect::JitKernelOp>(
group_inputs, GetJitKernelAttr(group), output_types);
Expand Down
35 changes: 31 additions & 4 deletions paddle/cinn/hlir/framework/pir/fusion_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ std::ostream& operator<<(std::ostream& os, const FusionOpInfo& info) {
}

FusionInfo::FusionInfo(const OpLoweringGroup& group) {
ParseOpInfos(group);
ParseInputDimExprs(group);
}

void FusionInfo::ParseOpInfos(const OpLoweringGroup& group) {
std::unordered_map<const ::pir::Operation*, size_t> op_mapper;
unique_fn_name_ = group.FuncName();

Expand Down Expand Up @@ -141,15 +146,37 @@ FusionInfo::FusionInfo(const OpLoweringGroup& group) {
op_infos_.emplace_back(*op, GetInnerUpstreamOps(op));
op_mapper.insert({op, i});
}
auto& shape_analysis =
::pir::ShapeAnalysisManager::Instance().Get(group.GetParentProgram());
for (const auto& value : group.GetInputOpValues()) {
}

void FusionInfo::ParseInputDimExprs(const OpLoweringGroup& group) {
// NOTE(Aurelius84): [Why try get DimExpr from Group firstly? ]
// In case of BroadcastTree, we will clone many Groups containing same ops.
// But its input valus is defining outside and will have same DimExprs in
// global ShapeAnalysis, which leading hash conflict unexpected.
const auto TryGetDimExprsFromGroup = [&](const ::pir::Value& value) -> bool {
if (!group.HasShapeOrDataExprs(value)) return false;
input_dim_exprs_.push_back(group.GetShapeOrDataExprs(value));
return true;
};
// NOTE(Aurelius84): If we can't get DimExpr from Group, we will find them
// from global ShapeAnalysis.
const auto TryeGetDimExprsFromGlobal =
[&](const ::pir::Value& value) -> bool {
auto& shape_analysis =
::pir::ShapeAnalysisManager::Instance().Get(group.GetParentProgram());
if (!shape_analysis.HasShapeOrDataForValue(value)) {
VLOG(4) << "FusionInfo: input value doesn't have shape or data, skip it."
<< value.impl();
continue;
return false;
}
input_dim_exprs_.push_back(shape_analysis.GetShapeOrDataForValue(value));
return true;
};

for (const auto& value : group.GetInputOpValues()) {
if (!TryGetDimExprsFromGroup(value)) {
TryGetDimExprsFromGroup(value);
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions paddle/cinn/hlir/framework/pir/fusion_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class FusionInfo {
friend std::ostream &operator<<(std::ostream &os, const FusionInfo &info);

private:
void ParseOpInfos(const OpLoweringGroup &group);
void ParseInputDimExprs(const OpLoweringGroup &group);

std::vector<FusionOpInfo> op_infos_;
std::vector<::symbol::ShapeOrDataDimExprs> input_dim_exprs_;
std::size_t cached_hash_value_{0};
Expand Down
2 changes: 0 additions & 2 deletions test/ir/pir/cinn/inference/test_llama_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ def eval(self, use_cinn):
return out

def test_eval(self):
# TODO(Aurelius84):disable compilation cache
paddle.set_flags({"FLAGS_enable_cinn_compile_cache": False})
dy_out = self.eval(use_cinn=False)
cinn_out = self.eval(use_cinn=True)
np.testing.assert_allclose(
Expand Down
1 change: 0 additions & 1 deletion test/ir/pir/cinn/inference/test_llama_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def test_eval(self):
paddle.set_flags(
{
"FLAGS_prim_forward_blacklist": "pd_op.embedding;pd_op.softmax",
"FLAGS_enable_cinn_compile_cache": False,
}
)
cinn_out = self.eval(use_cinn=True)
Expand Down
2 changes: 0 additions & 2 deletions test/ir/pir/cinn/inference/test_llama_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ def eval(self, use_cinn):
return out

def test_eval(self):
# TODO(Aurelius84):disable compilation cache
paddle.set_flags({"FLAGS_enable_cinn_compile_cache": False})
dy_out = self.eval(use_cinn=False)
cinn_out = self.eval(use_cinn=True)
# TODO(Aurelius84): fix the precision with inf
Expand Down

0 comments on commit 66818fb

Please sign in to comment.