Skip to content

Commit

Permalink
[PIR+CINN]Add FusionInfo interface and Polish CompilationCache
Browse files Browse the repository at this point in the history
fix compilation problem

fix conflict

fix conflict
  • Loading branch information
Aurelius84 committed Apr 7, 2024
1 parent 81ae815 commit f744be8
Show file tree
Hide file tree
Showing 17 changed files with 588 additions and 231 deletions.
8 changes: 4 additions & 4 deletions cmake/cinn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ cinn_cc_library(
${jitify_deps})
add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB)
add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ${core_deps})
target_link_libraries(cinnapi op_dialect pir phi)
add_dependencies(cinnapi op_dialect pir phi)
target_link_libraries(cinnapi op_dialect cinn_op_dialect pir phi)
add_dependencies(cinnapi op_dialect cinn_op_dialect pir phi)

target_link_libraries(cinnapi ${PYTHON_LIBRARIES})

Expand Down Expand Up @@ -229,8 +229,8 @@ function(gen_cinncore LINKTYPE)
${jitify_deps})
add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB)
add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ${core_deps})
target_link_libraries(${CINNCORE_TARGET} op_dialect pir phi)
add_dependencies(${CINNCORE_TARGET} op_dialect pir phi)
target_link_libraries(${CINNCORE_TARGET} op_dialect cinn_op_dialect pir phi)
add_dependencies(${CINNCORE_TARGET} op_dialect cinn_op_dialect pir phi)

