diff --git a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.cc index 3fa26f51b5592..326b2126758ed 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.cc @@ -61,26 +61,8 @@ pir::Operation* ProcessDyShapeGroup( } else { // no condition block // compile group to jit_kernel_op std::vector output_types; - const auto& group_output_values = group->output_values(); - for (size_t i = 0; i < group_output_values.size(); ++i) { - auto base_type = - group_output_values[i].type().dyn_cast<::pir::DenseTensorType>(); - auto dim_info = base_type.dims(); - if (shape_analysis.HasShapeOrDataForValue(group_output_values[i])) { - auto shape = group->GetShapeOrDataExprs(group_output_values[i]).shape(); - for (size_t k = 0; k < shape.size(); ++k) { - if (shape[k].isa()) { - dim_info[k] = shape[k].Get(); - } - } - } - auto new_type = ::pir::DenseTensorType::get(pir::IrContext::Instance(), - base_type.dtype(), - dim_info, - base_type.data_layout(), - base_type.lod(), - base_type.offset()); - output_types.push_back(new_type); + for (const auto& value : group->output_values()) { + output_types.push_back(value.type()); } auto jit_kernel_op = rewriter.Build( group_inputs, GetJitKernelAttr(group), output_types); diff --git a/paddle/cinn/hlir/framework/pir/fusion_info.cc b/paddle/cinn/hlir/framework/pir/fusion_info.cc index c8c3d1b766829..660c9e487ec4b 100644 --- a/paddle/cinn/hlir/framework/pir/fusion_info.cc +++ b/paddle/cinn/hlir/framework/pir/fusion_info.cc @@ -112,6 +112,11 @@ std::ostream& operator<<(std::ostream& os, const FusionOpInfo& info) { } FusionInfo::FusionInfo(const OpLoweringGroup& group) { + ParseOpInfos(group); + ParseInputDimExprs(group); +} + +void FusionInfo::ParseOpInfos(const OpLoweringGroup& group) { std::unordered_map op_mapper; unique_fn_name_ = group.FuncName(); @@ -141,15 +146,37 @@ FusionInfo::FusionInfo(const OpLoweringGroup& group) { op_infos_.emplace_back(*op, GetInnerUpstreamOps(op)); op_mapper.insert({op, i}); } - auto& shape_analysis = - ::pir::ShapeAnalysisManager::Instance().Get(group.GetParentProgram()); - for (const auto& value : group.GetInputOpValues()) { +} + +void FusionInfo::ParseInputDimExprs(const OpLoweringGroup& group) { + // NOTE(Aurelius84): [Why try get DimExpr from Group firstly? ] + // In case of BroadcastTree, we will clone many Groups containing same ops. + // But its input valus is defining outside and will have same DimExprs in + // global ShapeAnalysis, which leading hash conflict unexpected. + const auto TryGetDimExprsFromGroup = [&](const ::pir::Value& value) -> bool { + if (!group.HasShapeOrDataExprs(value)) return false; + input_dim_exprs_.push_back(group.GetShapeOrDataExprs(value)); + return true; + }; + // NOTE(Aurelius84): If we can't get DimExpr from Group, we will find them + // from global ShapeAnalysis. + const auto TryeGetDimExprsFromGlobal = + [&](const ::pir::Value& value) -> bool { + auto& shape_analysis = + ::pir::ShapeAnalysisManager::Instance().Get(group.GetParentProgram()); if (!shape_analysis.HasShapeOrDataForValue(value)) { VLOG(4) << "FusionInfo: input value doesn't have shape or data, skip it." << value.impl(); - continue; + return false; } input_dim_exprs_.push_back(shape_analysis.GetShapeOrDataForValue(value)); + return true; + }; + + for (const auto& value : group.GetInputOpValues()) { + if (!TryGetDimExprsFromGroup(value)) { + TryGetDimExprsFromGroup(value); + } } } diff --git a/paddle/cinn/hlir/framework/pir/fusion_info.h b/paddle/cinn/hlir/framework/pir/fusion_info.h index 04e482ba4c922..8290ef0c7d259 100644 --- a/paddle/cinn/hlir/framework/pir/fusion_info.h +++ b/paddle/cinn/hlir/framework/pir/fusion_info.h @@ -90,6 +90,9 @@ class FusionInfo { friend std::ostream &operator<<(std::ostream &os, const FusionInfo &info); private: + void ParseOpInfos(const OpLoweringGroup &group); + void ParseInputDimExprs(const OpLoweringGroup &group); + std::vector op_infos_; std::vector<::symbol::ShapeOrDataDimExprs> input_dim_exprs_; std::size_t cached_hash_value_{0}; diff --git a/test/ir/pir/cinn/inference/test_llama_forward.py b/test/ir/pir/cinn/inference/test_llama_forward.py index 51381d59e6d95..eb41f6ce3f941 100644 --- a/test/ir/pir/cinn/inference/test_llama_forward.py +++ b/test/ir/pir/cinn/inference/test_llama_forward.py @@ -87,8 +87,6 @@ def eval(self, use_cinn): return out def test_eval(self): - # TODO(Aurelius84):disable compilation cache - paddle.set_flags({"FLAGS_enable_cinn_compile_cache": False}) dy_out = self.eval(use_cinn=False) cinn_out = self.eval(use_cinn=True) np.testing.assert_allclose( diff --git a/test/ir/pir/cinn/inference/test_llama_inference.py b/test/ir/pir/cinn/inference/test_llama_inference.py index 092a23edbfd27..20c0e88395861 100644 --- a/test/ir/pir/cinn/inference/test_llama_inference.py +++ b/test/ir/pir/cinn/inference/test_llama_inference.py @@ -190,7 +190,6 @@ def test_eval(self): paddle.set_flags( { "FLAGS_prim_forward_blacklist": "pd_op.embedding;pd_op.softmax", - "FLAGS_enable_cinn_compile_cache": False, } ) cinn_out = self.eval(use_cinn=True) diff --git a/test/ir/pir/cinn/inference/test_llama_postprocess.py b/test/ir/pir/cinn/inference/test_llama_postprocess.py index 6fc17b6d19ae7..b8bdb1f0224ec 100644 --- a/test/ir/pir/cinn/inference/test_llama_postprocess.py +++ b/test/ir/pir/cinn/inference/test_llama_postprocess.py @@ -109,8 +109,6 @@ def eval(self, use_cinn): return out def test_eval(self): - # TODO(Aurelius84):disable compilation cache - paddle.set_flags({"FLAGS_enable_cinn_compile_cache": False}) dy_out = self.eval(use_cinn=False) cinn_out = self.eval(use_cinn=True) # TODO(Aurelius84): fix the precision with inf