From 46e23871b36dd1fd086c9b822efac4137f378a81 Mon Sep 17 00:00:00 2001 From: Xy Ren Date: Mon, 24 Jan 2022 14:14:06 +0800 Subject: [PATCH] [opt] Remove legacy vectorization pass (#4096) --- cpp_examples/aot_save.cpp | 4 +- cpp_examples/autograd.cpp | 6 +- cpp_examples/run_snode.cpp | 6 +- python/taichi/lang/__init__.py | 1 - taichi/backends/cc/codegen_cc.cpp | 3 +- taichi/backends/opengl/codegen_opengl.cpp | 3 +- taichi/codegen/spirv/spirv_codegen.cpp | 3 +- taichi/ir/frontend_ir.cpp | 12 - taichi/ir/frontend_ir.h | 1 - taichi/ir/ir.cpp | 1 - taichi/ir/ir.h | 5 - taichi/ir/ir_builder.cpp | 13 +- taichi/ir/ir_builder.h | 3 - taichi/ir/statements.cpp | 21 +- taichi/ir/statements.h | 9 - taichi/ir/transforms.h | 4 - taichi/ir/type.cpp | 10 +- taichi/ir/type.h | 22 -- taichi/ir/type_factory.cpp | 13 - taichi/ir/type_factory.h | 2 - taichi/program/kernel.cpp | 5 +- taichi/python/export_lang.cpp | 1 - taichi/transforms/compile_to_offloads.cpp | 16 +- taichi/transforms/ir_printer.cpp | 9 +- taichi/transforms/loop_vectorize.cpp | 177 ---------- taichi/transforms/lower_access.cpp | 2 +- taichi/transforms/lower_ast.cpp | 15 +- taichi/transforms/simplify.cpp | 6 +- taichi/transforms/type_check.cpp | 11 +- taichi/transforms/vector_split.cpp | 337 -------------------- tests/cpp/analysis/same_statements_test.cpp | 2 +- 31 files changed, 46 insertions(+), 677 deletions(-) delete mode 100644 taichi/transforms/loop_vectorize.cpp delete mode 100644 taichi/transforms/vector_split.cpp diff --git a/cpp_examples/aot_save.cpp b/cpp_examples/aot_save.cpp index 3c09d2647db99..cc88ebbb53403 100644 --- a/cpp_examples/aot_save.cpp +++ b/cpp_examples/aot_save.cpp @@ -32,7 +32,7 @@ void aot_save() { IRBuilder builder; auto *zero = builder.get_int32(0); auto *n_stmt = builder.get_int32(n); - auto *loop = builder.create_range_for(zero, n_stmt, 1, 0, 4); + auto *loop = builder.create_range_for(zero, n_stmt, 0, 4); { auto _ = builder.get_loop_guard(loop); auto *index = builder.get_loop_index(loop); @@ -55,7 +55,7 @@ void aot_save() { */ IRBuilder builder; auto *sum = builder.create_local_var(PrimitiveType::i32); - auto *loop = builder.create_struct_for(pointer, 1, 0, 4); + auto *loop = builder.create_struct_for(pointer, 0, 4); { auto _ = builder.get_loop_guard(loop); auto *index = builder.get_loop_index(loop); diff --git a/cpp_examples/autograd.cpp b/cpp_examples/autograd.cpp index be257416811ab..0b3bc8f43ed95 100644 --- a/cpp_examples/autograd.cpp +++ b/cpp_examples/autograd.cpp @@ -91,7 +91,7 @@ void autograd() { auto *zero = builder.get_int32(0); auto *one = builder.get_int32(1); auto *n_stmt = builder.get_int32(n); - auto *loop = builder.create_range_for(zero, n_stmt, 1, 0, 4); + auto *loop = builder.create_range_for(zero, n_stmt, 0, 4); { auto _ = builder.get_loop_guard(loop); auto *i = builder.get_loop_index(loop); @@ -114,7 +114,7 @@ void autograd() { auto get_kernel_cal = [&](bool grad) -> Kernel * { IRBuilder builder; - auto *loop = builder.create_struct_for(a, 1, 0, 4); + auto *loop = builder.create_struct_for(a, 0, 4); { auto _ = builder.get_loop_guard(loop); auto *i = builder.get_loop_index(loop); @@ -133,7 +133,7 @@ void autograd() { { IRBuilder builder; - auto *loop = builder.create_struct_for(a, 1, 0, 4); + auto *loop = builder.create_struct_for(a, 0, 4); { auto _ = builder.get_loop_guard(loop); auto *i = builder.get_loop_index(loop); diff --git a/cpp_examples/run_snode.cpp b/cpp_examples/run_snode.cpp index a5ed8cbd890e1..992f6ae1d79f2 100644 --- a/cpp_examples/run_snode.cpp +++ b/cpp_examples/run_snode.cpp @@ -64,7 +64,7 @@ void run_snode() { IRBuilder builder; auto *zero = builder.get_int32(0); auto *n_stmt = builder.get_int32(n); - auto *loop = builder.create_range_for(zero, n_stmt, 1, 0, 4); + auto *loop = builder.create_range_for(zero, n_stmt, 0, 4); { auto _ = builder.get_loop_guard(loop); auto *index = builder.get_loop_index(loop); @@ -87,7 +87,7 @@ void run_snode() { */ IRBuilder builder; auto *sum = builder.create_local_var(PrimitiveType::i32); - auto *loop = builder.create_struct_for(pointer, 1, 0, 4); + auto *loop = builder.create_struct_for(pointer, 0, 4); { auto _ = builder.get_loop_guard(loop); auto *index = builder.get_loop_index(loop); @@ -110,7 +110,7 @@ void run_snode() { # ext = place.to_numpy() */ IRBuilder builder; - auto *loop = builder.create_struct_for(pointer, 1, 0, 4); + auto *loop = builder.create_struct_for(pointer, 0, 4); { auto _ = builder.get_loop_guard(loop); auto *index = builder.get_loop_index(loop); diff --git a/python/taichi/lang/__init__.py b/python/taichi/lang/__init__.py index 07bb2ed2b3783..02da8cddb39a7 100644 --- a/python/taichi/lang/__init__.py +++ b/python/taichi/lang/__init__.py @@ -742,7 +742,6 @@ def loop_unique(val, covers=None): parallelize = _ti_core.parallelize serialize = lambda: parallelize(1) -vectorize = _ti_core.vectorize bit_vectorize = _ti_core.bit_vectorize block_dim = _ti_core.block_dim global_thread_idx = _ti_core.insert_thread_idx_expr diff --git a/taichi/backends/cc/codegen_cc.cpp b/taichi/backends/cc/codegen_cc.cpp index 9dcc7214b657c..97473554c3d8a 100644 --- a/taichi/backends/cc/codegen_cc.cpp +++ b/taichi/backends/cc/codegen_cc.cpp @@ -51,8 +51,7 @@ class CCTransformer : public IRVisitor { auto ir = kernel_->ir.get(); auto config = kernel_->program->config; config.demote_dense_struct_fors = true; - irpass::compile_to_executable(ir, config, kernel_, - /*vectorize=*/false, kernel_->grad, + irpass::compile_to_executable(ir, config, kernel_, kernel_->grad, /*ad_use_stack=*/true, config.print_ir, /*lower_global_access*/ true); } diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index 30d056cb622b9..39e3aa0fa6abf 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -1223,8 +1223,7 @@ void OpenglCodeGen::lower() { auto ir = kernel_->ir.get(); auto &config = kernel_->program->config; config.demote_dense_struct_fors = true; - irpass::compile_to_executable(ir, config, kernel_, - /*vectorize=*/false, kernel_->grad, + irpass::compile_to_executable(ir, config, kernel_, kernel_->grad, /*ad_use_stack=*/false, config.print_ir, /*lower_global_access=*/true, /*make_thread_local=*/config.make_thread_local); diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 1ed32346b8542..a60288c612008 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -1804,8 +1804,7 @@ void KernelCodegen::run(TaichiKernelAttributes &kernel_attribs, void lower(Kernel *kernel) { auto &config = kernel->program->config; config.demote_dense_struct_fors = true; - irpass::compile_to_executable(kernel->ir.get(), config, kernel, - /*vectorize=*/false, kernel->grad, + irpass::compile_to_executable(kernel->ir.get(), config, kernel, kernel->grad, /*ad_use_stack=*/false, config.print_ir, /*lower_global_access=*/true, /*make_thread_local=*/false); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index ec6bfbc894b91..3f7c9823752b0 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -38,14 +38,12 @@ IRNode *FrontendContext::root() { FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var, const Expr &global_var) : global_var(global_var) { - vectorize = dec.vectorize; bit_vectorize = dec.bit_vectorize; num_cpu_threads = dec.num_cpu_threads; strictly_serialized = dec.strictly_serialized; block_dim = dec.block_dim; auto cfg = get_current_program().config; if (cfg.arch == Arch::cuda) { - vectorize = 1; num_cpu_threads = 1; TI_ASSERT(block_dim <= taichi_max_gpu_block_dim); } else { @@ -55,8 +53,6 @@ FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var, } mem_access_opt = dec.mem_access_opt; dec.reset(); - if (vectorize == -1) - vectorize = 1; loop_var_id.resize(loop_var.size()); for (int i = 0; i < (int)loop_var.size(); i++) { @@ -69,13 +65,11 @@ FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var, const mesh::MeshPtr &mesh, const mesh::MeshElementType &element_type) : mesh_for(true), mesh(mesh.ptr.get()), element_type(element_type) { - vectorize = dec.vectorize; bit_vectorize = dec.bit_vectorize; num_cpu_threads = dec.num_cpu_threads; block_dim = dec.block_dim; auto cfg = get_current_program().config; if (cfg.arch == Arch::cuda) { - vectorize = 1; num_cpu_threads = 1; TI_ASSERT(block_dim <= taichi_max_gpu_block_dim); } else { @@ -85,8 +79,6 @@ FrontendForStmt::FrontendForStmt(const ExprGroup &loop_var, } mem_access_opt = dec.mem_access_opt; dec.reset(); - if (vectorize == -1) - vectorize = 1; loop_var_id.resize(loop_var.size()); for (int i = 0; i < (int)loop_var.size(); i++) { @@ -105,14 +97,12 @@ FrontendForStmt::FrontendForStmt(const Expr &loop_var, const Expr &begin, const Expr &end) : begin(begin), end(end) { - vectorize = dec.vectorize; bit_vectorize = dec.bit_vectorize; num_cpu_threads = dec.num_cpu_threads; strictly_serialized = dec.strictly_serialized; block_dim = dec.block_dim; auto cfg = get_current_program().config; if (cfg.arch == Arch::cuda) { - vectorize = 1; num_cpu_threads = 1; } else { if (num_cpu_threads == 0) @@ -120,8 +110,6 @@ FrontendForStmt::FrontendForStmt(const Expr &loop_var, } mem_access_opt = dec.mem_access_opt; dec.reset(); - if (vectorize == -1) - vectorize = 1; loop_var_id.resize(1); loop_var_id[0] = loop_var.cast()->id; loop_var.expr->ret_type = PrimitiveType::i32; diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 02252c39cd520..63e8e84d30666 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -126,7 +126,6 @@ class FrontendForStmt : public Stmt { Expr global_var; std::unique_ptr body; std::vector loop_var_id; - int vectorize; int bit_vectorize; int num_cpu_threads; bool strictly_serialized; diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 5d02e785cd254..624dd9110de73 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -24,7 +24,6 @@ std::string snode_access_flag_name(SNodeAccessFlag type) { } void DecoratorRecorder::reset() { - vectorize = -1; bit_vectorize = -1; num_cpu_threads = 0; uniform = false; diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index c2cec18c85b35..cfbc88cdd6ecd 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -73,7 +73,6 @@ class MemoryAccessOptions { class DecoratorRecorder { public: - int vectorize; int bit_vectorize; int num_cpu_threads; bool strictly_serialized; @@ -708,10 +707,6 @@ struct LocalAddress { extern DecoratorRecorder dec; -inline void Vectorize(int v) { - dec.vectorize = v; -} - inline void BitVectorize(int v) { dec.bit_vectorize = v; } diff --git a/taichi/ir/ir_builder.cpp b/taichi/ir/ir_builder.cpp index 759017d106be1..716da1b16cdb8 100644 --- a/taichi/ir/ir_builder.cpp +++ b/taichi/ir/ir_builder.cpp @@ -85,34 +85,31 @@ IRBuilder::IfGuard::~IfGuard() { RangeForStmt *IRBuilder::create_range_for(Stmt *begin, Stmt *end, - int vectorize, int bit_vectorize, int num_cpu_threads, int block_dim, bool strictly_serialized) { return insert(Stmt::make_typed( - begin, end, std::make_unique(), vectorize, bit_vectorize, - num_cpu_threads, block_dim, strictly_serialized)); + begin, end, std::make_unique(), bit_vectorize, num_cpu_threads, + block_dim, strictly_serialized)); } StructForStmt *IRBuilder::create_struct_for(SNode *snode, - int vectorize, int bit_vectorize, int num_cpu_threads, int block_dim) { return insert(Stmt::make_typed( - snode, std::make_unique(), vectorize, bit_vectorize, - num_cpu_threads, block_dim)); + snode, std::make_unique(), bit_vectorize, num_cpu_threads, + block_dim)); } MeshForStmt *IRBuilder::create_mesh_for(mesh::Mesh *mesh, mesh::MeshElementType element_type, - int vectorize, int bit_vectorize, int num_cpu_threads, int block_dim) { return insert(Stmt::make_typed( - mesh, element_type, std::make_unique(), vectorize, bit_vectorize, + mesh, element_type, std::make_unique(), bit_vectorize, num_cpu_threads, block_dim)); } diff --git a/taichi/ir/ir_builder.h b/taichi/ir/ir_builder.h index b61473c00f599..049b09649c3d0 100644 --- a/taichi/ir/ir_builder.h +++ b/taichi/ir/ir_builder.h @@ -103,19 +103,16 @@ class IRBuilder { // Control flows. RangeForStmt *create_range_for(Stmt *begin, Stmt *end, - int vectorize = -1, int bit_vectorize = -1, int num_cpu_threads = 0, int block_dim = 0, bool strictly_serialized = false); StructForStmt *create_struct_for(SNode *snode, - int vectorize = -1, int bit_vectorize = -1, int num_cpu_threads = 0, int block_dim = 0); MeshForStmt *create_mesh_for(mesh::Mesh *mesh, mesh::MeshElementType element_type, - int vectorize = -1, int bit_vectorize = -1, int num_cpu_threads = 0, int block_dim = 0); diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 7960feda3620e..61d9a6aad5c26 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -236,7 +236,6 @@ std::unique_ptr ConstStmt::copy() { RangeForStmt::RangeForStmt(Stmt *begin, Stmt *end, std::unique_ptr &&body, - int vectorize, int bit_vectorize, int num_cpu_threads, int block_dim, @@ -245,7 +244,6 @@ RangeForStmt::RangeForStmt(Stmt *begin, : begin(begin), end(end), body(std::move(body)), - vectorize(vectorize), bit_vectorize(bit_vectorize), num_cpu_threads(num_cpu_threads), block_dim(block_dim), @@ -258,21 +256,19 @@ RangeForStmt::RangeForStmt(Stmt *begin, std::unique_ptr RangeForStmt::clone() const { auto new_stmt = std::make_unique( - begin, end, body->clone(), vectorize, bit_vectorize, num_cpu_threads, - block_dim, strictly_serialized); + begin, end, body->clone(), bit_vectorize, num_cpu_threads, block_dim, + strictly_serialized); new_stmt->reversed = reversed; return new_stmt; } StructForStmt::StructForStmt(SNode *snode, std::unique_ptr &&body, - int vectorize, int bit_vectorize, int num_cpu_threads, int block_dim) : snode(snode), body(std::move(body)), - vectorize(vectorize), bit_vectorize(bit_vectorize), num_cpu_threads(num_cpu_threads), block_dim(block_dim) { @@ -281,9 +277,8 @@ StructForStmt::StructForStmt(SNode *snode, } std::unique_ptr StructForStmt::clone() const { - auto new_stmt = std::make_unique(snode, body->clone(), - vectorize, bit_vectorize, - num_cpu_threads, block_dim); + auto new_stmt = std::make_unique( + snode, body->clone(), bit_vectorize, num_cpu_threads, block_dim); new_stmt->mem_access_opt = mem_access_opt; return new_stmt; } @@ -291,14 +286,12 @@ std::unique_ptr StructForStmt::clone() const { MeshForStmt::MeshForStmt(mesh::Mesh *mesh, mesh::MeshElementType element_type, std::unique_ptr &&body, - int vectorize, int bit_vectorize, int num_cpu_threads, int block_dim) : mesh(mesh), major_from_type(element_type), body(std::move(body)), - vectorize(vectorize), bit_vectorize(bit_vectorize), num_cpu_threads(num_cpu_threads), block_dim(block_dim) { @@ -307,9 +300,9 @@ MeshForStmt::MeshForStmt(mesh::Mesh *mesh, } std::unique_ptr MeshForStmt::clone() const { - auto new_stmt = std::make_unique( - mesh, major_from_type, body->clone(), vectorize, bit_vectorize, - num_cpu_threads, block_dim); + auto new_stmt = + std::make_unique(mesh, major_from_type, body->clone(), + bit_vectorize, num_cpu_threads, block_dim); new_stmt->major_to_types = major_to_types; new_stmt->minor_relation_types = minor_relation_types; new_stmt->mem_access_opt = mem_access_opt; diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index bb2d52e0ec9c6..391b47bd8f4d7 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -728,7 +728,6 @@ class RangeForStmt : public Stmt { Stmt *begin, *end; std::unique_ptr body; bool reversed; - int vectorize; int bit_vectorize; int num_cpu_threads; int block_dim; @@ -738,7 +737,6 @@ class RangeForStmt : public Stmt { RangeForStmt(Stmt *begin, Stmt *end, std::unique_ptr &&body, - int vectorize, int bit_vectorize, int num_cpu_threads, int block_dim, @@ -758,7 +756,6 @@ class RangeForStmt : public Stmt { TI_STMT_DEF_FIELDS(begin, end, reversed, - vectorize, bit_vectorize, num_cpu_threads, block_dim, @@ -777,7 +774,6 @@ class StructForStmt : public Stmt { std::unique_ptr block_initialization; std::unique_ptr block_finalization; std::vector index_offsets; - int vectorize; int bit_vectorize; int num_cpu_threads; int block_dim; @@ -785,7 +781,6 @@ class StructForStmt : public Stmt { StructForStmt(SNode *snode, std::unique_ptr &&body, - int vectorize, int bit_vectorize, int num_cpu_threads, int block_dim); @@ -798,7 +793,6 @@ class StructForStmt : public Stmt { TI_STMT_DEF_FIELDS(snode, index_offsets, - vectorize, bit_vectorize, num_cpu_threads, block_dim, @@ -813,7 +807,6 @@ class MeshForStmt : public Stmt { public: mesh::Mesh *mesh; std::unique_ptr body; - int vectorize; int bit_vectorize; int num_cpu_threads; int block_dim; @@ -825,7 +818,6 @@ class MeshForStmt : public Stmt { MeshForStmt(mesh::Mesh *mesh, mesh::MeshElementType element_type, std::unique_ptr &&body, - int vectorize, int bit_vectorize, int num_cpu_threads, int block_dim); @@ -837,7 +829,6 @@ class MeshForStmt : public Stmt { std::unique_ptr clone() const override; TI_STMT_DEF_FIELDS(mesh, - vectorize, bit_vectorize, num_cpu_threads, block_dim, diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index 58b2e1d2fde97..4ff9b74fc4204 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -53,10 +53,8 @@ void type_check(IRNode *root, const CompileConfig &config); bool inlining(IRNode *root, const CompileConfig &config, const InliningPass::Args &args); -void loop_vectorize(IRNode *root, const CompileConfig &config); void bit_loop_vectorize(IRNode *root); void slp_vectorize(IRNode *root); -void vector_split(IRNode *root, int max_width, bool serial_schedule); void replace_all_usages_with(IRNode *root, Stmt *old_stmt, Stmt *new_stmt); bool check_out_of_bound(IRNode *root, const CompileConfig &config, @@ -149,7 +147,6 @@ void compile_to_offloads(IRNode *ir, const CompileConfig &config, Kernel *kernel, bool verbose, - bool vectorize, bool grad, bool ad_use_stack, bool start_from_ast); @@ -167,7 +164,6 @@ void offload_to_executable(IRNode *ir, void compile_to_executable(IRNode *ir, const CompileConfig &config, Kernel *kernel, - bool vectorize, bool grad, bool ad_use_stack, bool verbose, diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 66b59835db4c1..eb220683bbadd 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -78,10 +78,6 @@ std::string PointerType::to_string() const { } } -std::string VectorType::to_string() const { - return fmt::format("[{} x {}]", num_elements_, element_->to_string()); -} - std::string TensorType::to_string() const { std::string s = "[Tensor ("; for (int i = 0; i < (int)shape_.size(); ++i) { @@ -92,11 +88,7 @@ std::string TensorType::to_string() const { } int Type::vector_width() const { - if (auto vec = cast()) { - return vec->get_num_elements(); - } else { - return 1; - } + return 1; // TODO: CPU vectorization } bool Type::is_primitive(PrimitiveTypeID type) const { diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 0b72af912a15d..b9a3b6cc888cb 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -154,28 +154,6 @@ class PointerType : public Type { bool is_bit_pointer_{false}; }; -class VectorType : public Type { - public: - VectorType(int num_elements, Type *element) - : num_elements_(num_elements), element_(element) { - TI_ASSERT(num_elements_ != 1); - } - - Type *get_element_type() const { - return element_; - } - - int get_num_elements() const { - return num_elements_; - } - - std::string to_string() const override; - - private: - int num_elements_{0}; - Type *element_{nullptr}; -}; - class TensorType : public Type { public: TensorType(std::vector shape, Type *element) diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 4c0724c19381e..2df31d64fc070 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -22,14 +22,6 @@ Type *TypeFactory::get_primitive_type(PrimitiveTypeID id) { return primitive_types_[id].get(); } -Type *TypeFactory::get_vector_type(int num_elements, Type *element) { - auto key = std::make_pair(num_elements, element); - if (vector_types_.find(key) == vector_types_.end()) { - vector_types_[key] = std::make_unique(num_elements, element); - } - return vector_types_[key].get(); -} - Type *TypeFactory::get_tensor_type(std::vector shape, Type *element) { auto encode = [](const std::vector &shape) -> std::string { std::string s; @@ -174,11 +166,6 @@ class TypePromotionMapping { TI_WARN("promoted_type got a pointer input."); } - if (d->is()) { - d = d->as()->get_element_type(); - TI_WARN("promoted_type got a vector input."); - } - if (d->is()) { d = d->as()->get_element_type(); TI_WARN("promoted_type got a tensor input."); diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index e414df1ef7302..ddcd498b15c47 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -17,8 +17,6 @@ class TypeFactory { PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true); - Type *get_vector_type(int num_elements, Type *element); - Type *get_tensor_type(std::vector shape, Type *element); Type *get_pointer_type(Type *element, bool is_bit_pointer = false); diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index a87af3af66fa1..8e6e9e66cbdc2 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -84,7 +84,7 @@ void Kernel::lower(bool to_executable) { if (to_executable) { irpass::compile_to_executable( - ir.get(), config, this, /*vectorize*/ arch_is_cpu(arch), grad, + ir.get(), config, this, grad, /*ad_use_stack=*/true, verbose, /*lower_global_access=*/to_executable, /*make_thread_local=*/config.make_thread_local, /*make_block_local=*/ @@ -92,8 +92,7 @@ void Kernel::lower(bool to_executable) { config.make_block_local, /*start_from_ast=*/ir_is_ast_); } else { - irpass::compile_to_offloads(ir.get(), config, this, verbose, - /*vectorize=*/arch_is_cpu(arch), grad, + irpass::compile_to_offloads(ir.get(), config, this, verbose, grad, /*ad_use_stack=*/true, /*start_from_ast=*/ir_is_ast_); } diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index ede052ad005b1..752ac617ede2c 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -1046,7 +1046,6 @@ void export_lang(py::module &m) { }); // Schedules m.def("parallelize", Parallelize); - m.def("vectorize", Vectorize); m.def("bit_vectorize", BitVectorize); m.def("block_dim", BlockDim); diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index c84447147235a..9c6c3935f7975 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -33,7 +33,6 @@ void compile_to_offloads(IRNode *ir, const CompileConfig &config, Kernel *kernel, bool verbose, - bool vectorize, bool grad, bool ad_use_stack, bool start_from_ast) { @@ -69,16 +68,6 @@ void compile_to_offloads(IRNode *ir, return; } - if (vectorize) { - irpass::loop_vectorize(ir, config); - print("Loop Vectorized"); - irpass::analysis::verify(ir); - - irpass::vector_split(ir, config.max_vector_width, config.serial_schedule); - print("Loop Split"); - irpass::analysis::verify(ir); - } - // TODO: strictly enforce bit vectorization for x86 cpu and CUDA now // create a separate CompileConfig flag for the new pass if (arch_is_cpu(config.arch) || config.arch == Arch::cuda) { @@ -272,7 +261,6 @@ void offload_to_executable(IRNode *ir, void compile_to_executable(IRNode *ir, const CompileConfig &config, Kernel *kernel, - bool vectorize, bool grad, bool ad_use_stack, bool verbose, @@ -282,8 +270,8 @@ void compile_to_executable(IRNode *ir, bool start_from_ast) { TI_AUTO_PROF; - compile_to_offloads(ir, config, kernel, verbose, vectorize, grad, - ad_use_stack, start_from_ast); + compile_to_offloads(ir, config, kernel, verbose, grad, ad_use_stack, + start_from_ast); offload_to_executable(ir, config, kernel, verbose, /*determine_ad_stack_size=*/grad && ad_use_stack, diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 0375a4bdd31c8..261cff3e00a94 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -336,18 +336,17 @@ class IRPrinter : public IRVisitor { } void visit(RangeForStmt *for_stmt) override { - print("{} : {}for in range({}, {}) (vectorize {}) (bit_vectorize {}) {}{{", + print("{} : {}for in range({}, {}) (bit_vectorize {}) {}{{", for_stmt->name(), for_stmt->reversed ? "reversed " : "", - for_stmt->begin->name(), for_stmt->end->name(), for_stmt->vectorize, + for_stmt->begin->name(), for_stmt->end->name(), for_stmt->bit_vectorize, block_dim_info(for_stmt->block_dim)); for_stmt->body->accept(this); print("}}"); } void visit(StructForStmt *for_stmt) override { - print("{} : struct for in {} (vectorize {}) (bit_vectorize {}) {}{}{{", - for_stmt->name(), for_stmt->snode->get_node_type_name_hinted(), - for_stmt->vectorize, for_stmt->bit_vectorize, + print("{} : struct for in {} (bit_vectorize {}) {}{}{{", for_stmt->name(), + for_stmt->snode->get_node_type_name_hinted(), for_stmt->bit_vectorize, scratch_pad_info(for_stmt->mem_access_opt), block_dim_info(for_stmt->block_dim)); for_stmt->body->accept(this); diff --git a/taichi/transforms/loop_vectorize.cpp b/taichi/transforms/loop_vectorize.cpp deleted file mode 100644 index dba316e41c7da..0000000000000 --- a/taichi/transforms/loop_vectorize.cpp +++ /dev/null @@ -1,177 +0,0 @@ -// The loop vectorizer - -#include "taichi/program/program.h" -#include "taichi/ir/ir.h" -#include "taichi/ir/type_factory.h" -#include "taichi/ir/statements.h" -#include "taichi/ir/transforms.h" -#include "taichi/ir/visitors.h" - -TLANG_NAMESPACE_BEGIN - -// Lower Expr tree to a bunch of binary/unary(binary/unary) statements -// Goal: eliminate Expression, and mutable local variables. Make AST SSA. -class LoopVectorize : public IRVisitor { - public: - int vectorize; - Stmt *loop_var; // an alloca... - const CompileConfig &config; - - explicit LoopVectorize(const CompileConfig &config) : config(config) { - allow_undefined_visitor = true; - invoke_default_visitor = true; - loop_var = nullptr; - vectorize = 1; - } - - static void widen_type(DataType &type, int width) { - if (width != 1) { - type = Program::get_type_factory().get_vector_type(width, type); - } - } - - void visit(Stmt *stmt) override { - widen_type(stmt->ret_type, vectorize); - } - - void visit(ConstStmt *stmt) override { - stmt->val.repeat(vectorize); - widen_type(stmt->ret_type, vectorize); - } - - void visit(Block *stmt_list) override { - std::vector statements; - for (auto &stmt : stmt_list->statements) { - statements.push_back(stmt.get()); - } - for (auto stmt : statements) { - stmt->accept(this); - } - } - - void visit(GlobalPtrStmt *ptr) override { - ptr->snodes.repeat(vectorize); - widen_type(ptr->ret_type, vectorize); - } - - void visit(AllocaStmt *alloca) override { - widen_type(alloca->ret_type, vectorize); - } - - void visit(SNodeOpStmt *stmt) override { - if (vectorize == 1) - return; - // TI_NOT_IMPLEMENTED; - /* - stmt->snodes.repeat(vectorize); - stmt->ret_type.width *= vectorize; - */ - } - - void visit(ElementShuffleStmt *stmt) override { - if (vectorize == 1) - return; - int original_width = stmt->width(); - widen_type(stmt->ret_type, vectorize); - stmt->elements.repeat(vectorize); - // TODO: this can be buggy - int stride = stmt->elements[original_width - 1].index + 1; - if (stmt->elements[0].stmt->width() != 1) { - for (int i = 0; i < vectorize; i++) { - for (int j = 0; j < original_width; j++) { - stmt->elements[i * original_width + j].index += i * stride; - } - } - } - } - - void visit(LocalLoadStmt *stmt) override { - if (vectorize == 1) - return; - int original_width = stmt->width(); - widen_type(stmt->ret_type, vectorize); - stmt->src.repeat(vectorize); - // TODO: this can be buggy - int stride = stmt->src[original_width - 1].offset + 1; - if (stmt->src[0].var->width() != 1) { - for (int i = 0; i < vectorize; i++) { - for (int j = 0; j < original_width; j++) { - stmt->src[i * original_width + j].offset += i * stride; - } - } - } - if (loop_var && stmt->same_source() && stmt->src[0].var == loop_var) { - // insert_before_me - LaneAttribute const_offsets; - const_offsets.resize(vectorize * original_width); - for (int i = 0; i < vectorize * original_width; i++) { - const_offsets[i] = TypedConstant(i / original_width); - } - auto offsets = std::make_unique(const_offsets); - auto add_op = std::make_unique(BinaryOpType::add, stmt, - offsets.get()); - irpass::type_check(add_op.get(), config); - auto offsets_p = offsets.get(); - stmt->replace_usages_with(add_op.get()); - stmt->insert_after_me(std::move(offsets)); - offsets_p->insert_after_me(std::move(add_op)); - } - } - - void visit(IfStmt *if_stmt) override { - if (if_stmt->true_statements) - if_stmt->true_statements->accept(this); - if (if_stmt->false_statements) { - if_stmt->false_statements->accept(this); - } - } - - void visit(RangeForStmt *for_stmt) override { - auto old_vectorize = for_stmt->vectorize; - if (for_stmt->vectorize != 1) - vectorize = for_stmt->vectorize; - // TODO: RangeForStmt::loop_var is deprecated - // loop_var = for_stmt->loop_var; - for_stmt->body->accept(this); - // loop_var = nullptr; - vectorize = old_vectorize; - } - - void visit(StructForStmt *for_stmt) override { - // TODO: StructForStmt::loop_var is deprecated - return; - /*if (for_stmt->loop_vars.empty()) - return; - auto old_vectorize = for_stmt->vectorize; - if (for_stmt->vectorize != 1) - vectorize = for_stmt->vectorize; - loop_var = for_stmt->loop_vars.back(); - for_stmt->body->accept(this); - loop_var = nullptr; - vectorize = old_vectorize;*/ - } - - void visit(MeshForStmt *for_stmt) override { - return; - } - - void visit(WhileStmt *stmt) override { - stmt->body->accept(this); - } - - static void run(IRNode *node, const CompileConfig &config) { - LoopVectorize inst(config); - node->accept(&inst); - } -}; - -namespace irpass { - -void loop_vectorize(IRNode *root, const CompileConfig &config) { - TI_AUTO_PROF; - return LoopVectorize::run(root, config); -} - -} // namespace irpass - -TLANG_NAMESPACE_END diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index 651ec5d734499..73d437cc7b605 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -269,7 +269,7 @@ Stmt *PtrLowererImpl::handle_snode_at_level(int level, if (!diff.linear_related()) { on_loop_tree = false; } else if (j == (int)indices_.size() - 1) { - if (!(0 <= diff.low && diff.high <= current_struct_for->vectorize)) { + if (!(0 <= diff.low && diff.high <= 1)) { // TODO: Vectorize on_loop_tree = false; } } else { diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 7bf9067f04eca..219aa1499de4c 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -207,9 +207,8 @@ class LowerAST : public IRVisitor { // statement if (is_good_range_for) { auto &&new_for = std::make_unique( - begin->stmt, end->stmt, std::move(stmt->body), stmt->vectorize, - stmt->bit_vectorize, stmt->num_cpu_threads, stmt->block_dim, - stmt->strictly_serialized); + begin->stmt, end->stmt, std::move(stmt->body), stmt->bit_vectorize, + stmt->num_cpu_threads, stmt->block_dim, stmt->strictly_serialized); new_for->body->insert(std::make_unique(new_for.get(), 0), 0); new_for->body->local_var_to_stmt[stmt->loop_var_id[0]] = @@ -268,8 +267,7 @@ class LowerAST : public IRVisitor { } else if (stmt->mesh_for) { auto &&new_for = std::make_unique( stmt->mesh, stmt->element_type, std::move(stmt->body), - stmt->vectorize, stmt->bit_vectorize, stmt->num_cpu_threads, - stmt->block_dim); + stmt->bit_vectorize, stmt->num_cpu_threads, stmt->block_dim); new_for->body->insert(std::make_unique(new_for.get(), 0), 0); new_for->body->local_var_to_stmt[stmt->loop_var_id[0]] = @@ -303,7 +301,7 @@ class LowerAST : public IRVisitor { snode = snode->parent; auto &&new_for = std::make_unique( - snode, std::move(stmt->body), stmt->vectorize, stmt->bit_vectorize, + snode, std::move(stmt->body), stmt->bit_vectorize, stmt->num_cpu_threads, stmt->block_dim); new_for->index_offsets = offsets; VecStatement new_statements; @@ -343,9 +341,8 @@ class LowerAST : public IRVisitor { } // TODO: add a note explaining why shape might be empty. auto &&new_for = std::make_unique( - begin, end, std::move(stmt->body), stmt->vectorize, - stmt->bit_vectorize, stmt->num_cpu_threads, stmt->block_dim, - stmt->strictly_serialized, + begin, end, std::move(stmt->body), stmt->bit_vectorize, + stmt->num_cpu_threads, stmt->block_dim, stmt->strictly_serialized, /*range_hint=*/fmt::format("arg {}", tensor->arg_id)); VecStatement new_statements; Stmt *loop_index = diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 98f5e0b82a057..c582ba0cec699 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -205,10 +205,8 @@ class BasicBlockSimplify : public IRVisitor { ~((1LL << (stmt->bit_begin)) - 1); auto load_addr = load.get(); modifier.insert_before(stmt, std::move(load)); - if (current_struct_for->vectorize == 1) - offset = diff.low; - if (stmt->bit_begin == 0 && - current_struct_for->vectorize == bound) { + offset = diff.low; // TODO: Vectorization + if (stmt->bit_begin == 0 && bound == 1) { // TODO: Vectorization // TODO: take care of cases where vectorization width != z // dimension of the block auto offset_stmt = Stmt::make(stmt, offset); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index fcfb3b497117f..d431bd2cf8552 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -360,10 +360,8 @@ class TypeCheck : public IRVisitor { if (stmt->op_type == TernaryOpType::select) { auto ret_type = promoted_type(stmt->op2->ret_type, stmt->op3->ret_type); TI_ASSERT(stmt->op1->ret_type->is_primitive(PrimitiveTypeID::i32)) - TI_ASSERT(stmt->op1->ret_type->vector_width() == - stmt->op2->ret_type->vector_width()); - TI_ASSERT(stmt->op2->ret_type->vector_width() == - stmt->op3->ret_type->vector_width()); + TI_ASSERT(stmt->op1->width() == stmt->op2->width()); + TI_ASSERT(stmt->op2->width() == stmt->op3->width()); if (ret_type != stmt->op2->ret_type) { auto cast_stmt = insert_type_cast_before(stmt, stmt->op2, ret_type); stmt->op2 = cast_stmt; @@ -402,18 +400,17 @@ class TypeCheck : public IRVisitor { } void visit(ArgLoadStmt *stmt) override { - const auto &rt = stmt->ret_type; // TODO: Maybe have a type_inference() pass, which takes in the args/rets // defined by the kernel. After that, type_check() pass will purely do // verification, without modifying any types. - TI_ASSERT(rt->vector_width() == 1); + TI_ASSERT(stmt->width() == 1); stmt->ret_type.set_is_pointer(stmt->is_ptr); } void visit(ReturnStmt *stmt) override { // TODO: Support stmt->ret_id? stmt->ret_type = stmt->values[0]->ret_type; - TI_ASSERT(stmt->ret_type->vector_width() == 1); + TI_ASSERT(stmt->width() == 1); } void visit(ExternalPtrStmt *stmt) override { diff --git a/taichi/transforms/vector_split.cpp b/taichi/transforms/vector_split.cpp deleted file mode 100644 index ea7bcdbc0418a..0000000000000 --- a/taichi/transforms/vector_split.cpp +++ /dev/null @@ -1,337 +0,0 @@ -// Split vectors wider than machine vector width into multiple vectors - -#include "taichi/ir/ir.h" -#include "taichi/ir/statements.h" -#include "taichi/ir/transforms.h" -#include "taichi/ir/visitors.h" -#include "taichi/program/program.h" - -#include - -TLANG_NAMESPACE_BEGIN - -class BasicBlockVectorSplit : public IRVisitor { - public: - Block *block; - std::vector statements; - std::vector> splits; - int max_width; - - int current_split_factor; - std::vector current_split; - bool need_split; - bool serial_schedule; - std::unordered_map> origin2split; - - BasicBlockVectorSplit(Block *block, int max_width, bool serial_schedule) - : block(block), max_width(max_width), serial_schedule(serial_schedule) { - // allow_undefined_visitor = true; - // invoke_default_visitor = false; - run(); - } - - int lane_start(int split) { - return split * max_width; - } - - int lane_end(int split) { - return (split + 1) * max_width; - } - - Stmt *lookup(Stmt *old, int index) { - if (origin2split.find(old) == origin2split.end()) { - TI_WARN("VectorSplitter looking for statement outside current block?"); - return old; - } else { - TI_ASSERT(0 <= index); - TI_ASSERT(index < (int)origin2split[old].size()); - return origin2split[old][index]; - } - } - - void run() { - std::vector statements = std::move(block->statements); - for (int i = 0; i < (int)statements.size(); i++) { - auto stmt = statements[i].get(); - if (stmt->width() > max_width) { - TI_ASSERT(stmt->width() % max_width == 0); - current_split_factor = stmt->width() / max_width; - current_split.resize(current_split_factor); - need_split = true; - stmt->accept(this); - origin2split[stmt] = std::vector(current_split_factor, nullptr); - for (int j = 0; j < current_split_factor; j++) { - current_split[j]->ret_type = - Program::get_type_factory().get_vector_type(max_width, - stmt->element_type()); - origin2split[stmt][j] = current_split[j].get(); - } - splits.push_back(std::move(current_split)); - } else { // recreate a statement anyway since the original one may be - // pointing to unknown statements - current_split_factor = 1; - current_split.resize(current_split_factor); - need_split = false; - stmt->accept(this); - origin2split[stmt] = std::vector(1, nullptr); - current_split[0]->element_type() = stmt->element_type(); - current_split[0]->ret_type = - Program::get_type_factory().get_vector_type(stmt->width(), - stmt->element_type()); - origin2split[stmt][0] = current_split[0].get(); - std::vector split; - split.push_back(std::move(current_split[0])); - splits.push_back(std::move(split)); - } - } - block->statements.clear(); - if (!serial_schedule) { - // finish vectors one by one - for (int i = 0; i < (int)splits.size(); i++) { - for (int j = 0;; j++) { - bool modified = false; - if (j < (int)splits[i].size()) { - block->insert(std::move(splits[i][j])); - modified = true; - } - if (!modified) { - break; - } - } - } - } else { - for (int j = 0;; j++) { - bool modified = false; - for (int i = 0; i < (int)splits.size(); i++) { - if (j < (int)splits[i].size()) { - block->insert(std::move(splits[i][j])); - modified = true; - } - } - if (!modified) { - break; - } - } - } - for (int i = 0; i < (int)block->statements.size(); i++) { - auto stmt_ = block->statements[i].get(); - if (stmt_->is()) { - auto stmt = stmt_->as(); - for (int l = 0; l < stmt->width(); l++) { - auto *old_var = stmt->src[l].var; - if (origin2split.find(old_var) != origin2split.end()) { - auto new_var = - origin2split[old_var][stmt->src[l].offset / max_width]; - stmt->src[l].var = new_var; - stmt->src[l].offset %= max_width; - // TI_WARN("replaced..."); - } - } - } - } - } - - // Visitors: set current_split[0...current_split_factor] - - void visit(GlobalPtrStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - std::vector indices; - for (int j = 0; j < (int)stmt->indices.size(); j++) { - indices.push_back(lookup(stmt->indices[j], i)); - } - current_split[i] = Stmt::make( - stmt->snodes.slice(lane_start(i), - need_split ? lane_end(i) : stmt->width()), - indices); - } - } - - void visit(ConstStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make(stmt->val.slice( - lane_start(i), need_split ? lane_end(i) : stmt->width())); - } - } - - void visit(AllocaStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) - current_split[i] = Stmt::make( - need_split ? max_width : stmt->width(), stmt->element_type()); - } - - void visit(ElementShuffleStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - LaneAttribute ptr; - int new_width = need_split ? max_width : stmt->width(); - ptr.resize(new_width); - for (int j = 0; j < new_width; j++) { - VectorElement addr(stmt->elements[lane_start(i) + j]); - if (origin2split.find(addr.stmt) == origin2split.end()) { - ptr[j] = addr; - } else { - ptr[j].stmt = lookup(addr.stmt, addr.index / max_width); - ptr[j].index = addr.index % max_width; - } - } - current_split[i] = Stmt::make(ptr); - } - } - - void visit(LocalLoadStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - LaneAttribute ptr; - int new_width = need_split ? max_width : stmt->width(); - ptr.reserve(new_width); - for (int j = 0; j < new_width; j++) { - LocalAddress addr(stmt->src[lane_start(i) + j]); - if (origin2split.find(addr.var) == origin2split.end()) { - ptr.push_back(addr); - } else { - ptr.push_back(LocalAddress(lookup(addr.var, addr.offset / max_width), - addr.offset % max_width)); - } - } - current_split[i] = Stmt::make(ptr); - } - } - - void visit(LocalStoreStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make(lookup(stmt->dest, i), - lookup(stmt->val, i)); - } - } - - void visit(GlobalLoadStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make(lookup(stmt->src, i)); - } - } - - void visit(GlobalStoreStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make(lookup(stmt->dest, i), - lookup(stmt->val, i)); - } - } - - void visit(UnaryOpStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = - Stmt::make(stmt->op_type, lookup(stmt->operand, i)); - current_split[i]->as()->cast_type = - stmt->as()->cast_type; - } - } - - void visit(BinaryOpStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make( - stmt->op_type, lookup(stmt->lhs, i), lookup(stmt->rhs, i)); - } - } - - void visit(TernaryOpStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = - Stmt::make(stmt->op_type, lookup(stmt->op1, i), - lookup(stmt->op2, i), lookup(stmt->op3, i)); - } - } - - void visit(AtomicOpStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make( - stmt->op_type, lookup(stmt->dest, i), lookup(stmt->val, i)); - } - } - - void visit(PrintStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - std::vector new_contents; - std::transform(stmt->contents.begin(), stmt->contents.end(), - std::back_inserter(new_contents), - [=](auto const &x) -> PrintStmt::EntryType { - if (std::holds_alternative(x)) { - return lookup(std::get(x), i); - } else { - return x; - } - }); - current_split[i] = Stmt::make(new_contents); - } - } - - void visit(RandStmt *stmt) override { - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make(stmt->element_type()); - } - } - - void visit(WhileControlStmt *stmt) override { - TI_ASSERT(need_split == false); - for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make(lookup(stmt->mask, i), - lookup(stmt->cond, i)); - } - } -}; - -// Goal: eliminate vectors that are longer than physical vector width (e.g. 8 -// on AVX2) -class VectorSplit : public IRVisitor { - public: - int max_width; - bool serial_schedule; - - VectorSplit(IRNode *node, int max_width, bool serial_schedule) - : max_width(max_width), serial_schedule(serial_schedule) { - allow_undefined_visitor = true; - invoke_default_visitor = true; - node->accept(this); - } - - void visit(Block *block) override { - if (!block->has_container_statements()) { - bool all_within_width = true; - for (auto &stmt : block->statements) { - if (stmt->width() > max_width) { - all_within_width = false; - } - } - if (!all_within_width) - BasicBlockVectorSplit(block, max_width, serial_schedule); - } else { - for (auto &stmt : block->statements) { - stmt->accept(this); - } - } - } - - void visit(IfStmt *if_stmt) override { - if (if_stmt->true_statements) - if_stmt->true_statements->accept(this); - if (if_stmt->false_statements) { - if_stmt->false_statements->accept(this); - } - } - - void visit(RangeForStmt *for_stmt) override { - for_stmt->body->accept(this); - } - - void visit(WhileStmt *stmt) override { - stmt->body->accept(this); - } -}; - -namespace irpass { - -void vector_split(IRNode *root, int max_width, bool serial_schedule) { - TI_AUTO_PROF; - VectorSplit(root, max_width, serial_schedule); -} - -} // namespace irpass - -TLANG_NAMESPACE_END diff --git a/tests/cpp/analysis/same_statements_test.cpp b/tests/cpp/analysis/same_statements_test.cpp index 05391e7d783c3..3c20bd2cbfcf3 100644 --- a/tests/cpp/analysis/same_statements_test.cpp +++ b/tests/cpp/analysis/same_statements_test.cpp @@ -158,7 +158,7 @@ TEST(SameStatements, TestSameLoopIndex) { auto range_for = block ->push_back(zero, four, std::make_unique(), 1, 1, - 1, 1, false) + 1, false) ->as(); auto loop_index_a = range_for->body->push_back(range_for, 0); auto loop_index_b = range_for->body->push_back(range_for, 0);