Skip to content

Commit

Permalink
add reduce axis (PaddlePaddle#72)
Browse files Browse the repository at this point in the history
* refine Var, add reduce axis

* support reduce_sum
  • Loading branch information
Superjomn committed Mar 9, 2020
1 parent a2d1a8c commit c594325
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 7 deletions.
5 changes: 3 additions & 2 deletions cinn/backends/codegen_c_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <sstream>
#include <tuple>

#include "cinn/lang/builtin.h"
#include "cinn/lang/compute.h"
#include "cinn/lang/lower.h"
#include "cinn/lang/module.h"
Expand Down Expand Up @@ -189,7 +190,7 @@ TEST(CodeGenC, matmul) {
Var k(20, "k");

Tensor C = Compute(
{100, 50}, [&](Var i, Var j) { return A(i, k) * B(k, j); }, "C", k);
{100, 50}, [&](Var i, Var j) { return lang::Sum(A(i, k) * B(k, j), k); }, "C", k);
C->Bind(C_buf);

// Code gen
Expand Down Expand Up @@ -223,7 +224,7 @@ void matmul(const struct cinn_buffer_t *_A, const struct cinn_buffer_t *_B, stru
for (int32_t i = 0; (i <= 99); i += 1){
for (int32_t j = 0; (j <= 49); j += 1){
for (int32_t k = 0; (k <= 19); k += 1){
C[((i * 50) + j)] = (A[((i * 20) + k)] * B[((k * 50) + j)]);
C[((i * 50) + j)] = (C[((i * 50) + j)] + (A[((i * 20) + k)] * B[((k * 50) + j)]));
};
};
};
Expand Down
12 changes: 12 additions & 0 deletions cinn/ir/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ std::ostream &operator<<(std::ostream &os, IrNodeTy type) {
return os;
}

Expr Zero(const Type &type) {
if (type.is_float(32)) return Expr(0.f);
if (type.is_float(64)) return Expr(double(0.));
if (type.is_bool()) return Expr(false);
if (type.is_int(32)) return Expr(int32_t(0));
if (type.is_int(64)) return Expr(int64_t(0));
if (type.is_uint(32)) return Expr(uint32_t(0));
if (type.is_uint(64)) return Expr(uint64_t(0));
NOT_IMPLEMENTED
return Expr();
}

Expr::Expr(const Var &var) { *static_cast<IrNodeRef *>(this) = *static_cast<const IrNodeRef *>(&var); }

int32_t Expr::as_int32() const {
Expand Down
6 changes: 6 additions & 0 deletions cinn/ir/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <memory>
#include <string>

#include "cinn/common/common.h"
#include "cinn/common/object.h"
#include "cinn/common/shared.h"
#include "cinn/common/type.h"
Expand Down Expand Up @@ -222,7 +223,9 @@ struct Expr : public IrNodeRef {
//! Helper function to construct numeric constants of various types.
// @{
explicit Expr(int32_t x) : IrNodeRef(new IntImm(Int(32), x)) {}
explicit Expr(uint32_t x) : IrNodeRef(new IntImm(UInt(32), x)) {}
explicit Expr(int64_t x) : IrNodeRef(new IntImm(Int(64), x)) {}
explicit Expr(uint64_t x) : IrNodeRef(new IntImm(UInt(64), x)) {}
explicit Expr(float x) : IrNodeRef(new FloatImm(Float(32), x)) {}
explicit Expr(double x) : IrNodeRef(new FloatImm(Float(64), x)) {}
// @}
Expand Down Expand Up @@ -279,6 +282,9 @@ struct BinaryOpNode : public ExprNode<T>, public BinaryArguHolder {
bool is_binary_op() const override { return true; }
};

//! Zero in CINN type system.
Expr Zero(const Type& type);

enum class DeviceAPI {
UNK,
Host,
Expand Down
7 changes: 6 additions & 1 deletion cinn/lang/builtin.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
#include "cinn/ir/ir.h"
#include "cinn/lang/buffer.h"

namespace cinn {
namespace lang {} // namespace lang
namespace lang {

Expr Sum(Expr body, Var reduce_axis) { return ir::Reduce::Make(ir::Reduce::kSum, ir::Zero(body.type()), body); }

} // namespace lang
} // namespace cinn
2 changes: 1 addition & 1 deletion cinn/lang/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
namespace cinn {
namespace lang {

Expr ReduceSum(Expr body, Var axis);
Expr Sum(Expr body, Var reduce_axis);

} // namespace lang
} // namespace cinn
15 changes: 14 additions & 1 deletion cinn/lang/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,21 @@ Expr _Tensor_::tensor_store_expanded_body() {
for (int i = 0; i < axis.size(); i++) {
axis_.push_back(Expr(axis[i]));
}
Expr final_body = body();

auto *reduce_node = body().As<ir::Reduce>();
if (reduce_node) {
final_body = reduce_node->body;
switch (reduce_node->reduce_type) {
case ir::Reduce::kSum:
final_body = Tensor(this)(axis_) + final_body;
break;
default:
NOT_IMPLEMENTED
}
}

return ir::Store::Make(Expr(Buffer(this)), body(), detail::ExpandTo1DIndice(shape, axis_));
return ir::Store::Make(Expr(Buffer(this)), final_body, detail::ExpandTo1DIndice(shape, axis_));
}

void _Tensor_::Bind(lang::Buffer &buffer) {
Expand Down
9 changes: 7 additions & 2 deletions cinn/lang/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,13 @@ class _Tensor_ : public ExprNode<_Tensor_> {
void* stage_shared{};

public:
//! Shape of this tensor.
//! Shape of this tensor(buffer).
std::vector<Expr> shape;
//! The domain of each axis, TODO(Superjomn) support ISL domain.
//! The domain of each axis.
// TODO(Superjomn) support ISL domain.
std::vector<Expr> domain;
//! Tensor axis.
// TODO(Superjomn) seems not needed, to be discarded latter.
std::vector<Var> axis;
//! The operation that generates Tensor.
FunctionRef operaion;
Expand Down Expand Up @@ -152,6 +154,9 @@ class _Tensor_ : public ExprNode<_Tensor_> {

Expr inline_expanded(const std::vector<Expr>& indices);

//! Tell whether contain a reduce axis.
bool contains_reduce_axis() const { return shape.size() == domain.size(); }

std::vector<Expr*> expr_fields() override;
std::vector<const Expr*> expr_fields() const override;

Expand Down

0 comments on commit c594325

Please sign in to comment.