From ec742aaeeb6aefd26de0aa389953fca3cfee11a7 Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Sun, 6 Jun 2021 23:27:38 +0800 Subject: [PATCH 1/3] [ir] Make lower_scalar_pointer testable --- taichi/analysis/arithmetic_interpretor.cpp | 33 +++ taichi/analysis/arithmetic_interpretor.h | 37 ++- taichi/ir/ir_builder.cpp | 2 +- taichi/ir/ir_builder.h | 4 +- taichi/transforms/lower_access.cpp | 227 +++++++++--------- taichi/transforms/scalar_pointer_lowerer.cpp | 82 +++++++ taichi/transforms/scalar_pointer_lowerer.h | 75 ++++++ .../scalar_pointer_lowerer_test.cpp | 107 +++++++++ 8 files changed, 448 insertions(+), 119 deletions(-) create mode 100644 taichi/transforms/scalar_pointer_lowerer.cpp create mode 100644 taichi/transforms/scalar_pointer_lowerer.h create mode 100644 tests/cpp/transforms/scalar_pointer_lowerer_test.cpp diff --git a/taichi/analysis/arithmetic_interpretor.cpp b/taichi/analysis/arithmetic_interpretor.cpp index 95045fcac7648..25557bb515df2 100644 --- a/taichi/analysis/arithmetic_interpretor.cpp +++ b/taichi/analysis/arithmetic_interpretor.cpp @@ -97,7 +97,35 @@ class EvalVisitor : public IRVisitor { } } + void visit(BitExtractStmt *stmt) override { + auto val_opt = context_.maybe_get(stmt->input); + if (!val_opt) { + failed_ = true; + return; + } + const uint64_t mask = (1ULL << (stmt->bit_end - stmt->bit_begin)) - 1; + auto val = val_opt.value().val_int(); + val = (val >> stmt->bit_begin) & mask; + insert_to_ctx(stmt, stmt->ret_type, val); + } + + void visit(LinearizeStmt *stmt) override { + int64_t val = 0; + for (int i = 0; i < (int)stmt->inputs.size(); ++i) { + auto idx_opt = context_.maybe_get(stmt->inputs[i]); + if (!idx_opt) { + failed_ = true; + return; + } + val = (val * stmt->strides[i]) + idx_opt.value().val_int(); + } + insert_to_ctx(stmt, stmt->ret_type, val); + } + void visit(Stmt *stmt) override { + if (context_.should_ignore(stmt)) { + return; + } failed_ = (context_.maybe_get(stmt) == std::nullopt); } @@ -135,6 +163,11 @@ class EvalVisitor : public IRVisitor { context_.insert(stmt, TypedConstant(dt, val_opt.value())); } + template + void insert_to_ctx(const Stmt *stmt, DataType dt, const T &val) { + context_.insert(stmt, TypedConstant(dt, val)); + } + EvalContext context_; bool failed_{false}; }; diff --git a/taichi/analysis/arithmetic_interpretor.h b/taichi/analysis/arithmetic_interpretor.h index 8d44027f9d4b6..b9c6581726ac3 100644 --- a/taichi/analysis/arithmetic_interpretor.h +++ b/taichi/analysis/arithmetic_interpretor.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include "taichi/ir/statements.h" @@ -20,11 +21,23 @@ class ArithmeticInterpretor { */ class EvalContext { public: + /** + * Pre-defines a value for statement @param s. + * + * @param s: Statement to be evaluated + * @param c: Predefined value + */ EvalContext &insert(const Stmt *s, TypedConstant c) { map_[s] = c; return *this; } + /** + * Tries to get the evaluated value for statement @param s. + * + * @param s: Statement to get + * @return: The evaluated value, empty if not found. + */ std::optional maybe_get(const Stmt *s) const { auto itr = map_.find(s); if (itr == map_.end()) { @@ -33,8 +46,30 @@ class ArithmeticInterpretor { return itr->second; } + /** + * Tells the interpretor to ignore statement @param s. + * + * This is effective only for statements that are not supported by + * ArithmeticInterpretor. + * + * @param s: Statemet to ignore + */ + void ignore(const Stmt *s) { + ignored_.insert(s); + } + + /** + * Checks if statement @param s is ignored. + * + * @return: True if ignored + */ + bool should_ignore(const Stmt *s) { + return ignored_.count(s) > 0; + } + private: std::unordered_map map_; + std::unordered_set ignored_; }; /** @@ -55,7 +90,7 @@ class ArithmeticInterpretor { * Evaluates the sequence of CHI as defined in |region|. * @param region: A sequence of CHI statements to be evaluated * @param init_ctx: This context can mock the result for certain types of - * statements that are not supported, or cannot be evaluated statically. + * statements that are not supported, or cannot be evaluated statically. */ std::optional evaluate(const CodeRegion ®ion, const EvalContext &init_ctx) const; diff --git a/taichi/ir/ir_builder.cpp b/taichi/ir/ir_builder.cpp index 5521c92be99ba..36cd0cc1fa0d8 100644 --- a/taichi/ir/ir_builder.cpp +++ b/taichi/ir/ir_builder.cpp @@ -23,7 +23,7 @@ void IRBuilder::reset() { insert_point_.position = 0; } -std::unique_ptr IRBuilder::extract_ir() { +std::unique_ptr IRBuilder::extract_ir() { auto result = std::move(root_); reset(); return result; diff --git a/taichi/ir/ir_builder.h b/taichi/ir/ir_builder.h index 3a6cb7526bfd4..1236cab9c762b 100644 --- a/taichi/ir/ir_builder.h +++ b/taichi/ir/ir_builder.h @@ -19,7 +19,7 @@ class IRBuilder { void reset(); // Extract the IR. - std::unique_ptr extract_ir(); + std::unique_ptr extract_ir(); // General inserter. Returns stmt.get(). template @@ -235,7 +235,7 @@ class IRBuilder { } private: - std::unique_ptr root_{nullptr}; + std::unique_ptr root_{nullptr}; InsertPoint insert_point_; }; diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index 7b2e0b3ec57a8..1e907f6dca5ef 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -6,11 +6,37 @@ #include "taichi/program/kernel.h" #include "taichi/program/program.h" #include "taichi/transforms/lower_access.h" +#include "taichi/transforms/scalar_pointer_lowerer.h" #include #include -TLANG_NAMESPACE_BEGIN +namespace taichi { +namespace lang { +namespace { + +class LowerAccess; + +class PtrLowererImpl : public ScalarPointerLowerer { + public: + using ScalarPointerLowerer::ScalarPointerLowerer; + + void set_lower_access(LowerAccess *la); + + void set_pointer_needs_activation(bool v) { + pointer_needs_activation_ = v; + } + + protected: + Stmt *handle_snode_at_level(int level, + LinearizeStmt *linearized, + Stmt *last) override; + + private: + LowerAccess *la_{nullptr}; + std::unordered_set snodes_on_loop_; + bool pointer_needs_activation_{false}; +}; // Lower GlobalPtrStmt into smaller pieces for access optimization @@ -62,120 +88,22 @@ class LowerAccess : public IRVisitor { current_struct_for = nullptr; } - void lower_scalar_ptr(VecStatement &lowered, - SNode *leaf_snode, - std::vector indices, - bool pointer_needs_activation, - SNodeOpType snode_op = SNodeOpType::undefined, - bool is_bit_vectorized = false) { + void lower_scalar_ptr(SNode *leaf_snode, + const std::vector indices, + const bool pointer_needs_activation, + const SNodeOpType snode_op, + const bool is_bit_vectorized, + VecStatement *lowered) { if (snode_op == SNodeOpType::is_active) { // For ti.is_active TI_ASSERT(!pointer_needs_activation); } - // emit a sequence of micro access ops - std::set nodes_on_loop; - if (current_struct_for) { - for (SNode *s = current_struct_for->snode; s != nullptr; s = s->parent) { - nodes_on_loop.insert(s); - } - } - - // start_bits is the index of the starting bit for a coordinate - // for a given SNode. It characterizes the relationship between a parent - // and a child SNode: "parent.start = child.start + child.num_bits". - // - // For example, if there are two 1D snodes a and b, - // where a = ti.root.dense(ti.i, 2) and b = a.dense(ti.i, 8), - // we have a.start = b.start + 3 for the i-th dimension. - // When accessing b[15], then bits [0, 3) of 15 are for accessing b, - // and bit [3, 4) of 15 is for accessing a. - int start_bits[taichi_max_num_indices] = {0}; - std::deque snodes; - for (auto s = leaf_snode; s != nullptr; s = s->parent) { - snodes.push_front(s); - for (int j = 0; j < taichi_max_num_indices; j++) { - start_bits[j] += s->extractors[j].num_bits; - } - } - - Stmt *last = lowered.push_back(); - - int path_inc = int(snode_op != SNodeOpType::undefined); - int length = (int)snodes.size() - 1 + path_inc; - for (int i = 0; i < length; i++) { - auto snode = snodes[i]; - if (is_bit_vectorized && snode->type == SNodeType::bit_array && - i == length - 1 && snodes[i - 1]->type == SNodeType::dense) { - continue; - } - std::vector lowered_indices; - std::vector strides; - // extract bits - for (int k_ = 0; k_ < (int)indices.size(); k_++) { - for (int k = 0; k < taichi_max_num_indices; k++) { - if (snode->physical_index_position[k_] == k) { - start_bits[k] -= snode->extractors[k].num_bits; - int begin = start_bits[k]; - int end = begin + snode->extractors[k].num_bits; - auto extracted = - Stmt::make(indices[k_], begin, end); - lowered_indices.push_back(extracted.get()); - lowered.push_back(std::move(extracted)); - strides.push_back(1 << snode->extractors[k].num_bits); - } - } - } - - bool on_loop_tree = nodes_on_loop.find(snode) != nodes_on_loop.end(); - if (on_loop_tree && - indices.size() == current_struct_for->snode->num_active_indices) { - for (int j = 0; j < (int)indices.size(); j++) { - auto diff = irpass::analysis::value_diff_loop_index( - indices[j], current_struct_for, j); - if (!diff.linear_related()) - on_loop_tree = false; - else if (j == (int)indices.size() - 1) { - if (!(0 <= diff.low && - diff.high <= current_struct_for->vectorize)) { - on_loop_tree = false; - } - } else { - if (!diff.certain() || diff.low != 0) { - on_loop_tree = false; - } - } - } - } - - // linearize - auto linearized = - lowered.push_back(lowered_indices, strides); - - if (snode_op != SNodeOpType::undefined && i == (int)snodes.size() - 1) { - // Create a SNodeOp querying if element i(linearized) of node is active - lowered.push_back(snode_op, snodes[i], last, linearized); - } else { - bool kernel_forces_no_activate_snode = - std::find(kernel_forces_no_activate.begin(), - kernel_forces_no_activate.end(), - snode) != kernel_forces_no_activate.end(); - - auto needs_activation = - snode->need_activation() && pointer_needs_activation && - !kernel_forces_no_activate_snode && !on_loop_tree; - - auto lookup = lowered.push_back( - snode, last, linearized, needs_activation); - int chid = snode->child_id(snodes[i + 1]); - if (is_bit_vectorized && snode->type == SNodeType::dense && - i == length - 2) { - last = lowered.push_back(lookup, chid, true); - } else { - last = lowered.push_back(lookup, chid, false); - } - } - } + PtrLowererImpl lowerer{leaf_snode, indices, snode_op, is_bit_vectorized, + lowered}; + lowerer.set_pointer_needs_activation(pointer_needs_activation); + lowerer.set_lower_access(this); + lowerer.run(); } VecStatement lower_vector_ptr(GlobalPtrStmt *ptr, @@ -191,9 +119,9 @@ class LowerAccess : public IRVisitor { indices.push_back(extractor.get()); lowered.push_back(std::move(extractor)); } - lower_scalar_ptr(lowered, ptr->snodes[i], indices, activate, snode_op, - ptr->is_bit_vectorized); - TI_ASSERT(lowered.size()); + lower_scalar_ptr(ptr->snodes[i], indices, activate, snode_op, + ptr->is_bit_vectorized, &lowered); + TI_ASSERT(lowered.size() > 0); lowered_pointers.push_back(lowered.back().get()); } // create shuffle @@ -290,6 +218,75 @@ class LowerAccess : public IRVisitor { } }; +void PtrLowererImpl::set_lower_access(LowerAccess *la) { + la_ = la; + + snodes_on_loop_.clear(); + if (la_->current_struct_for) { + for (SNode *s = la_->current_struct_for->snode; s != nullptr; + s = s->parent) { + snodes_on_loop_.insert(s); + } + } +} + +Stmt *PtrLowererImpl::handle_snode_at_level(int level, + LinearizeStmt *linearized, + Stmt *last) { + // Check whether |snode| is part of the tree being iterated over by struct for + auto *snode = snodes()[level]; + bool on_loop_tree = (snodes_on_loop_.find(snode) != snodes_on_loop_.end()); + auto *current_struct_for = la_->current_struct_for; + if (on_loop_tree && current_struct_for && + (indices_.size() == current_struct_for->snode->num_active_indices)) { + for (int j = 0; j < (int)indices_.size(); j++) { + auto diff = irpass::analysis::value_diff_loop_index( + indices_[j], current_struct_for, j); + if (!diff.linear_related()) { + on_loop_tree = false; + } else if (j == (int)indices_.size() - 1) { + if (!(0 <= diff.low && diff.high <= current_struct_for->vectorize)) { + on_loop_tree = false; + } + } else { + if (!diff.certain() || diff.low != 0) { + on_loop_tree = false; + } + } + } + } + + if ((snode_op_ != SNodeOpType::undefined) && + (level == (int)snodes().size() - 1)) { + // Create a SNodeOp querying if element i(linearized) of node is active + lowered_->push_back(snode_op_, snode, last, linearized); + } else { + const bool kernel_forces_no_activate_snode = + std::find(la_->kernel_forces_no_activate.begin(), + la_->kernel_forces_no_activate.end(), + snode) != la_->kernel_forces_no_activate.end(); + + const bool needs_activation = + snode->need_activation() && pointer_needs_activation_ && + !kernel_forces_no_activate_snode && !on_loop_tree; + + auto lookup = lowered_->push_back(snode, last, linearized, + needs_activation); + int chid = snode->child_id(snodes()[level + 1]); + if (is_bit_vectorized_ && (snode->type == SNodeType::dense) && + (level == path_length() - 2)) { + last = lowered_->push_back(lookup, chid, + /*is_bit_vectorized=*/true); + } else { + last = lowered_->push_back(lookup, chid, + /*is_bit_vectorized=*/false); + } + } + return last; +} + +} // namespace + const PassID LowerAccessPass::id = "LowerAccessPass"; namespace irpass { @@ -304,5 +301,5 @@ bool lower_access(IRNode *root, } } // namespace irpass - -TLANG_NAMESPACE_END +} // namespace lang +} // namespace taichi diff --git a/taichi/transforms/scalar_pointer_lowerer.cpp b/taichi/transforms/scalar_pointer_lowerer.cpp new file mode 100644 index 0000000000000..b52b29e102e51 --- /dev/null +++ b/taichi/transforms/scalar_pointer_lowerer.cpp @@ -0,0 +1,82 @@ +#include +#include + +#include "taichi/inc/constants.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/snode.h" +#include "taichi/ir/statements.h" +#include "taichi/transforms/scalar_pointer_lowerer.h" + +namespace taichi { +namespace lang { + +ScalarPointerLowerer::ScalarPointerLowerer(SNode *leaf_snode, + const std::vector &indices, + const SNodeOpType snode_op, + const bool is_bit_vectorized, + VecStatement *lowered) + : indices_(indices), + snode_op_(snode_op), + is_bit_vectorized_(is_bit_vectorized), + lowered_(lowered) { + for (auto *s = leaf_snode; s != nullptr; s = s->parent) { + snodes_.push_back(s); + } + // From root to leaf + std::reverse(snodes_.begin(), snodes_.end()); + + const int path_inc = (int)(snode_op_ != SNodeOpType::undefined); + path_length_ = (int)snodes_.size() - 1 + path_inc; +} + +void ScalarPointerLowerer::run() { + // |start_bits| is the index of the starting bit for a coordinate + // for a given SNode. It characterizes the relationship between a parent + // and a child SNode: "parent.start = child.start + child.num_bits". + // + // For example, if there are two 1D snodes a and b, + // where a = ti.root.dense(ti.i, 2) and b = a.dense(ti.i, 8), + // we have a.start = b.start + 3 for the i-th dimension. + // When accessing b[15], then bits [0, 3) of 15 are for accessing b, + // and bit [3, 4) of 15 is for accessing a. + std::array start_bits = {0}; + for (const auto *s : snodes_) { + for (int j = 0; j < taichi_max_num_indices; j++) { + start_bits[j] += s->extractors[j].num_bits; + } + } + + Stmt *last = lowered_->push_back(); + for (int i = 0; i < path_length_; i++) { + auto *snode = snodes_[i]; + // TODO: Explain this condition + if (is_bit_vectorized_ && (snode->type == SNodeType::bit_array) && + (i == path_length_ - 1) && (snodes_[i - 1]->type == SNodeType::dense)) { + continue; + } + std::vector lowered_indices; + std::vector strides; + // extract bits + for (int k_ = 0; k_ < (int)indices_.size(); k_++) { + for (int k = 0; k < taichi_max_num_indices; k++) { + if (snode->physical_index_position[k_] == k) { + start_bits[k] -= snode->extractors[k].num_bits; + const int begin = start_bits[k]; + const int end = begin + snode->extractors[k].num_bits; + auto extracted = Stmt::make(indices_[k_], begin, end); + lowered_indices.push_back(extracted.get()); + lowered_->push_back(std::move(extracted)); + strides.push_back(1 << snode->extractors[k].num_bits); + } + } + } + // linearize + auto *linearized = + lowered_->push_back(lowered_indices, strides); + + last = handle_snode_at_level(i, linearized, last); + } +} + +} // namespace lang +} // namespace taichi diff --git a/taichi/transforms/scalar_pointer_lowerer.h b/taichi/transforms/scalar_pointer_lowerer.h new file mode 100644 index 0000000000000..78f67a0d8d6ad --- /dev/null +++ b/taichi/transforms/scalar_pointer_lowerer.h @@ -0,0 +1,75 @@ +#pragma once + +#include + +#include "taichi/ir/stmt_op_types.h" + +namespace taichi { +namespace lang { + +class LinearizeStmt; +class SNode; +class Stmt; +class StructForStmt; +class VecStatement; + +/** + * Lowers an SNode at a given indices to a series of concrete ops. + */ +class ScalarPointerLowerer { + public: + /** + * Constructor + * + * @param leaf_snode: SNode of the accessed field + * @param indices: Indices to access the field + * @param snode_op: SNode operation + * @param is_bit_vectorized: Is @param leaf_snode bit vectorized + * @param lowered: Collects the output ops + */ + explicit ScalarPointerLowerer(SNode *leaf_snode, + const std::vector &indices, + const SNodeOpType snode_op, + const bool is_bit_vectorized, + VecStatement *lowered); + + virtual ~ScalarPointerLowerer() = default; + /** + * Runs the lowering process. + * + * This can only be called once. + */ + void run(); + + protected: + /** + * @param level: Level + * @param linearized: a + * @param last: + */ + virtual Stmt *handle_snode_at_level(int level, + LinearizeStmt *linearized, + Stmt *last) { + return last; + } + + std::vector snodes() const { + return snodes_; + } + + int path_length() const { + return path_length_; + } + + const std::vector indices_; + const SNodeOpType snode_op_; + const bool is_bit_vectorized_; + VecStatement *const lowered_; + + private: + std::vector snodes_; + int path_length_{0}; +}; + +} // namespace lang +} // namespace taichi diff --git a/tests/cpp/transforms/scalar_pointer_lowerer_test.cpp b/tests/cpp/transforms/scalar_pointer_lowerer_test.cpp new file mode 100644 index 0000000000000..7de569a3cfaa2 --- /dev/null +++ b/tests/cpp/transforms/scalar_pointer_lowerer_test.cpp @@ -0,0 +1,107 @@ +#include + +#include "gtest/gtest.h" +#include "taichi/analysis/arithmetic_interpretor.h" +#include "taichi/ir/ir.h" +#include "taichi/ir/ir_builder.h" +#include "taichi/ir/snode.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/transforms.h" +#include "taichi/transforms/scalar_pointer_lowerer.h" +#include "tests/cpp/struct/fake_struct_compiler.h" + +namespace taichi { +namespace lang { +namespace { + +constexpr int kPointerSize = 4; +constexpr int kDenseSize = 8; + +class LowererImpl : public ScalarPointerLowerer { + public: + using ScalarPointerLowerer::ScalarPointerLowerer; + std::vector linears; + + protected: + Stmt *handle_snode_at_level(int level, + LinearizeStmt *linearized, + Stmt *last) override { + linears.push_back(linearized); + return last; + } +}; + +class ScalarPointerLowererTest : public ::testing::Test { + protected: + void SetUp() override { + root_snode_ = std::make_unique(/*depth=*/0, /*t=*/SNodeType::root); + const std::vector indices = {Index{0}}; + ptr_snode_ = &(root_snode_->pointer(indices, kPointerSize)); + dense_snode_ = &(ptr_snode_->dense(indices, kDenseSize)); + // Must end with a `place` SNode. + leaf_snode_ = &(dense_snode_->insert_children(SNodeType::place)); + leaf_snode_->dt = PrimitiveType::f32; + + FakeStructCompiler sc; + sc.run(*root_snode_); + } + + const CompileConfig cfg_; + std::unique_ptr root_snode_{nullptr}; + SNode *ptr_snode_{nullptr}; + SNode *dense_snode_{nullptr}; + SNode *leaf_snode_{nullptr}; +}; + +TEST_F(ScalarPointerLowererTest, Basic) { + IRBuilder builder; + for (int i = 0; i < kPointerSize; ++i) { + for (int j = 0; j < kDenseSize; ++j) { + const int loop_index = (i * kDenseSize) + j; + VecStatement lowered; + LowererImpl lowerer{leaf_snode_, + std::vector{builder.get_int32(loop_index)}, + SNodeOpType::undefined, + /*is_bit_vectorized=*/false, &lowered}; + lowerer.run(); + // There are three linearized stmts: + // 0: for root + // 1: for pointer + // 2: for dense + constexpr int kPointerLevel = 1; + constexpr int kDenseLevel = 2; + ASSERT_EQ(lowerer.linears.size(), 3); + + auto block = builder.extract_ir(); + block->insert(std::move(lowered)); + // Set types so that ArithmeticInterpretor can run correctly + irpass::type_check(block.get(), cfg_); + + ArithmeticInterpretor::CodeRegion code_region; + code_region.block = block.get(); + + ArithmeticInterpretor::EvalContext init_ctx; + for (auto &stmt : code_region.block->statements) { + if (stmt->is()) { + init_ctx.ignore(stmt.get()); + break; + } + } + + ArithmeticInterpretor ai; + code_region.end = lowerer.linears[kPointerLevel]; + auto res_opt = ai.evaluate(code_region, init_ctx); + ASSERT_TRUE(res_opt.has_value()); + EXPECT_EQ(res_opt.value(), i); + + code_region.end = lowerer.linears[kDenseLevel]; + res_opt = ai.evaluate(code_region, init_ctx); + ASSERT_TRUE(res_opt.has_value()); + EXPECT_EQ(res_opt.value(), j); + } + } +} + +} // namespace +} // namespace lang +} // namespace taichi From 5fc10b7f0107ce48d5614d40be31ab26ccfa6eb8 Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Sun, 6 Jun 2021 23:36:38 +0800 Subject: [PATCH 2/3] comments --- taichi/transforms/lower_access.cpp | 1 + taichi/transforms/scalar_pointer_lowerer.h | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index 1e907f6dca5ef..c002a38a87d14 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -256,6 +256,7 @@ Stmt *PtrLowererImpl::handle_snode_at_level(int level, } } + // Generates the SNode access operations at the current |level|. if ((snode_op_ != SNodeOpType::undefined) && (level == (int)snodes().size() - 1)) { // Create a SNodeOp querying if element i(linearized) of node is active diff --git a/taichi/transforms/scalar_pointer_lowerer.h b/taichi/transforms/scalar_pointer_lowerer.h index 78f67a0d8d6ad..574e9c9e6eceb 100644 --- a/taichi/transforms/scalar_pointer_lowerer.h +++ b/taichi/transforms/scalar_pointer_lowerer.h @@ -43,9 +43,11 @@ class ScalarPointerLowerer { protected: /** - * @param level: Level - * @param linearized: a - * @param last: + * Handles the SNode at a given @param level. + * + * @param level: Level of the SNode in the access path + * @param linearized: Linearized indices statement for this level + * @param last: SNode access op (e.g. GetCh) of the last iteration */ virtual Stmt *handle_snode_at_level(int level, LinearizeStmt *linearized, From b9097d01bfd60312c4e6861ab769e1f2ec3e044e Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Tue, 8 Jun 2021 13:51:04 +0800 Subject: [PATCH 3/3] Update taichi/analysis/arithmetic_interpretor.h Co-authored-by: xumingkuan --- taichi/analysis/arithmetic_interpretor.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/analysis/arithmetic_interpretor.h b/taichi/analysis/arithmetic_interpretor.h index b9c6581726ac3..6c56a8ed3db0c 100644 --- a/taichi/analysis/arithmetic_interpretor.h +++ b/taichi/analysis/arithmetic_interpretor.h @@ -52,7 +52,7 @@ class ArithmeticInterpretor { * This is effective only for statements that are not supported by * ArithmeticInterpretor. * - * @param s: Statemet to ignore + * @param s: Statement to ignore */ void ignore(const Stmt *s) { ignored_.insert(s);