diff --git a/cinn/backends/_x86_builtin_source.cc b/cinn/backends/_x86_builtin_source.cc index cada4afde8697..6a2a3744d29c9 100644 --- a/cinn/backends/_x86_builtin_source.cc +++ b/cinn/backends/_x86_builtin_source.cc @@ -129,13 +129,13 @@ struct ExternalVec { // AVX256 load //@{ -inline __m256 cinn_avx256_load(float* dst) { return _mm256_load_ps(dst); } -inline __m256d cinn_avx256_load(double* dst) { return _mm256_load_pd(dst); } +inline __m256 cinn_avx256_load(const float* dst) { return _mm256_load_ps(dst); } +inline __m256d cinn_avx256_load(const double* dst) { return _mm256_load_pd(dst); } //@} // AVX512 load //@{ -inline __m512 cinn_avx512_load(float* dst) { return _mm512_load_ps(dst); } -inline __m512d cinn_avx512_load(double* dst) { return _mm512_load_pd(dst); } +inline __m512 cinn_avx512_load(const float* dst) { return _mm512_load_ps(dst); } +inline __m512d cinn_avx512_load(const double* dst) { return _mm512_load_pd(dst); } //@} // FP32x8 * FP32x8 @@ -313,6 +313,22 @@ inline __m512 cinn_avx512_set1(float value) { return _mm512_set1_ps(value); } inline __m512d cinn_avx512_set1(double value) { return _mm512_set1_pd(value); } // @} +//! store +// @{ +inline void cinn_avx512_store(float* dst, const __m512& x) { _mm512_store_ps(dst, x); } +inline void cinn_avx512_store(double* dst, const __m512d& x) { _mm512_store_pd(dst, x); } +inline void cinn_avx256_store(float* dst, const __m256& x) { _mm256_store_ps(dst, x); } +inline void cinn_avx256_store(double* dst, const __m256d& x) { _mm256_store_pd(dst, x); } +// @} + +//! add +// @{ +inline __m256 cinn_avx256_add(const __m256& a, const __m256& b) { return _mm256_add_ps(a, b); } +inline __m256d cinn_avx256_add(const __m256d& a, const __m256d& b) { return _mm256_add_pd(a, b); } +inline __m512 cinn_avx512_add(const __m512& a, const __m512& b) { return _mm512_add_ps(a, b); } +inline __m512d cinn_avx512_add(const __m512d& a, const __m512d& b) { return _mm512_add_pd(a, b); } +// @} + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// /// )END Predefined utilities in CINN //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cinn/backends/codegen_c.cc b/cinn/backends/codegen_c.cc index 8737a2e3cb458..45fbf5e14be23 100644 --- a/cinn/backends/codegen_c.cc +++ b/cinn/backends/codegen_c.cc @@ -13,21 +13,21 @@ using namespace utils; void CodeGenC::Compile(const lang::Module &module, const Outputs &outputs) { if (!outputs.c_header_name.empty()) { - LOG(WARNING) << "Output C source to file " << outputs.c_header_name; auto source = Compile(module, OutputKind::CHeader); std::ofstream file(outputs.c_header_name); CHECK(file.is_open()) << "failed to open file " << outputs.c_header_name; file << source; file.close(); + LOG(WARNING) << "Output C header to file " << outputs.c_header_name; } if (!outputs.c_source_name.empty()) { - LOG(WARNING) << "Output C source to file " << outputs.c_source_name; auto source = Compile(module, OutputKind::CImpl); std::ofstream file(outputs.c_source_name); CHECK(file.is_open()) << "failed to open file " << outputs.c_source_name; file << source; file.close(); + LOG(WARNING) << "Output C source to file " << outputs.c_source_name; } } diff --git a/cinn/backends/codegen_c_x86.cc b/cinn/backends/codegen_c_x86.cc index 23d98560a446b..34a6b00073775 100644 --- a/cinn/backends/codegen_c_x86.cc +++ b/cinn/backends/codegen_c_x86.cc @@ -9,6 +9,8 @@ void CodeGenCX86::Visit(const ir::Mul *op) { VisitBinaryOp(op, op->a, op->b, "mu void CodeGenCX86::Visit(const ir::Div *op) { VisitBinaryOp(op, op->a, op->b, "div"); } void CodeGenCX86::Visit(const ir::Load *op) { + LOG(INFO) << "visit load arguemnt"; + Expr dense_strided_ramp = detail::StridedRampBase(op->index, 1); if (dense_strided_ramp.defined()) { // Loading a continuous Ramp address. CHECK(op->type().is_vector()); @@ -42,27 +44,23 @@ void CodeGenCX86::Visit(const ir::Store *op) { int bits = op->type().bits() * op->type().lanes(); if (SupportsAVX512()) { CHECK_EQ(bits, 512); - os() << "cinn_avx512_store(" << op->tensor.As()->name << ", " << op->value << ")"; + os() << "cinn_avx512_store("; + PrintAbsAddr(op); + os() << ", "; + Print(op->value); + os() << ")"; } else if (SupportsAVX256()) { CHECK_EQ(bits, 256); - os() << "cinn_avx256_store(" << op->tensor.As()->name << ", " << op->value << ")"; + os() << "cinn_avx256_store("; + PrintAbsAddr(op); + os() << ", "; + Print(op->value); + os() << ")"; } else { CodeGenC::Visit(op); } } -void CodeGenCX86::PrintAbsAddr(const ir::Load *op) { - os() << op->tensor.As()->name << " + "; - - auto *ramp_n = op->index.As(); - if (ramp_n) { - CHECK(!ramp_n->base.As()) << "base of a Ramp node should not be Ramp type"; - Print(ramp_n->base); - } else { - Print(op->index); - } -} - void CodeGenCX86::PrintVecInputArgument(const Expr *op) { int bits = op->type().bits() * op->type().lanes(); auto *broadcast_n = op->As(); diff --git a/cinn/backends/codegen_c_x86.h b/cinn/backends/codegen_c_x86.h index 6c0649cd04326..6e497eb01c888 100644 --- a/cinn/backends/codegen_c_x86.h +++ b/cinn/backends/codegen_c_x86.h @@ -61,7 +61,20 @@ class CodeGenCX86 : public CodeGenC { void PrintVecInputArgument(const Expr *op); //! The output argument, such as the destination for Load. void PrintVecOutputArgument(const Expr *op); - void PrintAbsAddr(const ir::Load *op); + + template + void PrintAbsAddr(const Op *op) { + os() << op->tensor.template As()->name << " + "; + + auto *ramp_n = op->index.template As(); + if (ramp_n) { + CHECK(!ramp_n->base.template As()) << "base of a Ramp node should not be Ramp type"; + Print(ramp_n->base); + } else { + Print(op->index); + } + } + template void VisitBinaryOp(const Op *op, Expr a, Expr b, const std::string &op_repr); }; diff --git a/cinn/cinn.h b/cinn/cinn.h index 3b7d55bf3be13..d95586b00e2a6 100644 --- a/cinn/cinn.h +++ b/cinn/cinn.h @@ -1,5 +1,6 @@ #pragma once #include "cinn/backends/codegen_c.h" +#include "cinn/backends/codegen_c_x86.h" #include "cinn/common/common.h" #include "cinn/lang/builtin.h" #include "cinn/lang/compute.h" @@ -10,6 +11,7 @@ namespace cinn { using backends::CodeGenC; +using backends::CodeGenCX86; using backends::Outputs; using ir::Var; using lang::Buffer; diff --git a/cinn/common/graph_utils.h b/cinn/common/graph_utils.h index 8f4b654e9968a..220445fbbf1f6 100644 --- a/cinn/common/graph_utils.h +++ b/cinn/common/graph_utils.h @@ -13,6 +13,7 @@ #include #include +#include #include "cinn/common/object.h" #include "cinn/common/shared.h" @@ -43,6 +44,23 @@ class GraphEdge : public Object { GraphNode* sink_{}; }; +} // namespace common +} // namespace cinn + +namespace std { + +template <> +struct hash> { + size_t operator()(const cinn::common::Shared& key) { + return reinterpret_cast(key->source()) ^ reinterpret_cast(key->sink()); + } +}; + +} // namespace std + +namespace cinn { +namespace common { + /** * @brief The base class of all node of graph. * This is used to normalize and share the graph operations. @@ -55,17 +73,34 @@ class GraphNode : public Object { //! Links from this to other. template std::tuple LinkTo(GraphNode* other) { + EdgeT *a, *b; + CHECK(other); CHECK_NE(other, this) << "cannot link to itself"; - other->inlinks_.push_back(make_shared(other, this)); - outlinks_.push_back(make_shared(this, other)); - return std::make_tuple(static_cast(outlinks_.back().get()), - static_cast(other->inlinks().back().get())); + auto source_edge = make_shared(this, other); + auto sink_edge = make_shared(this, other); + + outlinks_.insert(source_edge); + other->inlinks_.insert(sink_edge); + + for (auto& item : outlinks_) { + if (item->sink()->id() == other->id()) { + a = static_cast(item.get()); + break; + } + } + for (auto& item : other->inlinks_) { + if (item->sink()->id() == other->id()) { + b = static_cast(item.get()); + break; + } + } + return std::make_tuple(a, b); } //! Get the input links of the node. - virtual std::list> inlinks() const { return inlinks_; } + virtual std::set> inlinks() const { return inlinks_; } //! Get the output links of the node. - virtual std::list> outlinks() const { return outlinks_; } + virtual std::set> outlinks() const { return outlinks_; } //! Get a derived pointer. template Derived* As() { @@ -90,10 +125,10 @@ class GraphNode : public Object { protected: //! The input links of the node. //! \note We record the raw pointer rather than the shared pointer to avoid cycle reference. - std::list> inlinks_; + std::set> inlinks_; //! The output links of the node. //! \note We record the raw pointer rather than the shared pointer to avoid cycle reference. - std::list> outlinks_; + std::set> outlinks_; mutable int visited_time_{}; }; diff --git a/cinn/lang/lower.cc b/cinn/lang/lower.cc index 16f2dbd7cbd4d..28ea499cfee36 100644 --- a/cinn/lang/lower.cc +++ b/cinn/lang/lower.cc @@ -238,8 +238,9 @@ std::vector PrepareArguments(const std::vector& tensors, c std::vector Lower(const std::string& name, const std::vector& args) { // make sure the graph's start-points in the args. - auto stages = poly::GatherStagesInTensors(args); - auto graph = poly::CreateGraph(stages); + auto stages = poly::GatherStagesInTensors(args); + auto extra_dependencies = poly::ExtractExtraDependencyFromStages(stages); + auto graph = poly::CreateGraph(stages, extra_dependencies); LOG(INFO) << "Graph:\n" << graph->Visualize(); // Create a dic for stages and tensors. diff --git a/cinn/lang/tensor.cc b/cinn/lang/tensor.cc index fe05320496514..bcb2f780b233a 100644 --- a/cinn/lang/tensor.cc +++ b/cinn/lang/tensor.cc @@ -235,15 +235,15 @@ Expr _Tensor_::tensor_store_expanded_body() { } void _Tensor_::Bind(lang::Buffer &buffer) { + // Extract the tensors thouse has binded to this buffer. + buffer_depended_tensor_names_ = buffer.buffer()->binded_tensor_names(); + buffer.buffer()->BindTo(this); CHECK(!buffer->binded_tensor_names().empty()); this->buffer = buffer.buffer(); CHECK(this->buffer.defined()); CHECK(!inlined()); - // Extract the tensors thouse has binded to this buffer. - buffer_depended_tensor_names_ = this->buffer->binded_tensor_names(); - // Reset stage to nullptr to tell others this tensor should be inlined. InitStage(); } diff --git a/cinn/optim/optimize.cc b/cinn/optim/optimize.cc index 432c346b3203f..e8d3b0413e01a 100644 --- a/cinn/optim/optimize.cc +++ b/cinn/optim/optimize.cc @@ -2,6 +2,7 @@ #include "cinn/optim/ir_copy.h" #include "cinn/optim/ir_simplify.h" +#include "cinn/optim/vectorize_loops.h" namespace cinn { namespace optim { @@ -9,6 +10,7 @@ namespace optim { Expr Optimize(Expr e) { auto copied = IRCopy(e); Simplify(&copied); + VectorizeLoops(&copied, Target()); return copied; } diff --git a/cinn/optim/vectorize_loops.cc b/cinn/optim/vectorize_loops.cc index 3a0e5cc3a6108..79076d19e1f22 100644 --- a/cinn/optim/vectorize_loops.cc +++ b/cinn/optim/vectorize_loops.cc @@ -276,7 +276,10 @@ struct VectorizeLoops_ : public IRMutator { CHECK_GT(extent, 0) << "Loop over " << Expr(forloop->loop_var) << " has extent " << forloop->extent << ". Can only vectorize loops over a constant extent > 1"; + VLOG(2) << "Vectorizing " << forloop->loop_var << " extent " << extent; + VLOG(2) << "body:\n" << node->body; Vectorizer(forloop->loop_var, extent).Visit(&node->body); + VLOG(2) << "after vectorize body:\n" << node->body; // Remove the forloop. *expr = node->body; diff --git a/cinn/poly/ast_gen.cc b/cinn/poly/ast_gen.cc index 3a5c28f6c7666..65bc4183a42b4 100644 --- a/cinn/poly/ast_gen.cc +++ b/cinn/poly/ast_gen.cc @@ -352,6 +352,9 @@ void IslAstExprToCinnExpr(const isl::ast_expr& node, ir::Expr* expr) { case isl_ast_op_eq: *expr = ir::EQ::Make(ops[0], ops[1]); break; + case isl_ast_op_pdiv_q: + *expr = ir::Div::Make(ops[0], ops[1]); + break; case isl_ast_op_call: { ir::Expr caller_expr = ops.front(); // TODO(Superjomn) make it an string diff --git a/cinn/poly/poly_scheduler.cc b/cinn/poly/poly_scheduler.cc index 839913cc7b8af..0ae7e1d4795ae 100644 --- a/cinn/poly/poly_scheduler.cc +++ b/cinn/poly/poly_scheduler.cc @@ -235,7 +235,10 @@ std::unique_ptr PolyScheduler::BuildSchedule() { PolyScheduler::PolyScheduler(const std::vector& stages) { CHECK_GT(stages.size(), 0) << "No stage is provided"; - dfg_ = CreateGraph(stages); + // collect extra links + auto extra_links = ExtractExtraDependencyFromStages(stages); + + dfg_ = CreateGraph(stages, extra_links); for (auto* stage : stages) { AddStage(*stage); diff --git a/cinn/poly/stage.cc b/cinn/poly/stage.cc index 84764d3ecd33d..235e65e0c588a 100644 --- a/cinn/poly/stage.cc +++ b/cinn/poly/stage.cc @@ -42,6 +42,12 @@ Stage::Stage(const isl::set &domain, Expr expr) : domain_(domain), expr_(expr) { InitTransform(); } +std::tuple Stage::Split(int level, int factor, SplitRestStrategy strategy) { + auto dim_names = GetDimNames(transform_, isl_dim_out); + auto axis_name = dim_names.at(level); + return Split(axis_name, factor, strategy); +} + std::tuple Stage::Split(const Iterator &level, int factor, SplitRestStrategy strategy) { int offset = isl_set_find_dim_by_name(transformed_domain().get(), isl_dim_set, level.id.c_str()); CHECK_GE(offset, 0) << "iterator " << level << " not in " << domain_; @@ -215,5 +221,19 @@ std::string Stage::ith_dim_name(int level) { return dims[level]; } +Iterator Stage::ith_iterator(int level) { return Iterator(ith_dim_name(level)); } + +std::vector> ExtractExtraDependencyFromStages(const std::vector &stages) { + std::vector> extra_links; + for (auto &stage : stages) { + for (auto &tensor_name : stage->extra_depend_stages()) { + LOG(INFO) << "extra link " << tensor_name << " -> " << stage->id(); + extra_links.emplace_back(tensor_name, stage->id()); + } + } + + return extra_links; +} + } // namespace poly } // namespace cinn diff --git a/cinn/poly/stage.h b/cinn/poly/stage.h index 5a04c5ad65c9d..eefc702672e65 100644 --- a/cinn/poly/stage.h +++ b/cinn/poly/stage.h @@ -53,6 +53,8 @@ class Stage : public Object { Split(const Iterator& level, int factor, SplitRestStrategy strategy = SplitRestStrategy::kAuto); std::tuple // Split(const std::string& level, int factor, SplitRestStrategy strategy = SplitRestStrategy::kAuto); + std::tuple // + Split(int level, int factor, SplitRestStrategy strategy = SplitRestStrategy::kAuto); /** * Reorder the iterators. @@ -108,6 +110,8 @@ class Stage : public Object { //! Get the level-th dimensional name. std::string ith_dim_name(int level); + //! Get the i-th iterator. + Iterator ith_iterator(int level); //! Get the statements. std::vector input_statements() const; @@ -143,6 +147,8 @@ class Stage : public Object { std::set extra_depend_stages_; }; +std::vector> ExtractExtraDependencyFromStages(const std::vector& stages); + struct ComputeAtRelation { Shared stage; int level{-1}; diff --git a/cinn/runtime/cinn_runtime.cc b/cinn/runtime/cinn_runtime.cc index 14c060f36b21f..9249d6cf566a7 100644 --- a/cinn/runtime/cinn_runtime.cc +++ b/cinn/runtime/cinn_runtime.cc @@ -69,7 +69,10 @@ cinn_type_t cinn_float64_t() { return cinn_type_t(cinn_type_float, 64); } } // extern "C" -struct cinn_buffer_t* cinn_buffer_t::new_(cinn_device_kind_t device, cinn_type_t type, const std::vector& shape) { +struct cinn_buffer_t* cinn_buffer_t::new_(cinn_device_kind_t device, + cinn_type_t type, + const std::vector& shape, + int align) { int32_t dimensions = shape.size(); cinn_dimension_t* dims = new cinn_dimension_t[dimensions]; memcpy(dims, shape.data(), shape.size() * sizeof(int)); @@ -93,5 +96,6 @@ struct cinn_buffer_t* cinn_buffer_t::new_(cinn_device_kind_t device, cinn_type_t x->dims = dims; x->dimensions = dimensions; + x->align = align; return x; } diff --git a/cinn/runtime/cinn_runtime.h b/cinn/runtime/cinn_runtime.h index c9b46603ae597..978a44ee1f51d 100644 --- a/cinn/runtime/cinn_runtime.h +++ b/cinn/runtime/cinn_runtime.h @@ -159,6 +159,8 @@ typedef struct cinn_buffer_t { //! The actual memory size. uint64_t memory_size; + uint16_t align; + #ifdef __cplusplus cinn_buffer_t() : device(cinn_unk_device), @@ -168,9 +170,13 @@ typedef struct cinn_buffer_t { type(cinn_type_t()), dimensions(0), dims(NULL), - memory_size(0) {} + memory_size(0), + align(0) {} - static struct cinn_buffer_t* new_(cinn_device_kind_t device, cinn_type_t type, const std::vector& shape); + static struct cinn_buffer_t* new_(cinn_device_kind_t device, + cinn_type_t type, + const std::vector& shape, + int align = 0); static void delete_(struct cinn_buffer_t* x) { delete x; } ~cinn_buffer_t() { diff --git a/cinn/runtime/cinn_x86_device_impl.cc b/cinn/runtime/cinn_x86_device_impl.cc index a863a3fe53ca0..4dee0915ec020 100644 --- a/cinn/runtime/cinn_x86_device_impl.cc +++ b/cinn/runtime/cinn_x86_device_impl.cc @@ -1,3 +1,4 @@ +#include #include "cinn/runtime/cinn_runtime.h" int cinn_x86_malloc(void* context, cinn_buffer_t* buf) { @@ -9,7 +10,12 @@ int cinn_x86_malloc(void* context, cinn_buffer_t* buf) { if (buf->host_memory) { free(buf->host_memory); } - buf->host_memory = (unsigned char*)malloc(buf->type.bytes() * buf->num_elements()); + int bytes = buf->type.bytes() * buf->num_elements(); + if (buf->align == 0) { + buf->host_memory = (unsigned char*)malloc(bytes); + } else { + buf->host_memory = (uint8_t*)aligned_alloc(buf->align, bytes); + } buf->memory_size = memory_size; CINN_LOG("buf.memory size is %ld\n", buf->memory_size); } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 2cb26dcbb2492..6c5a1150fa36c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -5,7 +5,9 @@ cc_test(test01_elementwise_add_main SRCS test01_elementwise_add_main.cc DEPS cor ARGS ${global_test_args} ) cc_test(test01_elementwise_add_case - SRCS test01_elementwise_add_case.cc ${CMAKE_BINARY_DIR}/tests/test01_elementwise_add.cc + SRCS test01_elementwise_add_case.cc + ${CMAKE_BINARY_DIR}/tests/test01_elementwise_add.cc + ${CMAKE_BINARY_DIR}/tests/test01_elementwise_add_vectorize.cc DEPS core) add_dependencies(test01_elementwise_add_case test01_elementwise_add_main) @@ -15,5 +17,6 @@ cc_test(test02_matmul_main SRCS test02_matmul_main.cc DEPS core cc_test(test02_matmul_case SRCS test02_matmul_case.cc ${CMAKE_BINARY_DIR}/tests/test02_matmul.cc ${CMAKE_BINARY_DIR}/tests/test02_matmul_tile.cc + ${CMAKE_BINARY_DIR}/tests/test02_matmul_split.cc DEPS core) add_dependencies(test02_matmul_case test02_matmul_main) diff --git a/tests/test01_elementwise_add_case.cc b/tests/test01_elementwise_add_case.cc index 5c1f1f0fb9f2c..24bf9b1a5f70d 100644 --- a/tests/test01_elementwise_add_case.cc +++ b/tests/test01_elementwise_add_case.cc @@ -1,14 +1,17 @@ +#include #include #include "cinn/runtime/cinn_runtime.h" #include "tests/test01_elementwise_add.h" +#include "tests/test01_elementwise_add_vectorize.h" TEST(test01, basic) { - auto* A = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {100, 20}); - auto* B = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {100, 20}); - auto* C = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {100, 20}); + auto* A = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {100, 32}, 32); + auto* B = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {100, 32}, 32); + auto* C = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {100, 32}, 32); cinn_buffer_malloc(nullptr, A); cinn_buffer_malloc(nullptr, B); + cinn_buffer_malloc(nullptr, C); float* Ad = reinterpret_cast(A->host_memory); float* Bd = reinterpret_cast(B->host_memory); @@ -18,12 +21,20 @@ TEST(test01, basic) { Bd[i] = i; } - add1(A, B, C); - float* Cd = reinterpret_cast(C->host_memory); ASSERT_EQ(C->num_elements(), A->num_elements()); - for (int i = 0; i < C->num_elements(); i++) { - EXPECT_EQ(Ad[i] + Bd[i], Cd[i]); - } + auto check = [&] { + for (int i = 0; i < C->num_elements(); i++) { + EXPECT_EQ(Ad[i] + Bd[i], Cd[i]); + } + }; + + LOG(INFO) << "test1 basic"; + add1(A, B, C); + check(); + + LOG(INFO) << "test1 vectorize"; + add1_vectorize(A, B, C); + check(); } diff --git a/tests/test01_elementwise_add_main.cc b/tests/test01_elementwise_add_main.cc index 025feaee35b7e..c79fcf262b402 100644 --- a/tests/test01_elementwise_add_main.cc +++ b/tests/test01_elementwise_add_main.cc @@ -6,12 +6,12 @@ namespace cinn { TEST(test01_elementwise_add, basic) { - Placeholder A("A", {100, 20}); - Placeholder B("B", {100, 20}); + Placeholder A("A", {100, 32}); + Placeholder B("B", {100, 32}); Buffer C_buf(Float(32)); auto C = Compute( - {100, 20}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C"); + {100, 32}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C"); C->Bind(C_buf); Target target; @@ -25,7 +25,7 @@ TEST(test01_elementwise_add, basic) { auto func = Optimize(funcs.front()); module.Append(ir::LoweredFunc(func.As())); - module.Append(C_buf); + // module.Append(C_buf); CodeGenC compiler(target); Outputs outputs; @@ -33,4 +33,34 @@ TEST(test01_elementwise_add, basic) { compiler.Compile(module, outputs); } +TEST(test01_elementwise_add, vectorize) { + Placeholder A("A", {100, 32}); + Placeholder B("B", {100, 32}); + + Buffer C_buf(Float(32)); + auto C = Compute( + {100, 32}, [&](Var i, Var j) { return A(i, j) + B(i, j); }, "C"); + C->Bind(C_buf); + C->stage()->Vectorize(1, 8); + + Target target; + target.arch = Target::Arch ::X86; + target.bits = Target::Bit ::k32; + target.os = Target::OS ::Linux; + Module module("module2", target); + + auto funcs = Lower("add1_vectorize", {A, B, C}); + ASSERT_EQ(funcs.size(), 1UL); + + auto func = Optimize(funcs.front()); + LOG(INFO) << "after optim:\n" << func; + module.Append(ir::LoweredFunc(func.As())); + // module.Append(C_buf); + + CodeGenCX86 compiler(target, CodeGenCX86::Feature::AVX256); + Outputs outputs; + outputs = outputs.c_header("./test01_elementwise_add_vectorize.h").c_source("./test01_elementwise_add_vectorize.cc"); + compiler.Compile(module, outputs); +} + } // namespace cinn diff --git a/tests/test02_matmul_case.cc b/tests/test02_matmul_case.cc index 17f198f939beb..07248d864f342 100644 --- a/tests/test02_matmul_case.cc +++ b/tests/test02_matmul_case.cc @@ -1,7 +1,9 @@ +#include #include #include "cinn/runtime/cinn_runtime.h" #include "tests/test02_matmul.h" +#include "tests/test02_matmul_split.h" #include "tests/test02_matmul_tile.h" TEST(test02, basic) { @@ -13,6 +15,7 @@ TEST(test02, basic) { auto* B = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {K, N}); auto* C = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {M, N}); auto* C1 = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {M, N}); + auto* C2 = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {M, N}); auto* C_target = cinn_buffer_t::new_(cinn_device_kind_t::cinn_x86_device, cinn_float32_t(), {M, N}); cinn_buffer_malloc(nullptr, A); cinn_buffer_malloc(nullptr, B); @@ -24,7 +27,6 @@ TEST(test02, basic) { float* Bd = reinterpret_cast(B->host_memory); float* Cd_target = reinterpret_cast(C_target->host_memory); float* Cd = reinterpret_cast(C->host_memory); - float* Cd1 = reinterpret_cast(C1->host_memory); for (int i = 0; i < M; i++) { for (int k = 0; k < K; k++) { @@ -42,21 +44,35 @@ TEST(test02, basic) { for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { Cd_target[i * N + j] = 0.f; - Cd[i * N + j] = 0.f; + // Cd[i * N + j] = 0.f; } } - matmul(A, B, C); - matmul_tile(A, B, C1); + auto compare = [&]() { + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + EXPECT_NEAR(Cd[i * N + j], Cd_target[i * N + j], 1e-5); + } + } + }; for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { for (int k = 0; k < K; k++) { Cd_target[i * N + j] += Ad[i * K + k] * Bd[k * N + j]; } - - EXPECT_NEAR(Cd[i * N + j], Cd_target[i * N + j], 1e-5); - EXPECT_NEAR(Cd1[i * N + j], Cd_target[i * N + j], 1e-5); } } + + LOG(INFO) << "Testing matmul_basic"; + matmul(A, B, C); + compare(); + + LOG(INFO) << "Testing matmul_tile"; + matmul_tile(A, B, C); + compare(); + + LOG(INFO) << "Testing matmul_split"; + matmul_split(A, B, C); + compare(); } diff --git a/tests/test02_matmul_main.cc b/tests/test02_matmul_main.cc index ea5c2135b2166..ed67828fe3d27 100644 --- a/tests/test02_matmul_main.cc +++ b/tests/test02_matmul_main.cc @@ -5,20 +5,24 @@ namespace cinn { -TEST(test02_matmul, basic) { - const int M = 1000; - const int N = 400; - const int K = 500; +const int M = 1000; +const int N = 400; +const int K = 500; +TEST(test02_matmul, basic) { Placeholder A("A", {M, K}); Placeholder B("B", {K, N}); Var k(K, "k"); + Buffer C_buf(Float(32)); + auto C_init = Compute( + {M, N}, [&](Var i, Var j) { return Expr(0.f); }, "C_init"); + C_init->Bind(C_buf); auto C = Compute( {M, N}, [&](Var i, Var j) { return Sum(A(i, k) * B(k, j), k); }, "C", k); - Buffer C_buf(C->type()); C->Bind(C_buf); + ASSERT_EQ(C->buffer_depended_tensor_names().size(), 1UL); Target target; target.arch = Target::Arch ::X86; @@ -27,7 +31,7 @@ TEST(test02_matmul, basic) { { Module module("module1", target); - auto funcs = Lower("matmul", {A, B, C}); + auto funcs = Lower("matmul", {A, B, C, C_init}); ASSERT_EQ(funcs.size(), 1UL); auto func = Optimize(funcs.front()); @@ -45,7 +49,7 @@ TEST(test02_matmul, basic) { C->stage()->Tile(0, 1, 4, 4); Module module("module2", target); - auto funcs = Lower("matmul_tile", {A, B, C}); + auto funcs = Lower("matmul_tile", {A, B, C, C_init}); ASSERT_EQ(funcs.size(), 1UL); auto func = Optimize(funcs.front()); @@ -59,4 +63,45 @@ TEST(test02_matmul, basic) { } } +TEST(matmul, Split) { + Placeholder A("A", {M, K}); + Placeholder B("B", {K, N}); + + Var k(K, "k"); + Buffer C_buf(Float(32)); + + auto C_init = Compute( + {M, N}, [&](Var i, Var j) { return Expr(0.f); }, "C_init"); + C_init->Bind(C_buf); + auto C = Compute( + {M, N}, [&](Var i, Var j) { return Sum(A(i, k) * B(k, j), k); }, "C", k); + C->Bind(C_buf); + ASSERT_EQ(C->buffer_depended_tensor_names().size(), 1UL); + + Target target; + target.arch = Target::Arch ::X86; + target.bits = Target::Bit ::k32; + target.os = Target::OS ::Linux; + + poly::Iterator i0, i1; + std::tie(i0, i1) = C->stage()->Split(2, 16); + std::vector iterators({C->stage()->ith_iterator(1), + C->stage()->ith_iterator(0), + C->stage()->ith_iterator(2), + C->stage()->ith_iterator(3)}); + C->stage()->Reorder(iterators); + + Module module("module3", target); + auto funcs = Lower("matmul_split", {A, B, C, C_init}); + ASSERT_EQ(funcs.size(), 1UL); + + auto func = Optimize(funcs.front()); + module.Append(ir::LoweredFunc(func.As())); + + CodeGenCX86 compiler(target, CodeGenCX86::Feature::AVX512); + Outputs outputs; + outputs = outputs.c_header("./test02_matmul_split.h").c_source("./test02_matmul_split.cc"); + compiler.Compile(module, outputs); +} + } // namespace cinn