diff --git a/cinn/backends/codegen_c_test.cc b/cinn/backends/codegen_c_test.cc index 2873bc7ed8f3b..40d88229391d2 100644 --- a/cinn/backends/codegen_c_test.cc +++ b/cinn/backends/codegen_c_test.cc @@ -328,11 +328,11 @@ void matmul(const struct cinn_buffer_t *_A, const struct cinn_buffer_t *_B, stru ASSERT_EQ(Trim(target_out), Trim(out)); } -TEST(CodeGenC, matmul_with_packed) { +TEST(CodeGenC, matmul_packed) { const int M = 100; - const int K = 20; - const int N = 50; - const int bn = 4; + const int K = 200; + const int N = 500; + const int bn = 32; Placeholder A("A", {M, K}); Placeholder B("B", {K, N}); @@ -348,6 +348,13 @@ TEST(CodeGenC, matmul_with_packed) { {M, N}, [&](Expr i, Expr j) { return A(i, k) * packedB(j / bn, k, j % bn); }, "C", k); C->Bind(C_buf); + { + poly::Iterator i_outer, i_inner, j_outer, j_inner, k_outer, k_inner; + std::tie(i_outer, i_inner, j_outer, j_inner) = C->stage()->Tile(0, 1, bn, bn); + std::tie(k_outer, k_inner) = C->stage()->Split(poly::Iterator("k"), 4); + C->stage()->Reorder({i_outer, j_outer, i_inner, j_inner, k_outer, k_inner}); + } + // Code gen auto funcs = Lower("matmul_with_packing", {A, B, packedB, C}); ASSERT_EQ(funcs.size(), 1UL); @@ -380,17 +387,23 @@ void matmul_with_packing(const struct cinn_buffer_t *_A, const struct cinn_buffe const float* B = (const float*)(cinn_buffer_get_data_const_handle(_B)); float* C = (float*)(cinn_buffer_get_data_handle(_C)); float* PackedB = (float*)(cinn_buffer_get_data_handle(_PackedB)); - for (int32_t i = 0; (i <= 11); i += 1){ - for (int32_t j = 0; (j <= 19); j += 1){ - for (int32_t k = 0; (k <= 3); k += 1){ - PackedB[(((i * 20) + (j * 4)) + k)] = B[((j * 50) + ((i * 4) + k))]; + for (int32_t i = 0; (i <= 14); i += 1){ + for (int32_t j = 0; (j <= 199); j += 1){ + for (int32_t k = 0; (k <= 31); k += 1){ + PackedB[((((i * 200) * 32) + (j * 32)) + k)] = B[((j * 500) + ((i * 32) + k))]; }; }; }; - for (int32_t i = 0; (i <= 99); i += 1){ - for (int32_t j = 0; (j <= 49); j += 1){ - for (int32_t k = 0; (k <= 19); k += 1){ - C[((i * 50) + j)] = (A[((i * 20) + k)] * PackedB[((((j / 4) * 20) + (k * 4)) + (j % 4))]); + for (int32_t i_outer = 0; (i_outer <= 3); i_outer += 1){ + for (int32_t j_outer = 0; (j_outer <= 15); j_outer += 1){ + for (int32_t i_inner = 0; (i_inner <= min(31, ((-32 * i_outer) + 99))); i_inner += 1){ + for (int32_t j_inner = 0; (j_inner <= min(31, ((-32 * j_outer) + 499))); j_inner += 1){ + for (int32_t k_outer = 0; (k_outer <= 49); k_outer += 1){ + for (int32_t k_inner = 0; (k_inner <= 3); k_inner += 1){ + C[((((32 * i_outer) + i_inner) * 500) + ((32 * j_outer) + j_inner))] = (A[((((32 * i_outer) + i_inner) * 200) + ((4 * k_outer) + k_inner))] * PackedB[(((((((32 * j_outer) + j_inner) / 32) * 200) * 32) + (((4 * k_outer) + k_inner) * 32)) + (((32 * j_outer) + j_inner) % 32))]); + }; + }; + }; }; }; }; diff --git a/cinn/common/CMakeLists.txt b/cinn/common/CMakeLists.txt index 1bd0190d66cec..768c7f5f02658 100644 --- a/cinn/common/CMakeLists.txt +++ b/cinn/common/CMakeLists.txt @@ -6,7 +6,9 @@ set(srcs object.cc graph_utils.cc context.cc - axis.cc) + axis.cc + ir.cc + ) foreach(cpp ${srcs}) set(core_src diff --git a/cinn/ir/buffer.cc b/cinn/ir/buffer.cc index 32ef38a461408..fb573de1e9688 100644 --- a/cinn/ir/buffer.cc +++ b/cinn/ir/buffer.cc @@ -1,6 +1,7 @@ #include "cinn/ir/buffer.h" #include "cinn/common/common.h" +#include "cinn/common/ir.h" #include "cinn/ir/ir_operators.h" #include "cinn/ir/ir_visitor.h" #include "cinn/runtime/intrinsic.h" @@ -77,31 +78,6 @@ Var _Buffer_::buffer_addr() const { return _Var_::Make(name, thetype); } -/* -Expr Buffer::LoadExpr(const std::vector &indice) const { - NOT_IMPLEMENTED - auto *node = operator->(); - return Load::Make(Expr(*this), AbsOffset(indice)); -} - -Expr Buffer::StoreExpr(const std::vector &indice, Expr value) const { - auto *node = operator->(); - return Store::Make(Expr(*this), value, AbsOffset(indice)); -} -*/ - -Expr Buffer::AbsOffset(const std::vector &indice) const { - auto *node = operator->(); - CHECK(!node->shape.empty()); - CHECK_EQ(node->shape.size(), indice.size()) << "shape and indice not match"; - Expr res = indice.front() * node->shape[1]; - for (int i = 1; i < node->shape.size() - 1; i++) { - res = res + indice[i] * node->shape[i + 1]; - } - if (node->shape.size() > 1) res = res + indice.back(); - return res; -} - Expr Buffer::DestroyExpr() const { auto *node = operator->(); return ir::Call::Make( diff --git a/cinn/ir/buffer.h b/cinn/ir/buffer.h index a482e8629c21d..7a87870be6c5a 100644 --- a/cinn/ir/buffer.h +++ b/cinn/ir/buffer.h @@ -50,10 +50,6 @@ class Buffer : public IrNodeRef { const _Buffer_* operator->() const; _Buffer_* operator->(); - - protected: - //! Get a 1-dimension offset given multi-dimension indices. - Expr AbsOffset(const std::vector& indice) const; }; class _Buffer_ : public ExprNode<_Buffer_> { diff --git a/cinn/lang/tensor.cc b/cinn/lang/tensor.cc index 924a371c3a943..42fe1dc3cfd86 100644 --- a/cinn/lang/tensor.cc +++ b/cinn/lang/tensor.cc @@ -3,8 +3,10 @@ #include #include "cinn/common/common.h" +#include "cinn/common/ir.h" #include "cinn/ir/buffer.h" #include "cinn/ir/ir_operators.h" +#include "cinn/ir/ir_printer.h" #include "cinn/ir/ir_visitor.h" #include "cinn/ir/operation.h" #include "cinn/poly/stage.h" @@ -12,26 +14,6 @@ namespace cinn { namespace ir { -namespace detail { - -Expr ExpandTo1DIndice(const std::vector &shape, const std::vector &indices) { - CHECK_EQ(shape.size(), indices.size()); - Expr res = indices.front() * shape[1]; - for (int i = 1; i < shape.size() - 1; i++) { - res = res + indices[i] * shape[i + 1]; - } - if (shape.size() > 1) res = res + indices.back(); - return res; -} - -Expr ExpandTo1DIndice(const std::vector &shape, const std::vector &indices) { - std::vector shape_; - for (int v : shape) shape_.push_back(Expr(v)); - return ExpandTo1DIndice(shape, indices); -} - -} // namespace detail - Tensor _Tensor_::Make(const std::string &name, const std::vector &shape, FunctionRef fn) { CHECK(!shape.empty()) << "Tensor shape is set empty"; CHECK(!name.empty()) << "Tensor name is set empty"; @@ -89,7 +71,7 @@ Expr Tensor::operator()(const std::vector &indices) const { } else { CHECK(node->buffer.defined()) << utils::StringFormat("Buffer for [%s] should be defined so that it can be sliced", node->name.c_str()); - return Load::Make(*this, node->AbsOffset(indices)); + return Load::Make(*this, common::ExpandTo1DIndice(node->shape, indices)); } } @@ -240,7 +222,7 @@ Expr _Tensor_::tensor_store_expanded_body() { } } - return ir::Store::Make(Expr(Buffer(this)), final_body, detail::ExpandTo1DIndice(shape, axis_)); + return ir::Store::Make(Expr(Buffer(this)), final_body, common::ExpandTo1DIndice(shape, axis_)); } void _Tensor_::Bind(lang::Buffer &buffer) { @@ -254,17 +236,6 @@ void _Tensor_::Bind(lang::Buffer &buffer) { InitStage(); } -Expr _Tensor_::AbsOffset(const std::vector &indice) const { - CHECK(!shape.empty()); - CHECK_EQ(shape.size(), indice.size()) << "shape and indice not match"; - Expr res = indice.front() * shape[1]; - for (int i = 1; i < shape.size() - 1; i++) { - res = res + indice[i] * shape[i + 1]; - } - if (shape.size() > 1) res = res + indice.back(); - return res; -} - void Tensor::ExpandInlined() { // Collect all the Calls with Tensors // Expand all the uninlined tensor. diff --git a/cinn/lang/tensor.h b/cinn/lang/tensor.h index 1412093cb1408..9ce41614796c1 100644 --- a/cinn/lang/tensor.h +++ b/cinn/lang/tensor.h @@ -25,10 +25,6 @@ namespace detail { constexpr bool LE(int a, int b) { return a <= b; } constexpr bool GE(int a, int b) { return a >= b; } -//! Expand milti-dim indices to 1-dim index. -Expr ExpandTo1DIndice(const std::vector& shape, const std::vector& indices); -Expr ExpandTo1DIndice(const std::vector& shape, const std::vector& indices); - } // namespace detail class _Tensor_; @@ -166,8 +162,6 @@ class _Tensor_ : public ExprNode<_Tensor_> { ~_Tensor_(); - Expr AbsOffset(const std::vector& indice) const; - private: //! Create the polyhedral element for analysis. //! It is based on the shape.