Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Merge branch 'PaddlePaddle:develop' into sort
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 authored Sep 9, 2022
2 parents ccedd64 + b9d0896 commit d118a13
Show file tree
Hide file tree
Showing 42 changed files with 1,023 additions and 223 deletions.
Empty file modified cinn/backends/codegen_cuda_host.cc
100644 → 100755
Empty file.
13 changes: 5 additions & 8 deletions cinn/backends/codegen_cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,12 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {

private:
void Visit(const ir::_LoweredFunc_* op, Expr* expr) override {
if (op->cuda_axis_info.valid()) {
CHECK(op->cuda_axis_info.valid());

auto host_func = CreateHostFunctionGivenDeviceKernel(op);
host_module_builder.AddFunction(host_func.as_lowered_func_ref());
device_module_builder.AddFunction(CreateDeviceFunctionGivenDeviceKernel(*expr).as_lowered_func_ref());
} else {
host_module_builder.AddFunction(expr->as_lowered_func_ref());
if (!op->cuda_axis_info.valid()) {
expr->as_lowered_func_ref()->cuda_axis_info.set_valid(true);
}
auto host_func = CreateHostFunctionGivenDeviceKernel(op);
host_module_builder.AddFunction(host_func.as_lowered_func_ref());
device_module_builder.AddFunction(CreateDeviceFunctionGivenDeviceKernel(*expr).as_lowered_func_ref());
}

/**
Expand Down
9 changes: 6 additions & 3 deletions cinn/backends/compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,12 @@ TEST(Compiler, sqrt) {
Placeholder<float> bias("bias", {C});
float epsilon = 0.1f;

auto A = Compute({N, C, H, W}, [=](Expr n, Expr c, Expr h, Expr w) {
return (input(n, c, h, w) - mean(c)) * scale(c) / lang::Sqrt(variance(c) + Expr(epsilon)) + bias(c);
});
auto A = Compute(
{N, C, H, W},
[=](Expr n, Expr c, Expr h, Expr w) {
return (input(n, c, h, w) - mean(c)) * scale(c) / lang::Sqrt(variance(c) + Expr(epsilon)) + bias(c);
},
"A");

auto B = hlir::pe::Pool2d(input, {3, 3}, {1, 1}, {1, 1, 1, 1}, "max", false, false);

Expand Down
3 changes: 2 additions & 1 deletion cinn/backends/llvm/codegen_x86_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ TEST(Vectorize, basic) {
Placeholder<float> A("A", {M});
Placeholder<float> B("B", {M});

auto C = Compute({M}, [&](Expr i) { return A(i) + B(i); });
auto C = Compute(
{M}, [&](Expr i) { return A(i) + B(i); }, "C");
auto stages = CreateStages({C});

stages[C]->Vectorize(0, 8);
Expand Down
26 changes: 13 additions & 13 deletions cinn/common/cas.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace common {
using namespace ir; // NOLINT

Expr AutoSimplify(Expr u, const absl::flat_hash_map<std::string, CasInterval>& var_intervals) {
VLOG(3) << "Begin AutoSimplify: " << u;
VLOG(7) << "Begin AutoSimplify: " << u;
u = detail::ConvertCinnToCAS(u);
absl::flat_hash_map<std::string, CasInterval> s_var_intervals;
for (auto& item : var_intervals) {
Expand All @@ -48,7 +48,7 @@ Expr AutoSimplify(Expr u, const absl::flat_hash_map<std::string, CasInterval>& v
}
u = CasSimplify(u, s_var_intervals);
u = detail::ConvertCasToCinn(u);
VLOG(3) << "End AutoSimplify " << u;
VLOG(7) << "End AutoSimplify " << u;
return u;
}

Expand Down Expand Up @@ -407,7 +407,7 @@ double EvaluatePower(Expr u) {
// Order, reference to Page 85.
bool ExprPosCmp::operator()(const Expr& a, const Expr& b) {
// O-1, 1 <| 2
VLOG(3) << "Begin ExprPosCmp, a: " << a << ", b: " << b;
VLOG(7) << "Begin ExprPosCmp, a: " << a << ", b: " << b;
if (a.is_constant() && b.is_constant()) {
return a.get_constant() < b.get_constant();
}
Expand Down Expand Up @@ -743,7 +743,7 @@ Expr CasSimplifyMutator::SimplifyProduct(Expr a) {
for (auto& v : operands) {
ss << v << " ";
}
VLOG(6) << "operands: " << ss.str();
VLOG(7) << "operands: " << ss.str();
};
#endif

Expand Down Expand Up @@ -822,11 +822,11 @@ std::vector<Expr> CasSimplifyMutator::MergeSum(const std::vector<Expr>& p, const
std::stringstream ss;
for (auto& x : p) ss << x << " ";

VLOG(6) << "MergeSum p(" << ss.str() << ")";
VLOG(7) << "MergeSum p(" << ss.str() << ")";
ss.str("");

for (auto& x : q) ss << x << " ";
VLOG(6) << "MergeSum q(" << ss.str() << ")";
VLOG(7) << "MergeSum q(" << ss.str() << ")";
ss.str("");
}
#endif
Expand Down Expand Up @@ -909,8 +909,8 @@ std::vector<Expr> CasSimplifyMutator::SimplifyBinarySum(Expr left, Expr right) {
auto a_non_constant = ProductGetNonConstantPart(a);
auto b_non_constant = ProductGetNonConstantPart(b);
if (a_non_constant.defined() && b_non_constant.defined() && a_non_constant == b_non_constant) {
VLOG(3) << "a " << a;
VLOG(3) << "b " << b;
VLOG(7) << "a " << a;
VLOG(7) << "b " << b;
Expr s = SimplifySum(Sum::Make({ProductGetConstantPart(a), ProductGetConstantPart(b)}));
Expr p = Product::Make({s, ProductGetNonConstantPart(a)});
return {CasSimplify(p, var_intervals)};
Expand Down Expand Up @@ -961,7 +961,7 @@ std::vector<Expr> CasSimplifyMutator::SimplifySumRec(const std::vector<Expr>& op
for (auto& o : operands) {
ss << o.node_type() << " " << o << " ";
}
VLOG(6) << "SimplifySumRec operands: " << ss.str();
VLOG(7) << "SimplifySumRec operands: " << ss.str();
}
#endif
CHECK(!operands.empty());
Expand Down Expand Up @@ -1649,7 +1649,7 @@ bool CASasSymbol(Expr expr) {
}

Expr ConvertCinnToCAS(Expr expr) {
VLOG(3) << "Begin ConvertCinnToCAS " << expr;
VLOG(7) << "Begin ConvertCinnToCAS " << expr;
Expr copied = optim::IRCopy(expr);
struct Mutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* expr) { Visit(expr); }
Expand Down Expand Up @@ -1842,7 +1842,7 @@ Expr ReplaceMaxToConstant(Expr expr) {
}

Expr ConvertCasToCinn(Expr expr) {
VLOG(3) << "Begin ConvertCasToCinn : " << expr;
VLOG(7) << "Begin ConvertCasToCinn : " << expr;
Expr copied = optim::IRCopy(expr);

struct Mutator : ir::IRMutator<Expr*> {
Expand Down Expand Up @@ -2051,7 +2051,7 @@ Expr CasSimplifyMutator::FurtherSimplifyFracWithInterval(
auto it = var_intervals.find(bv->name);
auto ai_abs = std::abs(ai->value);
if (it != var_intervals.end()) {
VLOG(3) << "found " << bv->name << " " << it->second << " "
VLOG(7) << "found " << bv->name << " " << it->second << " "
<< " ai " << ai_abs;
}
if (it != var_intervals.end() && std::abs(it->second.r) > ai_abs && std::abs(it->second.l) > ai_abs) {
Expand Down Expand Up @@ -2089,7 +2089,7 @@ Expr SimplifyConstantFrac(FracOp* node) {
}

Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) {
VLOG(3) << "CAS simplify Frac " << expr;
VLOG(7) << "CAS simplify Frac " << expr;
auto* node = expr.As<FracOp>();
auto a = CasSimplify(node->a(), var_intervals);
auto b = CasSimplify(node->b(), var_intervals);
Expand Down
6 changes: 4 additions & 2 deletions cinn/common/cas_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,10 @@ TEST(CAS, ConvertCinnToCAS) {
Placeholder<float> A("A", {10, 10});
Placeholder<float> B("B", {10, 10});

auto C = Compute({Expr(10), Expr(10)},
[&](Expr i, Expr j) { return A(i, j) + 0.f + 1.f + 2.f * B(i, j) + 0.f * B(i, j) * A(i, j); });
auto C = Compute(
{Expr(10), Expr(10)},
[&](Expr i, Expr j) { return A(i, j) + 0.f + 1.f + 2.f * B(i, j) + 0.f * B(i, j) * A(i, j); },
"C");

Expr body = C->body();
LOG(INFO) << "body " << body;
Expand Down
Empty file modified cinn/common/ir_util.h
100644 → 100755
Empty file.
4 changes: 4 additions & 0 deletions cinn/common/type.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ Type Type::ConstOf() const {
return x;
}

bool Type::is_supported() const {
return (*this == Float(32) || this->is_bool() || *this == Int(32) || *this == Int(64));
}

Type Type::IgnoreConst() const {
CheckTypeValid();
auto x = *this;
Expand Down
2 changes: 2 additions & 0 deletions cinn/common/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ struct Type {
Type IgnoreConst() const;
//! Add const.
Type ConstOf() const;
//! Check if a dtype is supported in CINN yet.
bool is_supported() const;

friend std::ostream& operator<<(std::ostream& os, const Type& t);

Expand Down
Empty file modified cinn/frontend/net_builder_test.cc
100644 → 100755
Empty file.
17 changes: 7 additions & 10 deletions cinn/hlir/framework/graph_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,15 +288,14 @@ std::vector<ir::LoweredFunc> GraphCompiler::GetOpFuncWithIRSchedule(
std::vector<ir::Tensor> tensor_inputs;
std::vector<common::CINNValue> cinn_inputs;
std::vector<std::string> input_output_nodes;
VLOG(3) << "GetOpFunc of op " << node->id();
VLOG(3) << "GetOpFunc of op " << node->id() << " with op type " << node->op()->name;

// 1.Collect inputs info and outputs info
for (auto& i : node->inlinks_in_order(true)) {
std::string id = i->source()->as<NodeData>()->id();
auto shape = shape_dict_.at(id);
Type dtype = type_dict_.at(id);
CHECK(dtype == Float(32) || dtype.is_bool() || dtype == Int(32) || dtype == Int(64))
<< "The dtype of node " << id << " is not float or bool or int! Other dtype is not implemented yet.";
CHECK(dtype.is_supported()) << "Node " << id << " 's dtype " << dtype << "is not supported yet!";
ir::Tensor input;
if (dtype == Float(32)) {
input = lang::Placeholder<float>(id, shape);
Expand Down Expand Up @@ -345,8 +344,7 @@ std::vector<ir::LoweredFunc> GraphCompiler::GetOpFunc(const Node* node) {
std::string input_id = i->source()->as<NodeData>()->id();
auto in_shape = shape_dict.at(input_id);
Type dtype = dtype_dict.at(input_id);
CHECK(dtype == Float(32) || dtype.is_bool() || dtype == Int(32) || dtype == Int(64))
<< "The dtype of node " << input_id << " is not float or bool or int! Other dtype is not implemented yet.";
CHECK(dtype.is_supported()) << "Node " << input_id << " 's dtype " << dtype << "is not supported yet!";
ir::Tensor temp;
if (dtype == Float(32)) {
temp = lang::Placeholder<float>(input_id, in_shape);
Expand Down Expand Up @@ -449,8 +447,7 @@ std::vector<ir::LoweredFunc> GraphCompiler::GetOpFunc(const std::vector<Node*>&
std::string input_id = source_data->id();
auto in_shape = shape_dict.at(input_id);
Type dtype = dtype_dict.at(input_id);
CHECK(dtype == Float(32) || dtype.is_bool() || dtype == Int(32) || dtype == Int(64))
<< "The dtype of node " << input_id << " is not float or bool or int! Other dtype is not implemented yet.";
CHECK(dtype.is_supported()) << "Node " << input_id << " 's dtype " << dtype << "is not supported yet!";
ir::Tensor temp_in;
if (dtype == Float(32)) {
temp_in = lang::Placeholder<float>(input_id, in_shape);
Expand Down Expand Up @@ -1552,15 +1549,15 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl(const std::shared_ptr<OpImpl>& impl
auto new_args = lang::GetArgs(funcs[i]->body, input_output_nodes);
funcs[i]->args = new_args;
}
#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&(funcs[i]->body));
#endif
auto temp_buffers = lang::GetTempBuffers(all_arg_tensors, stages, funcs[i]->body);
funcs[i]->temp_bufs = temp_buffers;
funcs[i]->PrepareBufferCastExprs();
res.push_back(funcs[i]);
}
for (int i = 0; i < res.size(); i++) {
#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&(res[i]->body));
#endif
res[i] = optim::Optimize(Expr(res[i]), target, false).as_lowered_func_ref();
}

Expand Down
Loading

0 comments on commit d118a13

Please sign in to comment.