Skip to content

Commit

Permalink
Load store to relative indice (PaddlePaddle#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn authored Mar 28, 2020
1 parent eb84e94 commit d01cfc8
Show file tree
Hide file tree
Showing 25 changed files with 321 additions and 118 deletions.
25 changes: 18 additions & 7 deletions cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,32 +245,43 @@ void CodeGenC::Visit(const ir::Module *op) { NOT_IMPLEMENTED }
void CodeGenC::Visit(const ir::_Var_ *op) { os() << op->name; }

void CodeGenC::Visit(const ir::Load *op) {
Expr dense_strided_ramp = detail::StridedRampBase(op->index, 1);
Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1);
if (dense_strided_ramp.defined()) { // Loading a continuous Ramp address.
CHECK(op->type().is_vector());
PrintStackVecType(op->type().ElementOf(), op->index.type().lanes());
PrintStackVecType(op->type().ElementOf(), op->index().type().lanes());
os() << "::"
<< "Load(";
os() << op->tensor.As<ir::_Tensor_>()->name;
os() << ",";
Print(dense_strided_ramp);
os() << ")";
} else if (op->index.type().is_vector()) {
} else if (op->index().type().is_vector()) {
// gather
CHECK(op->type().is_vector());
PrintStackVecType(op->type().ElementOf(), op->index.type().lanes());
PrintStackVecType(op->type().ElementOf(), op->index().type().lanes());
os() << "::Load(";
os() << op->tensor.As<ir::_Tensor_>()->name;
os() << ",";
Print(op->index);
Print(op->index());
os() << ")";
} else {
// load scalar
ir::IrPrinter::Visit(op);
auto *tensor = op->tensor.As<ir::_Tensor_>();
CHECK(tensor);
os() << tensor->name << "[";
Print(op->index());
os() << "]";
}
}

