Skip to content

Commit

Permalink
refine lower (PaddlePaddle#110)
Browse files Browse the repository at this point in the history
* code clean

* make Lower's return from vector to LowerFunc

* move optime to Lower
  • Loading branch information
Superjomn authored Mar 30, 2020
1 parent 385c600 commit 91de356
Show file tree
Hide file tree
Showing 15 changed files with 188 additions and 157 deletions.
150 changes: 111 additions & 39 deletions cinn/backends/codegen_c_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ TEST(CodeGenC, module) {
Module module("module1", target);

auto funcs = Lower("add1", {A, B, C});
ASSERT_EQ(funcs.size(), 1UL);

module.Append(funcs.front());
module.Append(funcs);
module.Append(C_buf);

{
Expand All @@ -69,8 +68,8 @@ void add1(const struct cinn_buffer_t *_A, const struct cinn_buffer_t *_B, struct
const float* A = (const float*)(cinn_buffer_get_data_const_handle(_A));
const float* B = (const float*)(cinn_buffer_get_data_const_handle(_B));
float* C = (float*)(cinn_buffer_get_data_handle(_C));
for (int32_t i = 0; (i <= 99); i += 1) {
for (int32_t j = 0; (j <= 19); j += 1) {
for (int32_t i = 0; i < 100; i += 1) {
for (int32_t j = 0; j < 20; j += 1) {
C[((20 * i) + j)] = (A[((20 * i) + j)] + B[((20 * i) + j)]);
};
};
Expand Down Expand Up @@ -137,9 +136,7 @@ TEST(CodeGenC, module_with_transform) {

auto funcs = Lower("add1", {A, B, C, D});

ASSERT_EQ(funcs.size(), 1UL);

Expr func(funcs.front());
Expr func(funcs);
optim::Simplify(&func);

module.Append(ir::LoweredFunc(func.As<ir::_LoweredFunc_>()));
Expand All @@ -163,18 +160,23 @@ void add1(const struct cinn_buffer_t *_A, const struct cinn_buffer_t *_B, struct
const float* B = (const float*)(cinn_buffer_get_data_const_handle(_B));
float* C = (float*)(cinn_buffer_get_data_handle(_C));
float* D = (float*)(cinn_buffer_get_data_handle(_D));
for (int32_t i_outer = 0; (i_outer <= 24); i_outer += 1) {
for (int32_t i_inner = 0; (i_inner <= 3); i_inner += 1) {
for (int32_t j = 0; (j <= 19); j += 1) {
for (int32_t i_outer = 0; i_outer < 25; i_outer += 1) {
for (int32_t i_inner = 0; i_inner < 4; i_inner += 1) {
for (int32_t j = 0; j < 20; j += 1) {
C[((20 * i_inner) + ((80 * i_outer) + j))] = (1 + ((3 * A[((20 * i_inner) + ((80 * i_outer) + j))]) + B[((20 * i_inner) + ((80 * i_outer) + j))]));
};
};
};
for (int32_t i_outer = 0; (i_outer <= 24); i_outer += 1) {
for (int32_t i_inner = 0; (i_inner <= 3); i_inner += 1) {
for (int32_t j_outer = 0; (j_outer <= 1); j_outer += 1) {
for (int32_t j_inner = 0; (j_inner <= min(15, ((-16 * j_outer) + 19))); j_inner += 1) {
D[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))] = ((2 + (4 * A[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))])) * C[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))]);
for (int32_t i_outer = 0; i_outer < 25; i_outer += 1) {
for (int32_t i_inner = 0; i_inner < 4; i_inner += 1) {
for (int32_t j_outer = 0; j_outer < 1; j_outer += 1) {
for (int32_t j_inner = 0; j_inner < 16; j_inner += 1) {
D[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))] = ((2 * C[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))]) + (4 * (C[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))] * A[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))])));
};
};
for (int32_t j_outer = 1; j_outer < 2; j_outer += 1) {
for (int32_t j_inner = 0; j_inner < (20 + (-16 * j_outer)); j_inner += 1) {
D[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))] = ((2 * C[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))]) + (4 * (C[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))] * A[((20 * i_inner) + ((80 * i_outer) + ((16 * j_outer) + j_inner)))])));
};
};
};
Expand Down Expand Up @@ -206,15 +208,14 @@ TEST(CodeGenC, matmul) {

// Code gen
auto funcs = Lower("matmul", {A, B, C_init, C});
ASSERT_EQ(funcs.size(), 1UL);

Target target;
target.arch = Target::Arch ::X86;
target.bits = Target::Bit ::k32;
target.os = Target::OS ::Linux;

Module module("module1", target);
module.Append(funcs.front());
module.Append(funcs);
module.Append(C_buf);

CodeGenC codegen(target);
Expand All @@ -234,10 +235,10 @@ void matmul(const struct cinn_buffer_t *_A, const struct cinn_buffer_t *_B, stru
const float* B = (const float*)(cinn_buffer_get_data_const_handle(_B));
float* C = (float*)(cinn_buffer_get_data_handle(_C));
float* C_init = (float*)(cinn_buffer_get_data_handle(_C));
for (int32_t i = 0; (i <= 99); i += 1) {
for (int32_t j = 0; (j <= 49); j += 1) {
for (int32_t i = 0; i < 100; i += 1) {
for (int32_t j = 0; j < 50; j += 1) {
C_init[((50 * i) + j)] = 0;
for (int32_t k = 0; (k <= 19); k += 1) {
for (int32_t k = 0; k < 20; k += 1) {
C[((50 * i) + j)] = (C[((50 * i) + j)] + (A[((20 * i) + k)] * B[((50 * k) + j)]));
};
};
Expand Down Expand Up @@ -288,15 +289,14 @@ TEST(CodeGenC, matmul_tile) {

// Code gen
auto funcs = Lower("matmul", {A, B, C_init, C});
ASSERT_EQ(funcs.size(), 1UL);

Target target;
target.arch = Target::Arch ::X86;
target.bits = Target::Bit ::k32;
target.os = Target::OS ::Linux;

Module module("module1", target);
module.Append(funcs.front());
module.Append(funcs);
module.Append(C_buf);

CodeGenC codegen(target);
Expand All @@ -316,13 +316,51 @@ void matmul(const struct cinn_buffer_t *_A, const struct cinn_buffer_t *_B, stru
const float* B = (const float*)(cinn_buffer_get_data_const_handle(_B));
float* C = (float*)(cinn_buffer_get_data_handle(_C));
float* C_init = (float*)(cinn_buffer_get_data_handle(_C));
for (int32_t i_outer = 0; (i_outer <= 3); i_outer += 1) {
for (int32_t j_outer = 0; (j_outer <= 15); j_outer += 1) {
for (int32_t i_inner = 0; (i_inner <= min(31, ((-32 * i_outer) + 99))); i_inner += 1) {
for (int32_t j_inner = 0; (j_inner <= min(31, ((-32 * j_outer) + 499))); j_inner += 1) {
for (int32_t i_outer = 0; i_outer < 3; i_outer += 1) {
for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) {
for (int32_t j_inner = 0; j_inner < 32; j_inner += 1) {
C_init[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = 0;
for (int32_t k_outer = 0; k_outer < 50; k_outer += 1) {
for (int32_t k_inner = 0; k_inner < 4; k_inner += 1) {
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k_outer) + k_inner)))] * B[((32 * j_outer) + ((500 * k_inner) + ((2000 * k_outer) + j_inner)))]));
};
};
};
};
};
for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) {
for (int32_t j_inner = 0; j_inner < (500 + (-32 * j_outer)); j_inner += 1) {
C_init[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = 0;
for (int32_t k_outer = 0; k_outer < 50; k_outer += 1) {
for (int32_t k_inner = 0; k_inner < 4; k_inner += 1) {
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k_outer) + k_inner)))] * B[((32 * j_outer) + ((500 * k_inner) + ((2000 * k_outer) + j_inner)))]));
};
};
};
};
};
};
for (int32_t i_outer = 3; i_outer < 4; i_outer += 1) {
for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) {
for (int32_t j_inner = 0; j_inner < 32; j_inner += 1) {
C_init[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = 0;
for (int32_t k_outer = 0; k_outer < 50; k_outer += 1) {
for (int32_t k_inner = 0; k_inner < 4; k_inner += 1) {
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k_outer) + k_inner)))] * B[((32 * j_outer) + ((500 * k_inner) + ((2000 * k_outer) + j_inner)))]));
};
};
};
};
};
for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) {
for (int32_t j_inner = 0; j_inner < (500 + (-32 * j_outer)); j_inner += 1) {
C_init[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = 0;
for (int32_t k_outer = 0; (k_outer <= 49); k_outer += 1) {
for (int32_t k_inner = 0; (k_inner <= 3); k_inner += 1) {
for (int32_t k_outer = 0; k_outer < 50; k_outer += 1) {
for (int32_t k_inner = 0; k_inner < 4; k_inner += 1) {
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k_outer) + k_inner)))] * B[((32 * j_outer) + ((500 * k_inner) + ((2000 * k_outer) + j_inner)))]));
};
};
Expand Down Expand Up @@ -365,15 +403,14 @@ TEST(CodeGenC, matmul_packed) {

// Code gen
auto funcs = Lower("matmul_with_packing", {A, B, packedB, C});
ASSERT_EQ(funcs.size(), 1UL);

Target target;
target.arch = Target::Arch ::X86;
target.bits = Target::Bit ::k32;
target.os = Target::OS ::Linux;

Module module("module1", target);
module.Append(funcs.front());
module.Append(funcs);
module.Append(C_buf);
module.Append(packedB_buf);

Expand All @@ -396,19 +433,54 @@ void matmul_with_packing(const struct cinn_buffer_t *_A, const struct cinn_buffe
const float* B = (const float*)(cinn_buffer_get_data_const_handle(_B));
float* C = (float*)(cinn_buffer_get_data_handle(_C));
float* PackedB = (float*)(cinn_buffer_get_data_handle(_PackedB));
for (int32_t i = 0; (i <= 14); i += 1) {
for (int32_t j = 0; (j <= 199); j += 1) {
for (int32_t k = 0; (k <= 31); k += 1) {
for (int32_t i = 0; i < 15; i += 1) {
for (int32_t j = 0; j < 200; j += 1) {
for (int32_t k = 0; k < 32; k += 1) {
PackedB[((6400 * i) + ((32 * j) + k))] = B[((32 * i) + ((500 * j) + k))];
};
};
};
for (int32_t i_outer = 0; (i_outer <= 3); i_outer += 1) {
for (int32_t j_outer = 0; (j_outer <= 15); j_outer += 1) {
for (int32_t i_inner = 0; (i_inner <= min(31, ((-32 * i_outer) + 99))); i_inner += 1) {
for (int32_t j_inner = 0; (j_inner <= min(31, ((-32 * j_outer) + 499))); j_inner += 1) {
for (int32_t k_outer = 0; (k_outer <= 49); k_outer += 1) {
for (int32_t k_inner = 0; (k_inner <= 3); k_inner += 1) {
for (int32_t i_outer = 0; i_outer < 3; i_outer += 1) {
for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) {
for (int32_t j_inner = 0; j_inner < 32; j_inner += 1) {
for (int32_t k_outer = 0; k_outer < 50; k_outer += 1) {
for (int32_t k_inner = 0; k_inner < 4; k_inner += 1) {
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k_outer) + k_inner)))] * PackedB[((6400 * j_outer) + ((32 * k_inner) + ((128 * k_outer) + j_inner)))]);
};
};
};
};
};
for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < 32; i_inner += 1) {
for (int32_t j_inner = 0; j_inner < (500 + (-32 * j_outer)); j_inner += 1) {
for (int32_t k_outer = 0; k_outer < 50; k_outer += 1) {
for (int32_t k_inner = 0; k_inner < 4; k_inner += 1) {
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k_outer) + k_inner)))] * PackedB[((j_inner % 32) + ((6400 * (j_inner/32)) + ((6400 * j_outer) + ((32 * k_inner) + (128 * k_outer)))))]);
};
};
};
};
};
};
for (int32_t i_outer = 3; i_outer < 4; i_outer += 1) {
for (int32_t j_outer = 0; j_outer < 15; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) {
for (int32_t j_inner = 0; j_inner < 32; j_inner += 1) {
for (int32_t k_outer = 0; k_outer < 50; k_outer += 1) {
for (int32_t k_inner = 0; k_inner < 4; k_inner += 1) {
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k_outer) + k_inner)))] * PackedB[((6400 * j_outer) + ((32 * k_inner) + ((128 * k_outer) + j_inner)))]);
};
};
};
};
};
for (int32_t j_outer = 15; j_outer < 16; j_outer += 1) {
for (int32_t i_inner = 0; i_inner < (100 + (-32 * i_outer)); i_inner += 1) {
for (int32_t j_inner = 0; j_inner < (500 + (-32 * j_outer)); j_inner += 1) {
for (int32_t k_outer = 0; k_outer < 50; k_outer += 1) {
for (int32_t k_inner = 0; k_inner < 4; k_inner += 1) {
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k_outer) + k_inner)))] * PackedB[((j_inner % 32) + ((6400 * (j_inner/32)) + ((6400 * j_outer) + ((32 * k_inner) + (128 * k_outer)))))]);
};
};
Expand Down
9 changes: 4 additions & 5 deletions cinn/backends/codegen_c_x86_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,14 @@ TEST(CodeGenCX86, basic) {
C->stage()->Vectorize(1, 16);
C->stage()->Unroll(1);

auto funcs = Lower("matmul", {A, B, C, D});
CHECK_EQ(funcs.size(), 1UL);
auto func = Lower("matmul", {A, B, C, D});

std::cout << "before optim\n" << funcs.front()->body << std::endl;
std::cout << "before optim\n" << func->body << std::endl;

funcs.front()->body = Optimize(funcs.front()->body);
func->body = Optimize(func->body);

lang::Module module("module1", target);
module.Append(funcs[0]);
module.Append(func);

CodeGenCX86 codegen(target, CodeGenCX86::Feature::AVX512);
codegen.SetInlineBuiltinCodes(false);
Expand Down
17 changes: 0 additions & 17 deletions cinn/common/arithmatic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,23 +267,6 @@ std::tuple<Expr, bool /*positive*/> Solve(Expr lhs, Expr rhs, Var var) {
auto diff_res = ginac::diff(diff, symbol);
CHECK(!diff_res.is_zero());

/*
struct Visitor : public ginac::visitor, public GiNaC::numeric::visitor {
int v = std::numeric_limits<int>::min();
void operator()(GiNaC::ex ex) { ex.accept(*this); }
void visit(const GiNaC::numeric& node) override {
if (node.is_positive()) v = 1;
else v = -1;
}
};
Visitor visitor;
visitor(diff_res);
CHECK_NE(visitor.v, std::numeric_limits<int>::min()) << "the diff result should be a integer";
CHECK_NE(visitor.v, 0) << "the diff result should not be zero";
*/

return std::make_tuple(value, diff_res > 0);
}

Expand Down
4 changes: 1 addition & 3 deletions cinn/ir/buffer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,13 @@ TEST(Buffer, bind_to_multiple_tensors) {

auto funcs = lang::Lower("func1", {A, B});

ASSERT_EQ(funcs.size(), 1UL);

Target target;
target.arch = Target::Arch ::X86;
target.bits = Target::Bit ::k32;
target.os = Target::OS ::Linux;

lang::Module module("module1", target);
module.Append(funcs.front());
module.Append(funcs);
module.Append(buf0);

backends::CodeGenC codegen(target);
Expand Down
3 changes: 3 additions & 0 deletions cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,9 @@ struct Broadcast : public ExprNode<Broadcast> {

Type type() const override;

std::vector<Expr*> expr_fields() override { return {&value}; }
std::vector<const Expr*> expr_fields() const override { return {&value}; }

static const IrNodeTy _node_type_ = IrNodeTy::Broadcast;
};

Expand Down
11 changes: 6 additions & 5 deletions cinn/lang/lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "cinn/ir/buffer.h"
#include "cinn/ir/ir_printer.h"
#include "cinn/optim/optimize.h"
#include "cinn/optim/remove_nested_block.h"
#include "cinn/optim/replace_call_with_expr.h"
#include "cinn/optim/tensor_write_tell.h"
Expand Down Expand Up @@ -159,13 +160,12 @@ std::vector<ir::Argument> PrepareArguments(const std::vector<Tensor>& tensors, c
}

//! Lower the stages and get a LoweredFunc.
std::vector<ir::LoweredFunc> Lower(const std::string& name, const std::vector<Tensor>& args) {
ir::LoweredFunc Lower(const std::string& name, const std::vector<Tensor>& args) {
// make sure the graph's start-points in the args.

auto stages = poly::GatherStagesInTensors(args);
auto extra_dependencies = poly::ExtractExtraDependencyFromStages(stages);
auto graph = poly::CreateGraph(stages, extra_dependencies);
LOG(INFO) << "Graph:\n" << graph->Visualize();

// Create a dic for stages and tensors.
std::map<std::string, Stage*> stage_dic;
Expand Down Expand Up @@ -218,12 +218,13 @@ std::vector<ir::LoweredFunc> Lower(const std::string& name, const std::vector<Te
}

Expr block = ir::Block::Make(exprs);
// call passes
optim::RemoveNestedBlock(&block);

// prepare arguments
std::vector<ir::Argument> arguments = PrepareArguments(args, {block});
return {ir::_LoweredFunc_::Make(name, arguments, block)};

auto func = ir::_LoweredFunc_::Make(name, arguments, block);
auto res = optim::Optimize(func);
return ir::LoweredFunc(res.get());
}

} // namespace lang
Expand Down
2 changes: 1 addition & 1 deletion cinn/lang/lower.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace cinn {
namespace lang {
using ir::Tensor;

std::vector<ir::LoweredFunc> Lower(const std::string& name, const std::vector<Tensor>& args);
ir::LoweredFunc Lower(const std::string& name, const std::vector<Tensor>& args);

} // namespace lang
} // namespace cinn
Loading

0 comments on commit 91de356

Please sign in to comment.