# add_dependencies(${CINNCORE_TARGET} pybind)
target_link_libraries(${CINNCORE_TARGET} ${PYTHON_LIBRARIES})
Expand Down
32 changes: 26 additions & 6 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ pir::Block* GroupOp::block() {

pir::Block* GroupOp::block() const {
pir::Region& region = (*this)->region(0);
CHECK(!region.empty());
PADDLE_ENFORCE_EQ(region.empty(),
false,
::common::errors::Unavailable(
"Required GroupOp's region must not be emptpy."));
return &region.front();
}

Expand Down Expand Up @@ -156,7 +159,16 @@ pir::Block* FusionOp::block() {
return &region.front();
}

std::vector<pir::Operation*> FusionOp::GetOperators() {
pir::Block* FusionOp::block() const {
pir::Region& region = (*this)->region(0);
PADDLE_ENFORCE_EQ(region.empty(),
false,
::common::errors::Unavailable(
"Required FusionOp's region must not be emptpy."));
return &region.front();
}

std::vector<pir::Operation*> FusionOp::GetOperators() const {
std::vector<pir::Operation*> rt_ops;
for (auto& op : *block()) {
rt_ops.push_back(&op);
Expand Down Expand Up @@ -305,7 +317,9 @@ void GenerateShapeOp::Build(
if (inputs.empty()) {
VLOG(3) << "GenerateShapeOp inputs is empty";
for (const auto& attr : output_dim_exprs) {
CHECK(attr.isa<pir::Int64Attribute>());
PADDLE_ENFORCE(attr.isa<pir::Int64Attribute>(),
::common::errors::PreconditionNotMet(
"Reqiured attr must be Int64Attribute."));
}
}
argument.AddInputs(inputs);
Expand Down Expand Up @@ -467,11 +481,15 @@ bool GenerateShapeOp::InferSymbolicShape(
const auto attr_dim_exprs = [&] {
std::vector<symbol::DimExpr> dim_exprs{};
pir::Attribute dim_expr_attr = this->attributes().at("output_dim_exprs");
CHECK(dim_expr_attr.isa<pir::ArrayAttribute>());
PADDLE_ENFORCE(dim_expr_attr.isa<pir::ArrayAttribute>(),
::common::errors::PreconditionNotMet(
"Required dim_expr_attr is ArrayAttribute."));
auto array = dim_expr_attr.dyn_cast<pir::ArrayAttribute>();
for (int i = 0; i < array.size(); ++i) {
const auto& dim_expr = ConvertAttributeToDimExpr(array.at(i));
CHECK(dim_expr.has_value());
PADDLE_ENFORCE(dim_expr.has_value(),
::common::errors::PreconditionNotMet(
"Required dim_expr.has_value()==true."));
dim_exprs.push_back(dim_expr.value());
}
return dim_exprs;
Expand All @@ -481,7 +499,9 @@ bool GenerateShapeOp::InferSymbolicShape(
this->attributes().at("symbol_bindings");
auto symbol_bindings =
ConvertAttributeToSymbolBindings(symbol_bindings_attr);
CHECK(symbol_bindings.has_value());
PADDLE_ENFORCE(symbol_bindings.has_value(),
::common::errors::PreconditionNotMet(
"Required symbol_bindings.has_value()==true."));
return symbol_bindings.value();
}();
auto DimExprs4InputDim =
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class IR_API FusionOp : public pir::Op<FusionOp> {
const cinn::dialect::GroupInfo &group_info);

pir::Block *block();
std::vector<pir::Operation *> GetOperators();
pir::Block *block() const;

std::vector<pir::Operation *> GetOperators() const;

void VerifySig();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ pir::Operation* ProcessDyShapeGroup(
const OpLoweringGroupPtr& group,
pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT
pir::PatternRewriter& rewriter) { // NOLINT
// NOTE(dev): Need UpdateShapeOrDataExprs firstly.
group->UpdateShapeOrDataExprs();
auto group_inputs = GetBlockOutsideInput(group->ops());
GroupDimExprInfo group_dim_expr_info = GetGroupDimExprInfo(group);
const auto& leaves = group_dim_expr_info.all_value_dim_exprs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ CompileGroupAsOpAttribute(const std::vector<OpLoweringGroupPtr>& group_list) {

std::unordered_map<std::string, ::pir::Attribute> GetJitKernelAttr(
const OpLoweringGroupPtr& group) {
auto kernel_info = CompilationCache::Instance().GetKernelInfo(group);
hlir::framework::pir::FusionInfo fusion_info(*group);
auto kernel_info = CompilationCache::Instance().GetKernelInfo(fusion_info);
std::unordered_map<std::string, ::pir::Attribute> attrs{
{cinn::dialect::JitKernelOp::kAttrName,
cinn::dialect::CINNKernelInfoAttribute::get(pir::IrContext::Instance(),
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/hlir/framework/pir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ gather_srcs(
trivial_op_impl.cc
trivial_op_util.cc
compilation_task.cc
compilation_cache.cc)
compilation_cache.cc
fusion_info.cc)
41 changes: 9 additions & 32 deletions paddle/cinn/hlir/framework/pir/compilation_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,37 +39,19 @@ void* BackendResource::GetInferFuncPtr() const {
return ptr;
}

std::shared_ptr<backends::Compiler>& BackendResource::GetBackendCompiler() {
return backend_compiler_;
}

const std::shared_ptr<backends::Compiler>& BackendResource::GetBackendCompiler()
const {
return backend_compiler_;
}

void BackendResource::SetHostFnName(const std::string& name) {
host_fn_name_ = name;
}

void BackendResource::SetInferFnName(const std::string& name) {
infer_fn_name_ = name;
}

pir::CINNKernelInfo BackendResource::GernerateKernelInfo(
const std::shared_ptr<pir::OpLoweringGroup>& group) const {
pir::CINNKernelInfo BackendResource::GenerateKernelInfo() const {
pir::CINNKernelInfo kernel_info;
kernel_info.fn_name = host_fn_name_;
kernel_info.fn_ptr = GetHostFuncPtr();
kernel_info.infer_shape_fn_ptr = GetInferFuncPtr();
kernel_info.int_args_map = group->int_args_map();
kernel_info.int_args_map = GetIntArgsMap();
return kernel_info;
}
} // namespace pir

bool CompilationCache::Has(const CacheKey& key) const {
const bool has_existed = cache_.find(KeyHash(key)) != cache_.end();
VLOG(6) << "Check IsExisted in CompilationCache: " << key->FuncName() << " "
const bool has_existed = cache_.find(key) != cache_.end();
VLOG(6) << "Check IsExisted in CompilationCache: " << key << " "
<< has_existed;
return has_existed;
}
Expand All @@ -79,24 +61,19 @@ const CompilationCache::CacheValue& CompilationCache::Get(
PADDLE_ENFORCE_EQ(
Has(key),
true,
phi::errors::NotFound("%s is not in CompliatonCache.", key->FuncName()));
return cache_.at(KeyHash(key));
phi::errors::NotFound("%s is not in CompliatonCache.", key));
return cache_.at(key);
}

pir::CINNKernelInfo CompilationCache::GetKernelInfo(const CacheKey& key) const {
return Get(key)->GetKernelInfo(key);
return Get(key)->GetKernelInfo();
}

void CompilationCache::Insert(const CacheKey& key, const CacheValue& value) {
VLOG(6) << "Insert CompilationCache for: " << key->FuncName();
cache_.insert({KeyHash(key), value});
VLOG(6) << "Insert CompilationCache for: " << key;
cache_.insert({key, value});
}

void CompilationCache::Clear() { cache_.clear(); }

size_t CompilationCache::KeyHash(const CacheKey& key) const {
// TODO(Aurelius84): use a better hash function in next pr.
return std::hash<std::string>{}(key->FuncName());
}

} // namespace cinn::hlir::framework
59 changes: 31 additions & 28 deletions paddle/cinn/hlir/framework/pir/compilation_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/cinn/backends/compiler.h"
#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/hlir/framework/pir/fusion_info.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"

namespace cinn::hlir::framework {
Expand All @@ -27,61 +28,64 @@ namespace pir {
class OpLoweringGroup;
class BackendResource final {
public:
BackendResource(const Target& target) {
backend_compiler_ = backends::Compiler::Create(target);
}

BackendResource(const Target& target,
const std::string& host_fn_name,
const std::string& infer_fn_name)
: host_fn_name_(host_fn_name), infer_fn_name_(infer_fn_name) {
const std::string& infer_fn_name,
const std::map<int, CINNKernelInfo::ArgDimIdx>& int_args_map)
: host_fn_name_(host_fn_name),
infer_fn_name_(infer_fn_name),
int_args_map_(int_args_map) {
backend_compiler_ = backends::Compiler::Create(target);
}

void* GetHostFuncPtr() const;
void* GetInferFuncPtr() const;
pir::CINNKernelInfo GernerateKernelInfo(
const std::shared_ptr<pir::OpLoweringGroup>& group) const;
std::shared_ptr<backends::Compiler>& GetBackendCompiler();
const std::shared_ptr<backends::Compiler>& GetBackendCompiler() const;
void SetHostFnName(const std::string& name);
void SetInferFnName(const std::string& name);
const std::map<int, CINNKernelInfo::ArgDimIdx>& GetIntArgsMap() const {
return int_args_map_;
}
const std::shared_ptr<backends::Compiler>& GetBackendCompiler() const {
return backend_compiler_;
}
pir::CINNKernelInfo GenerateKernelInfo() const;

private:
std::string host_fn_name_;
std::string infer_fn_name_;
// std::string host_code_;
// std::vector<std::string> device_code_;
std::shared_ptr<backends::Compiler> backend_compiler_;
std::map<int, CINNKernelInfo::ArgDimIdx> int_args_map_;

std::shared_ptr<backends::Compiler> backend_compiler_{nullptr};
};

class CompilationResult final {
public:
explicit CompilationResult(const Target& target)
: target_(target), backend_resource_(target) {}

BackendResource& MutableBackendResource() { return backend_resource_; }
const BackendResource& GetBackendResource() const {
explicit CompilationResult(const Target& target) : target_(target) {}
const std::shared_ptr<BackendResource>& GetBackendResource() const {
return backend_resource_;
}
pir::CINNKernelInfo GetKernelInfo(
const std::shared_ptr<pir::OpLoweringGroup>& group) {
return backend_resource_.GernerateKernelInfo(group);

void SetBackendResource(const std::shared_ptr<BackendResource>& other) {
backend_resource_ = other;
}

pir::CINNKernelInfo GetKernelInfo() {
// TODO(Aurelius84): add ENFORCE_NOT_NULL
return backend_resource_->GenerateKernelInfo();
}

private:
Target target_;
BackendResource backend_resource_;
std::shared_ptr<BackendResource> backend_resource_{nullptr};
};

} // namespace pir

class CompilationCache {
public:
using CacheKey = std::shared_ptr<pir::OpLoweringGroup>;
using CacheKey = pir::FusionInfo;
using CacheValue = std::shared_ptr<pir::CompilationResult>;

static CompilationCache& Instance() {
static CompilationCache instance;
thread_local static CompilationCache instance;
return instance;
}

Expand All @@ -90,13 +94,12 @@ class CompilationCache {
pir::CINNKernelInfo GetKernelInfo(const CacheKey& key) const;
void Insert(const CacheKey& key, const CacheValue& value);
void Clear();
size_t KeyHash(const CacheKey& key) const;

private:
CompilationCache() = default;
CINN_DISALLOW_COPY_AND_ASSIGN(CompilationCache);

std::unordered_map<size_t, CacheValue> cache_;
std::unordered_map<CacheKey, CacheValue> cache_;
};

} // namespace cinn::hlir::framework
40 changes: 14 additions & 26 deletions paddle/cinn/hlir/framework/pir/compilation_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,9 @@ std::string GroupCompilationContext::PrintPredicate2Funcs() const {
return ss.str();
}

void CompilationTask::operator()() {
VLOG(4) << "Run Compilation Task for : " << context_->group_.get();
if (CompilationCache::Instance().Has(context_->group_)) {
VLOG(4) << "Found cached kernel info for group: "
<< context_->group_->FuncName();
return;
}
std::shared_ptr<pir::CompilationResult> CompilationTask::operator()() {
Lowering();
CodegenAndJit();
return CodegenAndJit();
}

void CompilationTask::Lowering() {
Expand All @@ -62,7 +56,7 @@ void CompilationTask::Lowering() {
/* apply pass = */ true));
}

void CompilationTask::CodegenAndJit() {
std::shared_ptr<pir::CompilationResult> CompilationTask::CodegenAndJit() {
ir::Module::Builder builder(cinn::common::UniqName("module"),
context_->target_);
CHECK_EQ(context_->predicates_.size(), context_->lowered_funcs_.size());
Expand All @@ -74,27 +68,21 @@ void CompilationTask::CodegenAndJit() {
}
builder.SetInferShapeFunc(context_->infer_shape_lowered_func_);
ir::Module ir_module = builder.Build();
BuildPirCINNKernelInfo(ir_module);
}

pir::CINNKernelInfo CompilationTask::GetCINNKernelInfo() {
if (!CompilationCache::Instance().Has(context_->group_)) {
PADDLE_THROW(phi::errors::NotFound(
"Kernel info has been cached for current group."));
}
return CompilationCache::Instance().GetKernelInfo(context_->group_);
return BuildPirCINNKernelInfo(ir_module);
}

void CompilationTask::BuildPirCINNKernelInfo(const ir::Module& module) {
std::shared_ptr<pir::CompilationResult> CompilationTask::BuildPirCINNKernelInfo(
const ir::Module& module) {
auto compilation_result =
std::make_shared<pir::CompilationResult>(context_->target_);
pir::BackendResource& backend_resource =
compilation_result->MutableBackendResource();
backend_resource.GetBackendCompiler()->Build(module, "");
backend_resource.SetHostFnName(context_->group_->FuncName());
backend_resource.SetInferFnName(context_->group_->FuncName() +
"_infer_shape");
CompilationCache::Instance().Insert(context_->group_, compilation_result);
auto backend_resource = std::make_shared<pir::BackendResource>(
context_->target_,
context_->group_->FuncName(),
context_->group_->FuncName() + "_infer_shape",
context_->group_->int_args_map());
backend_resource->GetBackendCompiler()->Build(module, "");
compilation_result->SetBackendResource(backend_resource);
return compilation_result;
}

} // namespace framework
Expand Down
Loading

0 comments on commit f744be8

Please sign in to comment.