void CodeGenC::Visit(const ir::Store *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Store *op) {
auto *tensor_node = op->tensor.As<ir::_Tensor_>();
CHECK(tensor_node);
os() << tensor_node->name << "[";
Print(op->index());
os() << "] = ";
Print(op->value);
}
void CodeGenC::Visit(const ir::Alloc *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::Free *op) { IrPrinter::Visit(op); }
void CodeGenC::Visit(const ir::_Range_ *op) { IrPrinter::Visit(op); }
Expand Down
14 changes: 7 additions & 7 deletions cinn/backends/codegen_c_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void add1(const struct cinn_buffer_t *_A, const struct cinn_buffer_t *_B, struct
float* C = (float*)(cinn_buffer_get_data_handle(_C));
for (int32_t i = 0; (i <= 99); i += 1) {
for (int32_t j = 0; (j <= 19); j += 1) {
C[((i * 20) + j)] = (A[((i * 20) + j)] + B[((i * 20) + j)]);
C[((20 * i) + j)] = (A[((20 * i) + j)] + B[((20 * i) + j)]);
};
};
}
Expand Down Expand Up @@ -236,9 +236,9 @@ void matmul(const struct cinn_buffer_t *_A, const struct cinn_buffer_t *_B, stru
float* C_init = (float*)(cinn_buffer_get_data_handle(_C));
for (int32_t i = 0; (i <= 99); i += 1) {
for (int32_t j = 0; (j <= 49); j += 1) {
C_init[((i * 50) + j)] = 0;
C_init[((50 * i) + j)] = 0;
for (int32_t k = 0; (k <= 19); k += 1) {
C[((i * 50) + j)] = (C[((i * 50) + j)] + (A[((i * 20) + k)] * B[((k * 50) + j)]));
C[((50 * i) + j)] = (C[((50 * i) + j)] + (A[((20 * i) + k)] * B[((50 * k) + j)]));
};
};
};
Expand Down Expand Up @@ -320,10 +320,10 @@ void matmul(const struct cinn_buffer_t *_A, const struct cinn_buffer_t *_B, stru
for (int32_t j_outer = 0; (j_outer <= 15); j_outer += 1) {
for (int32_t i_inner = 0; (i_inner <= min(31, ((-32 * i_outer) + 99))); i_inner += 1) {
for (int32_t j_inner = 0; (j_inner <= min(31, ((-32 * j_outer) + 499))); j_inner += 1) {
C_init[((((32 * i_outer) + i_inner) * 500) + ((32 * j_outer) + j_inner))] = 0;
C_init[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = 0;
for (int32_t k_outer = 0; (k_outer <= 49); k_outer += 1) {
for (int32_t k_inner = 0; (k_inner <= 3); k_inner += 1) {
C[((((32 * i_outer) + i_inner) * 500) + ((32 * j_outer) + j_inner))] = (C[((((32 * i_outer) + i_inner) * 500) + ((32 * j_outer) + j_inner))] + (A[((((32 * i_outer) + i_inner) * 200) + ((4 * k_outer) + k_inner))] * B[((((4 * k_outer) + k_inner) * 500) + ((32 * j_outer) + j_inner))]));
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] + (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k_outer) + k_inner)))] * B[((32 * j_outer) + ((500 * k_inner) + ((2000 * k_outer) + j_inner)))]));
};
};
};
Expand Down Expand Up @@ -399,7 +399,7 @@ void matmul_with_packing(const struct cinn_buffer_t *_A, const struct cinn_buffe
for (int32_t i = 0; (i <= 14); i += 1) {
for (int32_t j = 0; (j <= 199); j += 1) {
for (int32_t k = 0; (k <= 31); k += 1) {
PackedB[((((i * 200) * 32) + (j * 32)) + k)] = B[((j * 500) + ((i * 32) + k))];
PackedB[((6400 * i) + ((32 * j) + k))] = B[((32 * i) + ((500 * j) + k))];
};
};
};
Expand All @@ -409,7 +409,7 @@ void matmul_with_packing(const struct cinn_buffer_t *_A, const struct cinn_buffe
for (int32_t j_inner = 0; (j_inner <= min(31, ((-32 * j_outer) + 499))); j_inner += 1) {
for (int32_t k_outer = 0; (k_outer <= 49); k_outer += 1) {
for (int32_t k_inner = 0; (k_inner <= 3); k_inner += 1) {
C[((((32 * i_outer) + i_inner) * 500) + ((32 * j_outer) + j_inner))] = (A[((((32 * i_outer) + i_inner) * 200) + ((4 * k_outer) + k_inner))] * PackedB[(((((((32 * j_outer) + j_inner) / 32) * 200) * 32) + (((4 * k_outer) + k_inner) * 32)) + (((32 * j_outer) + j_inner) % 32))]);
C[((500 * i_inner) + ((16000 * i_outer) + ((32 * j_outer) + j_inner)))] = (A[((200 * i_inner) + ((6400 * i_outer) + ((4 * k_outer) + k_inner)))] * PackedB[((j_inner % 32) + ((6400 * (j_inner/32)) + ((6400 * j_outer) + ((32 * k_inner) + (128 * k_outer)))))]);
};
};
};
Expand Down
2 changes: 1 addition & 1 deletion cinn/backends/codegen_c_x86.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ void CodeGenCX86::Visit(const ir::Mul *op) { VisitBinaryOp(op, op->a(), op->b(),
void CodeGenCX86::Visit(const ir::Div *op) { VisitBinaryOp(op, op->a(), op->b(), "div"); }

void CodeGenCX86::Visit(const ir::Load *op) {
Expr dense_strided_ramp = detail::StridedRampBase(op->index, 1);
Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1);
if (dense_strided_ramp.defined()) { // Loading a continuous Ramp address.
CHECK(op->type().is_vector());

Expand Down
5 changes: 3 additions & 2 deletions cinn/backends/codegen_c_x86.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ class CodeGenCX86 : public CodeGenC {
void PrintAbsAddr(const Op *op) {
os() << op->tensor.template As<ir::_Tensor_>()->name << " + ";

auto *ramp_n = op->index.template As<ir::Ramp>();
auto index = op->index();
auto *ramp_n = index.template As<ir::Ramp>();
if (ramp_n) {
CHECK(!ramp_n->base.template As<ir::Ramp>()) << "base of a Ramp node should not be Ramp type";
Print(ramp_n->base);
} else {
Print(op->index);
Print(op->index());
}
}

Expand Down
7 changes: 7 additions & 0 deletions cinn/common/arithmatic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ bool IsPureMath(Expr expr) {
std::set<IrNodeTy> valid_node_tys({
IrNodeTy ::_Var_,
IrNodeTy ::IntImm,
IrNodeTy ::Sum,
IrNodeTy ::Product,
IrNodeTy ::FracOp,
IrNodeTy ::Power,
IrNodeTy ::FloatImm,
IrNodeTy ::Add,
IrNodeTy ::Sub,
Expand All @@ -226,6 +230,9 @@ bool IsPureMath(Expr expr) {
});

auto complex_nodes = ir::CollectIRNodes(expr, [&](const Expr* n) { return !valid_node_tys.count(n->node_type()); });
for (auto& node : complex_nodes) {
VLOG(3) << "Found " << node->node_type() << " " << Expr(node);
}
return complex_nodes.empty();
}

Expand Down
6 changes: 0 additions & 6 deletions cinn/common/cas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,6 @@ Expr CasSimplifyMutator::operator()(Expr u) {
if (u.As<FracOp>()) {
u = SimplifyFracOp(u);
auto tmp = FurtherSimplifyFracWithInterval(u, var_intervals);
LOG(INFO) << "futher simplify result " << tmp;
if (!tmp.same_as(u)) return operator()(tmp);
return u;
}
Expand Down Expand Up @@ -1125,21 +1124,16 @@ bool IsMonotonical(Expr u, Var v) {
// Should be called after SimplifyFracOp. If y is integer and $y\in \[0, 3\]$, then y/4=0
Expr CasSimplifyMutator::FurtherSimplifyFracWithInterval(
Expr expr, const std::unordered_map<std::string, CasInterval>& var_intervals) {
LOG(INFO) << "futher simplify frac " << expr << " " << expr.node_type();
auto* node = expr.As<FracOp>();
if (!node) return expr;
auto a = CasSimplify(node->a(), var_intervals);
auto b = CasSimplify(node->b(), var_intervals);

LOG(INFO) << "a " << a << " " << a.node_type();
LOG(INFO) << "b " << b << " " << b.node_type();

auto* ai = a.As<IntImm>();
auto* bi = b.As<IntImm>();
auto* av = a.As<_Var_>();
auto* bv = b.As<_Var_>();
auto* ap = a.As<Product>();
LOG(INFO) << "var_intervals.size " << var_intervals.size();
// case: y / 4, y\in[0,3]
if (bi) {
if (av) {
Expand Down
4 changes: 2 additions & 2 deletions cinn/common/cas_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,9 @@ TEST(CAS, ConvertCinnToCAS) {

body = detail::ConvertCinnToCAS(body);
body = CasSimplify(body);
EXPECT_EQ(GetStreamCnt(body), "(1 + A[((i * 10) + j)] + (2 * B[((i * 10) + j)]))");
EXPECT_EQ(GetStreamCnt(body), "(1 + A[i, j] + (2 * B[i, j]))");
body = detail::ConvertCasToCinn(body);
EXPECT_EQ(GetStreamCnt(body), "(1 + (A[((i * 10) + j)] + (2 * B[((i * 10) + j)])))");
EXPECT_EQ(GetStreamCnt(body), "(1 + (A[i, j] + (2 * B[i, j])))");
}

TEST(CAS, FracOp) {
Expand Down
107 changes: 103 additions & 4 deletions cinn/common/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,121 @@

#include "cinn/ir/ir_mutator.h"
#include "cinn/ir/ir_operators.h"
#include "cinn/ir/ir_printer.h"

namespace cinn {
namespace common {

namespace {

// ramp + scalar or broadcast
Expr RampRelatedMul(ir::Ramp *ramp, Expr other) {
CHECK_EQ(other.type().ElementOf(), Int(32));
CHECK_EQ(ramp->base.type(), Int(32));
CHECK_EQ(ramp->stride.type(), Int(32));
auto *other_broadcast = other.As<ir::Broadcast>();
if (other_broadcast) {
CHECK_EQ(ramp->lanes, other_broadcast->lanes);
other = other_broadcast->value;
}
return ir::Ramp::Make(ramp->base, ramp->stride * other, ramp->lanes);
}

Expr RampRelatedMul(ir::Broadcast *broadcast, Expr other) {
CHECK_EQ(other.type().lanes(), 1);
return ir::Broadcast::Make(broadcast->value * other, broadcast->lanes);
}
// ramp * ramp
Expr RampRelatedMul(ir::Ramp *ramp, ir::Ramp *other) {
NOT_IMPLEMENTED
return Expr();
}
// ramp + scalar
Expr RampRelatedAdd(ir::Ramp *ramp, Expr other) {
CHECK_EQ(other.type().ElementOf(), Int(32));

auto *other_broadcast = other.As<ir::Broadcast>();
if (other_broadcast) {
CHECK_EQ(ramp->lanes, other_broadcast->lanes);
other = other_broadcast->value;
}
return ir::Ramp::Make(ramp->base + other, ramp->stride, ramp->lanes);
}
Expr RampRelatedAdd(ir::Broadcast *broadcast, Expr other) {
CHECK_EQ(other.type().lanes(), 1);
return ir::Broadcast::Make(broadcast->value + other, broadcast->lanes);
}
// ramp * ramp
Expr RampRelatedAdd(ir::Ramp *ramp, ir::Ramp *other) {
NOT_IMPLEMENTED
return Expr();
}

Expr RampRelatedAdd(Expr a, Expr b) {
auto *a_ramp = a.As<ir::Ramp>();
auto *b_ramp = b.As<ir::Ramp>();
auto *a_broadcast = a.As<ir::Broadcast>();
auto *b_broadcast = b.As<ir::Broadcast>();
if (a_ramp && !b_ramp && (b->type().lanes() == 1 || b_broadcast)) {
return RampRelatedAdd(a_ramp, b);
} else if (!a_ramp && b_ramp && (a->type().lanes() == 1 || a_broadcast)) {
return RampRelatedAdd(b_ramp, a);
} else if (!a_ramp && !b_ramp && !a->type().is_vector() && !b->type().is_vector()) {
return a + b;
} else if (a_ramp && b_ramp) { // a_ramp && b_ramp
return RampRelatedAdd(a_ramp, b_ramp);
} else if (a_broadcast && !b_broadcast) {
return RampRelatedAdd(a_broadcast, b);
} else if (!a_broadcast && b_broadcast) {
return RampRelatedAdd(b_broadcast, a);
} else if (a_broadcast && b_broadcast) {
CHECK_EQ(a_broadcast->lanes, b_broadcast->lanes);
return ir::Broadcast::Make(a_broadcast->value + b_broadcast->value, a_broadcast->lanes);
} else {
NOT_IMPLEMENTED
}
}

Expr RampRelatedMul(Expr a, Expr b) {
auto *a_ramp = a.As<ir::Ramp>();
auto *b_ramp = b.As<ir::Ramp>();
auto *a_broadcast = a.As<ir::Broadcast>();
auto *b_broadcast = b.As<ir::Broadcast>();
if (a_ramp && !b_ramp && (!b->type().is_vector() || b_broadcast)) {
return RampRelatedMul(a_ramp, b);
} else if (!a_ramp && b_ramp && (a->type().is_vector() || a_broadcast)) {
return RampRelatedMul(b_ramp, a);
} else if (!a_ramp && !b_ramp && !a->type().is_vector() && !b->type().is_vector()) {
return a * b;
} else if (a_ramp && b_ramp) { // a_ramp && b_ramp
return RampRelatedMul(a_ramp, b_ramp);
} else if (a_broadcast && !b_broadcast) {
return RampRelatedMul(a_broadcast, b);
} else if (!a_broadcast && b_broadcast) {
return RampRelatedMul(b_broadcast, a);
} else if (a_broadcast && b_broadcast) {
CHECK_EQ(a_broadcast->lanes, b_broadcast->lanes);
return ir::Broadcast::Make(a_broadcast->value * b_broadcast->value, a_broadcast->lanes);
} else {
NOT_IMPLEMENTED
}
}

} // namespace

Expr ExpandTo1DIndice(const std::vector<Expr> &shape, const std::vector<Expr> &indices) {
CHECK_EQ(shape.size(), indices.size());
Expr res;
for (int i = 0; i < shape.size(); i++) {
CHECK_EQ(shape[i].type(), Int(32));
Expr indice_prod = indices[i];
for (int j = i + 1; j < shape.size(); j++) {
indice_prod = indice_prod * shape[j];
indice_prod = RampRelatedMul(indice_prod, shape[j]);
}

if (res.defined())
res = res + indice_prod;
else
if (res.defined()) {
res = RampRelatedAdd(res, indice_prod);
} else
res = indice_prod;
}

Expand Down
Loading

0 comments on commit d01cfc8

Please sign in to comment.