From 9feea6bb1d67b83870863edc5ce22a3761f42103 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Tue, 13 Oct 2020 16:18:17 -0400 Subject: [PATCH 01/18] Let DataType pretend to be LegacyVectorType --- taichi/ir/ir.cpp | 4 ++++ taichi/ir/ir.h | 44 +++++------------------------------------- taichi/ir/statements.h | 6 +++--- taichi/ir/type.cpp | 8 +++++++- taichi/ir/type.h | 24 +++++++++++++++++++++-- 5 files changed, 41 insertions(+), 45 deletions(-) diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 1f864055d3f2e..b886e5ce064d9 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -183,6 +183,10 @@ Stmt::Stmt(const Stmt &stmt) : field_manager(this), fields_registered(false) { ret_type = stmt.ret_type; } +Stmt::Stmt(const DataType &data_type): Stmt() { + +} + Stmt *Stmt::insert_before_me(std::unique_ptr &&new_stmt) { auto ret = new_stmt.get(); TI_ASSERT(parent); diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index be0e8b8b32e5b..193701e8f1643 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -34,43 +34,6 @@ using ScratchPadOptions = std::vector>; IRBuilder ¤t_ast_builder(); -struct LegacyVectorType { - private: - bool _is_pointer; - - public: - int width; - DataType data_type; - - LegacyVectorType(int width, DataType data_type, bool is_pointer = false) - : _is_pointer(is_pointer), width(width), data_type(data_type) { - } - - LegacyVectorType() - : _is_pointer(false), width(1), data_type(PrimitiveType::unknown) { - } - - bool operator==(const LegacyVectorType &o) const { - return width == o.width && data_type == o.data_type; - } - - bool operator!=(const LegacyVectorType &o) const { - return !(*this == o); - } - - std::string pointer_suffix() const; - std::string element_type_name() const; - std::string str() const; - - bool is_pointer() const { - return _is_pointer; - } - - void set_is_pointer(bool v) { - _is_pointer = v; - } -}; - class DecoratorRecorder { public: int vectorize; @@ -531,11 +494,14 @@ class Stmt : public IRNode { bool erased; bool fields_registered; std::string tb; - LegacyVectorType ret_type; + DataType ret_type; Stmt(); Stmt(const Stmt &stmt); + // TODO: remove this after type refactoring is done + Stmt(const DataType &data_type); + int &width() { return ret_type.width; } @@ -553,7 +519,7 @@ class Stmt : public IRNode { } std::string ret_data_type_name() const { - return ret_type.str(); + return ret_type->to_string(); } std::string type_hint() const; diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 5daa994525269..ae251bc070713 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -934,7 +934,7 @@ class GlobalTemporaryStmt : public Stmt { public: std::size_t offset; - GlobalTemporaryStmt(std::size_t offset, LegacyVectorType ret_type) + GlobalTemporaryStmt(std::size_t offset, DataType ret_type) : offset(offset) { this->ret_type = ret_type; TI_STMT_REG_FIELDS; @@ -952,7 +952,7 @@ class ThreadLocalPtrStmt : public Stmt { public: std::size_t offset; - ThreadLocalPtrStmt(std::size_t offset, LegacyVectorType ret_type) + ThreadLocalPtrStmt(std::size_t offset, DataType ret_type) : offset(offset) { this->ret_type = ret_type; TI_STMT_REG_FIELDS; @@ -970,7 +970,7 @@ class BlockLocalPtrStmt : public Stmt { public: Stmt *offset; - BlockLocalPtrStmt(Stmt *offset, LegacyVectorType ret_type) : offset(offset) { + BlockLocalPtrStmt(Stmt *offset, DataType ret_type) : offset(offset) { this->ret_type = ret_type; TI_STMT_REG_FIELDS; } diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 545a9018fcb82..170fb884d3936 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -16,7 +16,7 @@ TLANG_NAMESPACE_BEGIN #include "taichi/inc/data_type.inc.h" #undef PER_TYPE -DataType::DataType() : ptr_(PrimitiveType::unknown.ptr_) { +DataType::DataType() : data_type(*this) ,ptr_(PrimitiveType::unknown.ptr_) { } DataType PrimitiveType::get(PrimitiveType::primitive_type t) { @@ -42,4 +42,10 @@ std::string PrimitiveType::to_string() const { return data_type_name(DataType(this)); } +DataType LegacyVectorType(int width, DataType data_type, bool is_pointer) { + TI_ASSERT(!is_pointer); + TI_ASSERT(width == 1); + return data_type; +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 31cc19ed963ff..29fefae0fc050 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -16,7 +16,10 @@ class DataType { public: DataType(); - DataType(const Type *ptr) : ptr_(ptr) { + DataType(Type *ptr) : data_type(*this), ptr_(ptr) { + } + + DataType(const DataType &o) : data_type(*this), ptr_(o.ptr_) { } bool operator==(const DataType &o) const { @@ -38,8 +41,21 @@ class DataType { return ptr_; } + // To be compatible with LegacyVectorType + int width{1}; + DataType &data_type; + + Type *operator->() const { + return ptr_; + } + + DataType &operator=(const DataType &o) { + ptr_ = o.ptr_; + return *this; + } + private: - const Type *ptr_; + Type *ptr_; }; class PrimitiveType : public Type { @@ -115,4 +131,8 @@ class VectorType : public Type { Type *element_{nullptr}; }; +DataType LegacyVectorType(int width, + DataType data_type, + bool is_pointer = false); + TLANG_NAMESPACE_END From 6373a211eca0a50d4517c08ab888e0bbacd5723b Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Tue, 13 Oct 2020 16:20:44 -0400 Subject: [PATCH 02/18] Type RTTI helpers --- taichi/codegen/codegen_llvm.cpp | 1 + taichi/ir/type.cpp | 4 ++++ taichi/ir/type.h | 25 +++++++++++++++++++++++++ 3 files changed, 30 insertions(+) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index c8666941481aa..b761eea3c91c7 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -132,6 +132,7 @@ void CodeGenLLVM::visit(AllocaStmt *stmt) { TI_ASSERT(stmt->width() == 1); llvm_val[stmt] = create_entry_block_alloca(stmt->ret_type.data_type, stmt->ret_type.is_pointer()); + // TODO: upgrade to new type system // initialize as zero if element is not a pointer if (!stmt->ret_type.is_pointer()) builder->CreateStore(tlctx->get_constant(stmt->ret_type.data_type, 0), diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 170fb884d3936..198f6706f57bb 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -38,6 +38,10 @@ std::size_t DataType::hash() const { } } +bool DataType::is_pointer() const { + return ptr_->is(); +} + std::string PrimitiveType::to_string() const { return data_type_name(DataType(this)); } diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 29fefae0fc050..d62bbf299f489 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -7,6 +7,29 @@ TLANG_NAMESPACE_BEGIN class Type { public: virtual std::string to_string() const = 0; + + template + bool is() const { + return cast() != nullptr; + } + + template + const T *cast() const { + return dynamic_cast(this); + } + + template + T *cast() { + return dynamic_cast(this); + } + + template + T *as() { + auto p = dynamic_cast(this); + TI_ASSERT(p != nullptr); + return p; + } + virtual ~Type() { } }; @@ -54,6 +77,8 @@ class DataType { return *this; } + bool is_pointer() const; + private: Type *ptr_; }; From 9dfb25cf284a471d56fc9ec6f5af735f52c89c29 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Tue, 13 Oct 2020 17:09:34 -0400 Subject: [PATCH 03/18] more fixes --- taichi/ir/ir.cpp | 19 +------------------ taichi/ir/type.cpp | 18 +++++++++++++++--- taichi/ir/type.h | 4 +++- taichi/lang_util.cpp | 25 +++++++++++++++++-------- taichi/lang_util.h | 29 +++++++++++++++++++++++++++++ taichi/transforms/auto_diff.cpp | 8 ++++++-- taichi/transforms/offload.cpp | 4 ++-- taichi/transforms/type_check.cpp | 2 +- 8 files changed, 74 insertions(+), 35 deletions(-) diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index b886e5ce064d9..c3f4cc0fea517 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -21,23 +21,6 @@ IRBuilder ¤t_ast_builder() { return context->builder(); } -std::string LegacyVectorType::pointer_suffix() const { - if (is_pointer()) { - return "*"; - } else { - return ""; - } -} - -std::string LegacyVectorType::element_type_name() const { - return fmt::format("{}{}", data_type_short_name(data_type), pointer_suffix()); -} - -std::string LegacyVectorType::str() const { - auto ename = element_type_name(); - return fmt::format("{:4}x{}", ename, width); -} - void DecoratorRecorder::reset() { vectorize = -1; parallelize = 0; @@ -242,7 +225,7 @@ std::string Stmt::type_hint() const { if (ret_type.data_type == PrimitiveType::unknown) return ""; else - return fmt::format("<{}>", ret_type.str()); + return fmt::format("<{}>", ret_type.to_string()); } std::string Stmt::type() { diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index 198f6706f57bb..cec388e7fc3a8 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -42,14 +42,26 @@ bool DataType::is_pointer() const { return ptr_->is(); } +void DataType::set_is_pointer(bool is_ptr) { + if (is_ptr && !ptr_->is()) { + ptr_ = Program::get_type_factory().get_pointer_type(ptr_); + } + if (!is_ptr && ptr_->is()) { + ptr_ = ptr_->cast()->get_pointee_type(); + } +} + std::string PrimitiveType::to_string() const { - return data_type_name(DataType(this)); + return data_type_name(DataType(const_cast(this))); } DataType LegacyVectorType(int width, DataType data_type, bool is_pointer) { - TI_ASSERT(!is_pointer); + if (is_pointer) { + return Program::get_type_factory().get_pointer_type(data_type.get_ptr()); + } else { + return data_type; + } TI_ASSERT(width == 1); - return data_type; } TLANG_NAMESPACE_END diff --git a/taichi/ir/type.h b/taichi/ir/type.h index d62bbf299f489..8000ddba26331 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -60,7 +60,7 @@ class DataType { }; // TODO: DataType itself should be a pointer in the future - const Type *get_ptr() const { + Type *get_ptr() const { return ptr_; } @@ -79,6 +79,8 @@ class DataType { bool is_pointer() const; + void set_is_pointer(bool ptr); + private: Type *ptr_; }; diff --git a/taichi/lang_util.cpp b/taichi/lang_util.cpp index d6ea762c8164b..ec9dd90c9afd1 100644 --- a/taichi/lang_util.cpp +++ b/taichi/lang_util.cpp @@ -321,10 +321,11 @@ namespace { class TypePromotionMapping { public: TypePromotionMapping() { -#define TRY_SECOND(x, y) \ - mapping[std::make_pair(to_primitive_type(get_data_type()), \ - to_primitive_type(get_data_type()))] = \ - get_data_type() + std::declval())>(); +#define TRY_SECOND(x, y) \ + mapping[std::make_pair(get_data_primitive_type(), \ + get_data_primitive_type())] = \ + get_data_primitive_type() + \ + std::declval())>(); #define TRY_FIRST(x) \ TRY_SECOND(x, float32); \ TRY_SECOND(x, float64); \ @@ -349,16 +350,24 @@ class TypePromotionMapping { TRY_FIRST(uint64); } DataType query(DataType x, DataType y) { - return mapping[std::make_pair(to_primitive_type(x), to_primitive_type(y))]; + auto primitive = + mapping[std::make_pair(to_primitive_type(x), to_primitive_type(y))]; + return Program::get_type_factory().get_primitive_type(primitive); } private: std::map< std::pair, - DataType> + PrimitiveType::primitive_type> mapping; - static PrimitiveType::primitive_type to_primitive_type(const DataType d) { - auto primitive = dynamic_cast(d.get_ptr()); + static PrimitiveType::primitive_type to_primitive_type(const DataType d_) { + Type *d = d_.get_ptr(); + if (d->is()) { + d = d->as()->get_pointee_type(); + TI_WARN("promoted_type got a pointer input."); + } + + auto primitive = d->cast(); TI_ASSERT_INFO( primitive, "Failed to get primitive type! " diff --git a/taichi/lang_util.h b/taichi/lang_util.h index 106d389cf585c..a63fdb7e2aac9 100644 --- a/taichi/lang_util.h +++ b/taichi/lang_util.h @@ -49,6 +49,35 @@ inline DataType get_data_type() { } } +template +inline PrimitiveType::primitive_type get_data_primitive_type() { + if (std::is_same()) { + return PrimitiveType::primitive_type::f32; + } else if (std::is_same()) { + return PrimitiveType::primitive_type::f64; + } else if (std::is_same()) { + return PrimitiveType::primitive_type::u1; + } else if (std::is_same()) { + return PrimitiveType::primitive_type::i8; + } else if (std::is_same()) { + return PrimitiveType::primitive_type::i16; + } else if (std::is_same()) { + return PrimitiveType::primitive_type::i32; + } else if (std::is_same()) { + return PrimitiveType::primitive_type::i64; + } else if (std::is_same()) { + return PrimitiveType::primitive_type::u8; + } else if (std::is_same()) { + return PrimitiveType::primitive_type::u16; + } else if (std::is_same()) { + return PrimitiveType::primitive_type::u32; + } else if (std::is_same()) { + return PrimitiveType::primitive_type::u64; + } else { + TI_NOT_IMPLEMENTED; + } +} + std::string data_type_name(DataType t); std::string data_type_format(DataType dt); diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index feff8e8ea89eb..b954e80121901 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -739,11 +739,15 @@ class BackupSSA : public BasicStmtVisitor { if (backup_alloca.find(stmt) == backup_alloca.end()) { auto alloca = Stmt::make(stmt->width(), stmt->ret_type.data_type); - alloca->ret_type.set_is_pointer(stmt->ret_type.is_pointer()); + // TODO: the line below was deleted during type system refactoring. + // Hopefully it's no longer needed. + // alloca->ret_type.set_is_pointer(stmt->ret_type.is_pointer()); auto alloca_ptr = alloca.get(); independent_block->insert(std::move(alloca), 0); auto local_store = Stmt::make(alloca_ptr, stmt); - local_store->ret_type.set_is_pointer(stmt->ret_type.is_pointer()); + // TODO: the line below was deleted during type system refactoring. + // Hopefully it's no longer needed. + // local_store->ret_type.set_is_pointer(stmt->ret_type.is_pointer()); stmt->insert_after_me(std::move(local_store)); backup_alloca[stmt] = alloca_ptr; } diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index b3f761bd49764..6be056bbcf6b6 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -258,7 +258,7 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor { global_offset = 0; } - std::size_t allocate_global(LegacyVectorType type) { + std::size_t allocate_global(DataType type) { TI_ASSERT(type.width == 1); auto ret = global_offset; global_offset += data_type_size(type.data_type); @@ -564,7 +564,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { StmtToOffsetMap local_to_global_offset; std::unordered_map stmt_to_offloaded; OffloadedRanges *const offloaded_ranges_; - std::unordered_map local_to_global_vector_type; + std::unordered_map local_to_global_vector_type; }; void insert_gc(IRNode *root) { diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 37b58de7d20e1..b2b77c36a02f3 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -25,7 +25,7 @@ class TypeCheck : public IRVisitor { allow_undefined_visitor = true; } - static void mark_as_if_const(Stmt *stmt, LegacyVectorType t) { + static void mark_as_if_const(Stmt *stmt, DataType t) { if (stmt->is()) { stmt->ret_type = t; } From cee5a7919abee4b0f53df01435d229cb3832e330 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Tue, 13 Oct 2020 17:28:26 -0400 Subject: [PATCH 04/18] laplace.py runs --- examples/laplace.py | 2 +- python/taichi/lang/util.py | 20 ++++++++++---------- taichi/ir/ir.cpp | 2 +- taichi/lang_util.cpp | 27 +++++++-------------------- taichi/transforms/type_check.cpp | 4 +++- 5 files changed, 22 insertions(+), 33 deletions(-) diff --git a/examples/laplace.py b/examples/laplace.py index 07f75e893b80f..9d09d33c94183 100644 --- a/examples/laplace.py +++ b/examples/laplace.py @@ -1,6 +1,6 @@ import taichi as ti -ti.init(arch=ti.cpu) +ti.init(arch=ti.cpu, print_ir=True, print_accessor_ir=True) x, y = ti.field(ti.f32), ti.field(ti.f32) ti.root.dense(ti.ij, 16).place(x, y) diff --git a/python/taichi/lang/util.py b/python/taichi/lang/util.py index dd0cac67b6dfd..6bd9316a8977f 100644 --- a/python/taichi/lang/util.py +++ b/python/taichi/lang/util.py @@ -30,9 +30,9 @@ def is_taichi_class(rhs): # Real types -float32 = taichi_lang_core.DataType_float32 +float32 = taichi_lang_core.DataType_f32 f32 = float32 -float64 = taichi_lang_core.DataType_float64 +float64 = taichi_lang_core.DataType_f64 f64 = float64 real_types = [f32, f64, float] @@ -40,22 +40,22 @@ def is_taichi_class(rhs): # Integer types -int8 = taichi_lang_core.DataType_int8 +int8 = taichi_lang_core.DataType_i8 i8 = int8 -int16 = taichi_lang_core.DataType_int16 +int16 = taichi_lang_core.DataType_i16 i16 = int16 -int32 = taichi_lang_core.DataType_int32 +int32 = taichi_lang_core.DataType_i32 i32 = int32 -int64 = taichi_lang_core.DataType_int64 +int64 = taichi_lang_core.DataType_i64 i64 = int64 -uint8 = taichi_lang_core.DataType_uint8 +uint8 = taichi_lang_core.DataType_u8 u8 = uint8 -uint16 = taichi_lang_core.DataType_uint16 +uint16 = taichi_lang_core.DataType_u16 u16 = uint16 -uint32 = taichi_lang_core.DataType_uint32 +uint32 = taichi_lang_core.DataType_u32 u32 = uint32 -uint64 = taichi_lang_core.DataType_uint64 +uint64 = taichi_lang_core.DataType_u64 u64 = uint64 integer_types = [i8, i16, i32, i64, u8, u16, u32, u64, int] diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index c3f4cc0fea517..80377c248a2e2 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -225,7 +225,7 @@ std::string Stmt::type_hint() const { if (ret_type.data_type == PrimitiveType::unknown) return ""; else - return fmt::format("<{}>", ret_type.to_string()); + return fmt::format("|{}| ", ret_type.to_string()); } std::string Stmt::type() { diff --git a/taichi/lang_util.cpp b/taichi/lang_util.cpp index ec9dd90c9afd1..aa02c0da477c5 100644 --- a/taichi/lang_util.cpp +++ b/taichi/lang_util.cpp @@ -66,26 +66,7 @@ real measure_cpe(std::function target, } std::string data_type_name(DataType t) { -#define REGISTER_DATA_TYPE(i, j) else if (t == PrimitiveType::i) return #j - if (false) { - } - REGISTER_DATA_TYPE(f16, float16); - REGISTER_DATA_TYPE(f32, float32); - REGISTER_DATA_TYPE(f64, float64); - REGISTER_DATA_TYPE(u1, int1); - REGISTER_DATA_TYPE(i8, int8); - REGISTER_DATA_TYPE(i16, int16); - REGISTER_DATA_TYPE(i32, int32); - REGISTER_DATA_TYPE(i64, int64); - REGISTER_DATA_TYPE(u8, uint8); - REGISTER_DATA_TYPE(u16, uint16); - REGISTER_DATA_TYPE(u32, uint32); - REGISTER_DATA_TYPE(u64, uint64); - REGISTER_DATA_TYPE(gen, generic); - REGISTER_DATA_TYPE(unknown, unknown); - -#undef REGISTER_DATA_TYPE - else TI_NOT_IMPLEMENTED + return data_type_short_name(t); } std::string data_type_format(DataType dt) { @@ -136,6 +117,12 @@ int data_type_size(DataType t) { } std::string data_type_short_name(DataType t) { + if (!t->is()) { + return t->to_string(); + } + + // Handle primitive types below. + if (false) { } #define PER_TYPE(i) else if (t == PrimitiveType::i) return #i; diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index b2b77c36a02f3..78dfbba73e803 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -374,7 +374,9 @@ class TypeCheck : public IRVisitor { void visit(GetChStmt *stmt) { stmt->ret_type = LegacyVectorType(1, stmt->output_snode->dt); - stmt->ret_type.set_is_pointer(true); + if (stmt->output_snode->type != SNodeType::place) { + stmt->ret_type.set_is_pointer(true); + } // for place SNodes GetCh directly yields a numerical value. } void visit(OffloadedStmt *stmt) { From 741dcb51ec10fc170b24dfc4521269974a5204fc Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Tue, 13 Oct 2020 20:42:12 -0400 Subject: [PATCH 05/18] fixing tmp.py --- taichi/ir/type.cpp | 16 ++++++++++++++-- taichi/ir/type.h | 3 +++ taichi/ir/type_factory.cpp | 1 - taichi/lang_util.cpp | 7 +++++++ taichi/lang_util.h | 2 ++ taichi/transforms/type_check.cpp | 24 +++++++++++++----------- tmp.py | 16 ++++++++++++++++ 7 files changed, 55 insertions(+), 14 deletions(-) create mode 100644 tmp.py diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index cec388e7fc3a8..b123732274a48 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -16,7 +16,7 @@ TLANG_NAMESPACE_BEGIN #include "taichi/inc/data_type.inc.h" #undef PER_TYPE -DataType::DataType() : data_type(*this) ,ptr_(PrimitiveType::unknown.ptr_) { +DataType::DataType() : data_type(*this), ptr_(PrimitiveType::unknown.ptr_) { } DataType PrimitiveType::get(PrimitiveType::primitive_type t) { @@ -31,8 +31,10 @@ DataType PrimitiveType::get(PrimitiveType::primitive_type t) { } std::size_t DataType::hash() const { - if (auto primitive = dynamic_cast(ptr_)) { + if (auto primitive = ptr_->cast()) { return (std::size_t)primitive->type; + } else if (auto pointer = ptr_->cast()) { + return 10007 + DataType(pointer->get_pointee_type()).hash(); } else { TI_NOT_IMPLEMENTED } @@ -51,6 +53,16 @@ void DataType::set_is_pointer(bool is_ptr) { } } +DataType DataType::ptr_removed() const { + auto t = ptr_; + auto ptr_type = t->cast(); + if (ptr_type) { + return DataType(ptr_type->get_pointee_type()); + } else { + return *this; + } +} + std::string PrimitiveType::to_string() const { return data_type_name(DataType(const_cast(this))); } diff --git a/taichi/ir/type.h b/taichi/ir/type.h index 8000ddba26331..cbb3353d7f7b8 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -81,6 +81,9 @@ class DataType { void set_is_pointer(bool ptr); + // Temporary API + DataType ptr_removed() const; + private: Type *ptr_; }; diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 645b8f576bcb6..aa1d8250a7a40 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -1,5 +1,4 @@ #include "taichi/ir/type_factory.h" -#include "type_factory.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/lang_util.cpp b/taichi/lang_util.cpp index aa02c0da477c5..09e41f6fdd778 100644 --- a/taichi/lang_util.cpp +++ b/taichi/lang_util.cpp @@ -88,6 +88,11 @@ std::string data_type_format(DataType dt) { } int data_type_size(DataType t) { + // TODO: + // 1. ensure in the old code, pointer attributes of t are correct (by setting + // a loud failure on pointers); + // 2. support pointer here. + t.set_is_pointer(false); if (false) { } else if (t == PrimitiveType::f16) return 2; @@ -370,6 +375,8 @@ DataType promoted_type(DataType a, DataType b) { } std::string TypedConstant::stringify() const { + // TODO: remove the line below after type system upgrade. + auto dt = this->dt.ptr_removed(); if (dt == PrimitiveType::f32) { return fmt::format("{}", val_f32); } else if (dt == PrimitiveType::i32) { diff --git a/taichi/lang_util.h b/taichi/lang_util.h index a63fdb7e2aac9..eab69c3956a36 100644 --- a/taichi/lang_util.h +++ b/taichi/lang_util.h @@ -243,6 +243,8 @@ class TypedConstant { template TypedConstant(DataType dt, const T &value) : dt(dt) { + // TODO: loud failure on pointers + dt.set_is_pointer(false); if (dt == PrimitiveType::f32) { val_f32 = value; } else if (dt == PrimitiveType::i32) { diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 78dfbba73e803..3ae1cbcb037af 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -63,17 +63,18 @@ class TypeCheck : public IRVisitor { void visit(AtomicOpStmt *stmt) { TI_ASSERT(stmt->width() == 1); - if (stmt->val->ret_type.data_type != stmt->dest->ret_type.data_type) { + if (stmt->val->ret_type != stmt->dest->ret_type.ptr_removed()) { + // TODO: make sure the ptr_removed type is indeed a numerical type TI_WARN("[{}] Atomic add ({} to {}) may lose precision.", stmt->name(), data_type_name(stmt->val->ret_type.data_type), - data_type_name(stmt->dest->ret_type.data_type)); + data_type_name(stmt->dest->ret_type.ptr_removed())); stmt->val = insert_type_cast_before(stmt, stmt->val, - stmt->dest->ret_type.data_type); + stmt->dest->ret_type.ptr_removed()); } if (stmt->element_type() == PrimitiveType::unknown) { - stmt->ret_type = stmt->dest->ret_type; + stmt->ret_type = stmt->dest->ret_type.ptr_removed(); } - stmt->ret_type.set_is_pointer(false); + TI_ASSERT(!stmt->ret_type->is()); } void visit(LocalLoadStmt *stmt) { @@ -146,19 +147,20 @@ class TypeCheck : public IRVisitor { } void visit(GlobalStoreStmt *stmt) { - auto promoted = promoted_type(stmt->ptr->ret_type.data_type, + auto promoted = promoted_type(stmt->ptr->ret_type.ptr_removed(), stmt->data->ret_type.data_type); auto input_type = stmt->data->ret_data_type_name(); - if (stmt->ptr->ret_type.data_type != stmt->data->ret_type.data_type) { + if (stmt->ptr->ret_type.data_type.ptr_removed() != + stmt->data->ret_type.data_type) { stmt->data = insert_type_cast_before(stmt, stmt->data, - stmt->ptr->ret_type.data_type); + stmt->ptr->ret_type.ptr_removed()); } - if (stmt->ptr->ret_type.data_type != promoted) { + if (stmt->ptr->ret_type.ptr_removed() != promoted) { TI_WARN("[{}] Global store may lose precision: {} <- {}, at", stmt->name(), stmt->ptr->ret_data_type_name(), input_type); TI_WARN("\n{}", stmt->tb); } - stmt->ret_type = stmt->ptr->ret_type; + stmt->ret_type = stmt->ptr->ret_type.ptr_removed(); } void visit(RangeForStmt *stmt) { @@ -376,7 +378,7 @@ class TypeCheck : public IRVisitor { stmt->ret_type = LegacyVectorType(1, stmt->output_snode->dt); if (stmt->output_snode->type != SNodeType::place) { stmt->ret_type.set_is_pointer(true); - } // for place SNodes GetCh directly yields a numerical value. + } // for place SNodes GetCh directly yields a numerical value. } void visit(OffloadedStmt *stmt) { diff --git a/tmp.py b/tmp.py new file mode 100644 index 0000000000000..dffc67b6510f1 --- /dev/null +++ b/tmp.py @@ -0,0 +1,16 @@ +import taichi as ti + +ti.init(print_ir=True) + + +A = ti.field(ti.f32, shape=()) + +@ti.kernel +def func(): + a = 0 + for i in range(10): + a -= i + A[None] = a + +func() +assert A[None] == -45 From 751ac48ae99e417669ffac56729dd10545b7c70a Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Tue, 13 Oct 2020 20:51:15 -0400 Subject: [PATCH 06/18] fix local atomic tests on LLVM --- taichi/codegen/codegen_llvm.cpp | 13 +++++++------ taichi/transforms/type_check.cpp | 1 + tmp.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index b761eea3c91c7..388a633b45e0c 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -868,6 +868,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) { llvm::PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0); llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty); } else { + TI_ASSERT(!stmt->ret_type.data_type->is()); dest_ty = tlctx->get_data_type(stmt->ret_type.data_type); auto dest_bits = dest_ty->getPrimitiveSizeInBits(); auto truncated = builder->CreateTrunc( @@ -1620,8 +1621,8 @@ void CodeGenLLVM::visit(GlobalTemporaryStmt *stmt) { tlctx->get_constant((int64)stmt->offset)); TI_ASSERT(stmt->width() == 1); - auto ptr_type = - llvm::PointerType::get(tlctx->get_data_type(stmt->ret_type.data_type), 0); + auto ptr_type = llvm::PointerType::get( + tlctx->get_data_type(stmt->ret_type.data_type.ptr_removed()), 0); llvm_val[stmt] = builder->CreatePointerCast(buffer, ptr_type); } @@ -1629,8 +1630,8 @@ void CodeGenLLVM::visit(ThreadLocalPtrStmt *stmt) { auto base = get_tls_base_ptr(); TI_ASSERT(stmt->width() == 1); auto ptr = builder->CreateGEP(base, tlctx->get_constant(stmt->offset)); - auto ptr_type = - llvm::PointerType::get(tlctx->get_data_type(stmt->ret_type.data_type), 0); + auto ptr_type = llvm::PointerType::get( + tlctx->get_data_type(stmt->ret_type.data_type.ptr_removed()), 0); llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type); } @@ -1640,8 +1641,8 @@ void CodeGenLLVM::visit(BlockLocalPtrStmt *stmt) { TI_ASSERT(stmt->width() == 1); auto ptr = builder->CreateGEP( base, {tlctx->get_constant(0), llvm_val[stmt->offset]}); - auto ptr_type = - llvm::PointerType::get(tlctx->get_data_type(stmt->ret_type.data_type), 0); + auto ptr_type = llvm::PointerType::get( + tlctx->get_data_type(stmt->ret_type.data_type.ptr_removed()), 0); llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type); } diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 3ae1cbcb037af..8fd7c6c0eaf25 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -335,6 +335,7 @@ class TypeCheck : public IRVisitor { // verification, without modifying any types. TI_ASSERT(rt.data_type != PrimitiveType::unknown); TI_ASSERT(rt.width == 1); + stmt->ret_type.set_is_pointer(stmt->is_ptr); } void visit(KernelReturnStmt *stmt) { diff --git a/tmp.py b/tmp.py index dffc67b6510f1..d3da9b5fe71fc 100644 --- a/tmp.py +++ b/tmp.py @@ -1,6 +1,6 @@ import taichi as ti -ti.init(print_ir=True) +ti.init(print_ir=True, print_accessor_ir=True) A = ti.field(ti.f32, shape=()) From 64898402f0ff49f20bd0ab447005e5b332a0abf4 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Tue, 13 Oct 2020 21:43:00 -0400 Subject: [PATCH 07/18] fix opengl and cc --- taichi/backends/opengl/codegen_opengl.cpp | 2 +- tests/python/test_atomic.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index 7cd6f85f50b72..dd2d78eea1466 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -573,7 +573,7 @@ class KernelGen : public IRVisitor { void visit(AtomicOpStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - auto dt = stmt->dest->element_type(); + auto dt = stmt->dest->element_type().ptr_removed(); if (dt == PrimitiveType::i32 || (TI_OPENGL_REQUIRE(used, GL_NV_shader_atomic_int64) && dt == PrimitiveType::i64) || diff --git a/tests/python/test_atomic.py b/tests/python/test_atomic.py index 4590fc7ad0563..49af5fab30856 100644 --- a/tests/python/test_atomic.py +++ b/tests/python/test_atomic.py @@ -35,10 +35,6 @@ def func(): assert valproc(ya) == e -ti.init(ti.cc, log_level=ti.DEBUG) -run_atomic_add_global_case(ti.i32, 42) - - @ti.all_archs def test_atomic_add_global_i32(): run_atomic_add_global_case(ti.i32, 42) From 3dc5a40c6f740f1e748b61d2750a66830e84a2f6 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Tue, 13 Oct 2020 21:47:20 -0400 Subject: [PATCH 08/18] fix llvm external pointer --- taichi/codegen/codegen_llvm.cpp | 2 +- tmp.py | 21 ++++++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 388a633b45e0c..ba52551f3d3f4 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1229,7 +1229,7 @@ void CodeGenLLVM::visit(ExternalPtrStmt *stmt) { sizes[i] = raw_arg; } - auto dt = stmt->ret_type.data_type; + auto dt = stmt->ret_type.data_type.ptr_removed(); auto base = builder->CreateBitCast( llvm_val[stmt->base_ptrs[0]], llvm::PointerType::get(tlctx->get_data_type(dt), 0)); diff --git a/tmp.py b/tmp.py index d3da9b5fe71fc..fb19c680f6de9 100644 --- a/tmp.py +++ b/tmp.py @@ -1,16 +1,19 @@ import taichi as ti +import numpy as np ti.init(print_ir=True, print_accessor_ir=True) - -A = ti.field(ti.f32, shape=()) +n = 10000 @ti.kernel -def func(): - a = 0 - for i in range(10): - a -= i - A[None] = a +def inc(a: ti.ext_arr()): + for i in range(n): + a[i] += i + +x = np.zeros(dtype=np.int32, shape=n) +for i in range(10): + inc(x) + +for i in range(n): + assert x[i] == i * 10 -func() -assert A[None] == -45 From 8db30473fa1f16b1dc2ee6f9a386ac7673253184 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Tue, 13 Oct 2020 21:57:05 -0400 Subject: [PATCH 09/18] fix all x64 tests --- taichi/transforms/make_block_local.cpp | 2 +- taichi/transforms/make_thread_local.cpp | 2 +- tests/python/test_reduction.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/taichi/transforms/make_block_local.cpp b/taichi/transforms/make_block_local.cpp index 49c4446f9e9bd..1a29df0eca7d2 100644 --- a/taichi/transforms/make_block_local.cpp +++ b/taichi/transforms/make_block_local.cpp @@ -23,7 +23,7 @@ void make_block_local_offload(OffloadedStmt *offload) { for (auto &pad : pads->pads) { auto snode = pad.first; - auto data_type = snode->dt; + auto data_type = snode->dt.ptr_removed(); auto dtype_size = data_type_size(data_type); bool bls_has_read = pad.second.total_flags & AccessFlag::read; diff --git a/taichi/transforms/make_thread_local.cpp b/taichi/transforms/make_thread_local.cpp index a0db654c7d1ef..03d2c40c2c6fd 100644 --- a/taichi/transforms/make_thread_local.cpp +++ b/taichi/transforms/make_thread_local.cpp @@ -113,7 +113,7 @@ void make_thread_local_offload(OffloadedStmt *offload) { // TODO: sort thread local storage variables according to dtype_size to // reduce buffer fragmentation. for (auto dest : valid_reduction_values) { - auto data_type = dest->ret_type.data_type; + auto data_type = dest->ret_type.data_type.ptr_removed(); auto dtype_size = data_type_size(data_type); // Step 1: // Create thread local storage diff --git a/tests/python/test_reduction.py b/tests/python/test_reduction.py index 4f12b6a0df848..e7db7c730efda 100644 --- a/tests/python/test_reduction.py +++ b/tests/python/test_reduction.py @@ -42,7 +42,7 @@ def test_reduction_single_i32(): _test_reduction_single(ti.i32, lambda x, y: x % 2**32 == y % 2**32) -@ti.archs_excluding(ti.opengl) +@ti.test(exclude=ti.opengl) def test_reduction_single_u32(): _test_reduction_single(ti.u32, lambda x, y: x % 2**32 == y % 2**32) From a48c4806f42a78df48b87676214ff44b633b594b Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Wed, 14 Oct 2020 11:20:00 -0400 Subject: [PATCH 10/18] fix opengl tests --- taichi/backends/opengl/codegen_opengl.cpp | 2 ++ taichi/backends/opengl/opengl_data_types.h | 7 +++++-- tmp.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index dd2d78eea1466..c081d0733320d 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -428,6 +428,7 @@ class KernelGen : public IRVisitor { } void visit(ExternalPtrStmt *stmt) override { + TI_TAG; TI_ASSERT(stmt->width() == 1); const auto linear_index_name = fmt::format("_li_{}", stmt->short_name()); emit("int {} = 0;", linear_index_name); @@ -572,6 +573,7 @@ class KernelGen : public IRVisitor { } void visit(AtomicOpStmt *stmt) override { + TI_TAG; TI_ASSERT(stmt->width() == 1); auto dt = stmt->dest->element_type().ptr_removed(); if (dt == PrimitiveType::i32 || diff --git a/taichi/backends/opengl/opengl_data_types.h b/taichi/backends/opengl/opengl_data_types.h index e80018d9fc36d..16b21a1048d23 100644 --- a/taichi/backends/opengl/opengl_data_types.h +++ b/taichi/backends/opengl/opengl_data_types.h @@ -8,6 +8,7 @@ namespace opengl { inline std::string opengl_data_type_name(DataType dt) { // https://www.khronos.org/opengl/wiki/Data_Type_(GLSL) + dt.set_is_pointer(false); if (dt == PrimitiveType::f32) return "float"; else if (dt == PrimitiveType::f64) @@ -16,8 +17,9 @@ inline std::string opengl_data_type_name(DataType dt) { return "int"; else if (dt == PrimitiveType::i64) return "int64_t"; - else - TI_NOT_IMPLEMENTED; + else { + TI_ERROR("Type {} not supported.", dt->to_string()); + } } inline bool is_opengl_binary_op_infix(BinaryOpType type) { @@ -32,6 +34,7 @@ inline bool is_opengl_binary_op_different_return_type(BinaryOpType type) { } inline int opengl_data_address_shifter(DataType type) { + type.set_is_pointer(false); if (type == PrimitiveType::f32 || type == PrimitiveType::i32) return 2; else if (type == PrimitiveType::f64 || type == PrimitiveType::i64) { diff --git a/tmp.py b/tmp.py index fb19c680f6de9..1b2a32f8e252c 100644 --- a/tmp.py +++ b/tmp.py @@ -1,7 +1,7 @@ import taichi as ti import numpy as np -ti.init(print_ir=True, print_accessor_ir=True) +ti.init(arch=ti.opengl, print_ir=True, print_accessor_ir=True) n = 10000 From 84e3209ff691129863760dc0e145e910829a874a Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Wed, 14 Oct 2020 11:49:23 -0400 Subject: [PATCH 11/18] all tests passed --- taichi/backends/cc/codegen_cc.cpp | 17 ++++++++++------- tmp.py | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/taichi/backends/cc/codegen_cc.cpp b/taichi/backends/cc/codegen_cc.cpp index c94648a69d192..b0fbbcea976bf 100644 --- a/taichi/backends/cc/codegen_cc.cpp +++ b/taichi/backends/cc/codegen_cc.cpp @@ -136,7 +136,8 @@ class CCTransformer : public IRVisitor { void visit(GlobalTemporaryStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - auto ptr_type = cc_data_type_name(stmt->element_type()) + " *"; + auto ptr_type = + cc_data_type_name(stmt->element_type().ptr_removed()) + " *"; auto var = define_var(ptr_type, stmt->raw_name()); emit("{} = ({}) (ti_ctx->gtmp + {});", var, ptr_type, stmt->offset); } @@ -161,17 +162,19 @@ class CCTransformer : public IRVisitor { offset = fmt::format("({} * {} + {})", offset, stride, stmt->indices[i]->raw_name()); } - auto var = define_var(cc_data_type_name(stmt->element_type()) + " *", - stmt->raw_name()); + auto var = + define_var(cc_data_type_name(stmt->element_type().ptr_removed()) + " *", + stmt->raw_name()); emit("{} = {} + {};", var, stmt->base_ptrs[0]->raw_name(), offset); } void visit(ArgLoadStmt *stmt) override { if (stmt->is_ptr) { - auto var = define_var(cc_data_type_name(stmt->element_type()) + " *", - stmt->raw_name()); + auto var = define_var( + cc_data_type_name(stmt->element_type().ptr_removed()) + " *", + stmt->raw_name()); emit("{} = ti_ctx->args[{}].ptr_{};", var, stmt->arg_id, - data_type_short_name(stmt->element_type())); + data_type_short_name(stmt->element_type().ptr_removed())); } else { auto var = define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name()); @@ -377,7 +380,7 @@ class CCTransformer : public IRVisitor { const auto dest_ptr = stmt->dest->raw_name(); const auto src_name = stmt->val->raw_name(); const auto op = cc_atomic_op_type_symbol(stmt->op_type); - const auto type = stmt->element_type(); + const auto type = stmt->dest->element_type().ptr_removed(); auto var = define_var(cc_data_type_name(type), stmt->raw_name()); emit("{} = *{};", var, dest_ptr); if (stmt->op_type == AtomicOpType::max || diff --git a/tmp.py b/tmp.py index 1b2a32f8e252c..cfa492206121f 100644 --- a/tmp.py +++ b/tmp.py @@ -7,7 +7,7 @@ @ti.kernel def inc(a: ti.ext_arr()): - for i in range(n): + for i in ti.ndrange([0, n]): a[i] += i x = np.zeros(dtype=np.int32, shape=n) From 54a0ba7e06d50bef75de88c724820ae54cf3be29 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Wed, 14 Oct 2020 11:49:54 -0400 Subject: [PATCH 12/18] format --- taichi/ir/ir.cpp | 3 +-- taichi/ir/statements.h | 6 ++---- tmp.py | 3 ++- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 80377c248a2e2..58231a1a9f879 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -166,8 +166,7 @@ Stmt::Stmt(const Stmt &stmt) : field_manager(this), fields_registered(false) { ret_type = stmt.ret_type; } -Stmt::Stmt(const DataType &data_type): Stmt() { - +Stmt::Stmt(const DataType &data_type) : Stmt() { } Stmt *Stmt::insert_before_me(std::unique_ptr &&new_stmt) { diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index ae251bc070713..3e49c1be1d5af 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -934,8 +934,7 @@ class GlobalTemporaryStmt : public Stmt { public: std::size_t offset; - GlobalTemporaryStmt(std::size_t offset, DataType ret_type) - : offset(offset) { + GlobalTemporaryStmt(std::size_t offset, DataType ret_type) : offset(offset) { this->ret_type = ret_type; TI_STMT_REG_FIELDS; } @@ -952,8 +951,7 @@ class ThreadLocalPtrStmt : public Stmt { public: std::size_t offset; - ThreadLocalPtrStmt(std::size_t offset, DataType ret_type) - : offset(offset) { + ThreadLocalPtrStmt(std::size_t offset, DataType ret_type) : offset(offset) { this->ret_type = ret_type; TI_STMT_REG_FIELDS; } diff --git a/tmp.py b/tmp.py index cfa492206121f..e2c5d958ca741 100644 --- a/tmp.py +++ b/tmp.py @@ -5,15 +5,16 @@ n = 10000 + @ti.kernel def inc(a: ti.ext_arr()): for i in ti.ndrange([0, n]): a[i] += i + x = np.zeros(dtype=np.int32, shape=n) for i in range(10): inc(x) for i in range(n): assert x[i] == i * 10 - From 63ae4b221084ed6ae8d552e40a6dee3ffbd89a74 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Wed, 14 Oct 2020 12:10:35 -0400 Subject: [PATCH 13/18] fix GetCh type --- taichi/transforms/type_check.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 8fd7c6c0eaf25..db4d08aa83447 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -376,10 +376,7 @@ class TypeCheck : public IRVisitor { } void visit(GetChStmt *stmt) { - stmt->ret_type = LegacyVectorType(1, stmt->output_snode->dt); - if (stmt->output_snode->type != SNodeType::place) { - stmt->ret_type.set_is_pointer(true); - } // for place SNodes GetCh directly yields a numerical value. + stmt->ret_type = LegacyVectorType(1, stmt->output_snode->dt, true); } void visit(OffloadedStmt *stmt) { From 4a39b62a66d7f49dd7107dcecfe0ddb33f594d7c Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Wed, 14 Oct 2020 12:16:53 -0400 Subject: [PATCH 14/18] fix local ptr types --- taichi/transforms/make_block_local.cpp | 6 +++--- taichi/transforms/make_thread_local.cpp | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/taichi/transforms/make_block_local.cpp b/taichi/transforms/make_block_local.cpp index 1a29df0eca7d2..99d908ac72613 100644 --- a/taichi/transforms/make_block_local.cpp +++ b/taichi/transforms/make_block_local.cpp @@ -180,7 +180,7 @@ void make_block_local_offload(OffloadedStmt *offload) { TypedConstant(data_type, 0)); } auto bls_ptr = element_block->push_back( - bls_element_offset_bytes, LegacyVectorType(1, data_type)); + bls_element_offset_bytes, LegacyVectorType(1, data_type, true)); element_block->push_back(bls_ptr, value); }); } @@ -269,7 +269,7 @@ void make_block_local_offload(OffloadedStmt *offload) { bls.push_back(TypedConstant((int32)bls_offset))); bls.push_back(bls_element_offset, - LegacyVectorType(1, data_type)); + LegacyVectorType(1, data_type, true)); global_ptr->replace_with(std::move(bls)); } } @@ -283,7 +283,7 @@ void make_block_local_offload(OffloadedStmt *offload) { Stmt *bls_element_offset_bytes) { // Store/accumulate from BLS to global auto bls_ptr = element_block->push_back( - bls_element_offset_bytes, LegacyVectorType(1, data_type)); + bls_element_offset_bytes, LegacyVectorType(1, data_type, true)); auto bls_val = element_block->push_back(bls_ptr); auto global_pointer = diff --git a/taichi/transforms/make_thread_local.cpp b/taichi/transforms/make_thread_local.cpp index 03d2c40c2c6fd..7c291a61e6ec5 100644 --- a/taichi/transforms/make_thread_local.cpp +++ b/taichi/transforms/make_thread_local.cpp @@ -127,7 +127,7 @@ void make_thread_local_offload(OffloadedStmt *offload) { tls_offset += (dtype_size - tls_offset % dtype_size) % dtype_size; auto tls_ptr = offload->tls_prologue->push_back( - tls_offset, LegacyVectorType(1, data_type)); + tls_offset, LegacyVectorType(1, data_type, true)); auto zero = offload->tls_prologue->insert( std::make_unique(TypedConstant(data_type, 0)), -1); @@ -139,10 +139,10 @@ void make_thread_local_offload(OffloadedStmt *offload) { // Step 2: // Make loop body accumulate to TLS ptr instead of global ptr { - auto tls_ptr = - offload->body->insert(Stmt::make( - tls_offset, LegacyVectorType(1, data_type)), - 0); + auto tls_ptr = offload->body->insert( + Stmt::make(tls_offset, + LegacyVectorType(1, data_type, true)), + 0); dest->replace_with(tls_ptr); } @@ -154,7 +154,7 @@ void make_thread_local_offload(OffloadedStmt *offload) { offload->tls_epilogue->parent_stmt = offload; } auto tls_ptr = offload->tls_epilogue->push_back( - tls_offset, LegacyVectorType(1, data_type)); + tls_offset, LegacyVectorType(1, data_type, true)); // TODO: do not use global load from TLS. auto tls_load = offload->tls_epilogue->push_back(tls_ptr); auto global_ptr = offload->tls_epilogue->insert( From 9df8323c3e608066f22a43ac7af1e475506fcbc4 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Wed, 14 Oct 2020 15:56:33 -0400 Subject: [PATCH 15/18] clean --- examples/laplace.py | 2 +- taichi/backends/opengl/codegen_opengl.cpp | 2 -- taichi/backends/opengl/opengl_data_types.h | 2 ++ taichi/codegen/codegen_llvm.cpp | 1 - taichi/ir/ir.cpp | 5 +---- taichi/ir/ir.h | 3 --- taichi/lang_util.cpp | 8 +++++--- taichi/lang_util.h | 2 +- taichi/transforms/auto_diff.cpp | 9 +-------- tmp.py | 20 -------------------- 10 files changed, 11 insertions(+), 43 deletions(-) delete mode 100644 tmp.py diff --git a/examples/laplace.py b/examples/laplace.py index 9d09d33c94183..07f75e893b80f 100644 --- a/examples/laplace.py +++ b/examples/laplace.py @@ -1,6 +1,6 @@ import taichi as ti -ti.init(arch=ti.cpu, print_ir=True, print_accessor_ir=True) +ti.init(arch=ti.cpu) x, y = ti.field(ti.f32), ti.field(ti.f32) ti.root.dense(ti.ij, 16).place(x, y) diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index c081d0733320d..dd2d78eea1466 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -428,7 +428,6 @@ class KernelGen : public IRVisitor { } void visit(ExternalPtrStmt *stmt) override { - TI_TAG; TI_ASSERT(stmt->width() == 1); const auto linear_index_name = fmt::format("_li_{}", stmt->short_name()); emit("int {} = 0;", linear_index_name); @@ -573,7 +572,6 @@ class KernelGen : public IRVisitor { } void visit(AtomicOpStmt *stmt) override { - TI_TAG; TI_ASSERT(stmt->width() == 1); auto dt = stmt->dest->element_type().ptr_removed(); if (dt == PrimitiveType::i32 || diff --git a/taichi/backends/opengl/opengl_data_types.h b/taichi/backends/opengl/opengl_data_types.h index 16b21a1048d23..72bb1f6406452 100644 --- a/taichi/backends/opengl/opengl_data_types.h +++ b/taichi/backends/opengl/opengl_data_types.h @@ -34,6 +34,8 @@ inline bool is_opengl_binary_op_different_return_type(BinaryOpType type) { } inline int opengl_data_address_shifter(DataType type) { + // TODO: fail loudly when feeding a pointer type to this function after type + // system upgrade. type.set_is_pointer(false); if (type == PrimitiveType::f32 || type == PrimitiveType::i32) return 2; diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index ba52551f3d3f4..84028c6b18816 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -132,7 +132,6 @@ void CodeGenLLVM::visit(AllocaStmt *stmt) { TI_ASSERT(stmt->width() == 1); llvm_val[stmt] = create_entry_block_alloca(stmt->ret_type.data_type, stmt->ret_type.is_pointer()); - // TODO: upgrade to new type system // initialize as zero if element is not a pointer if (!stmt->ret_type.is_pointer()) builder->CreateStore(tlctx->get_constant(stmt->ret_type.data_type, 0), diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 58231a1a9f879..6dcbe90bf3984 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -166,9 +166,6 @@ Stmt::Stmt(const Stmt &stmt) : field_manager(this), fields_registered(false) { ret_type = stmt.ret_type; } -Stmt::Stmt(const DataType &data_type) : Stmt() { -} - Stmt *Stmt::insert_before_me(std::unique_ptr &&new_stmt) { auto ret = new_stmt.get(); TI_ASSERT(parent); @@ -224,7 +221,7 @@ std::string Stmt::type_hint() const { if (ret_type.data_type == PrimitiveType::unknown) return ""; else - return fmt::format("|{}| ", ret_type.to_string()); + return fmt::format("<{}> ", ret_type.to_string()); } std::string Stmt::type() { diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 193701e8f1643..531408db99772 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -499,9 +499,6 @@ class Stmt : public IRNode { Stmt(); Stmt(const Stmt &stmt); - // TODO: remove this after type refactoring is done - Stmt(const DataType &data_type); - int &width() { return ret_type.width; } diff --git a/taichi/lang_util.cpp b/taichi/lang_util.cpp index 09e41f6fdd778..ca0e45add4f87 100644 --- a/taichi/lang_util.cpp +++ b/taichi/lang_util.cpp @@ -65,6 +65,8 @@ real measure_cpe(std::function target, return elasped_cycles / float64(total_batches * elements_per_call); } +// TODO: Remove data_type_short_name. Having two names for a data type is +// confusing. std::string data_type_name(DataType t) { return data_type_short_name(t); } @@ -314,9 +316,9 @@ class TypePromotionMapping { public: TypePromotionMapping() { #define TRY_SECOND(x, y) \ - mapping[std::make_pair(get_data_primitive_type(), \ - get_data_primitive_type())] = \ - get_data_primitive_type() + \ + mapping[std::make_pair(get_primitive_data_type(), \ + get_primitive_data_type())] = \ + get_primitive_data_type() + \ std::declval())>(); #define TRY_FIRST(x) \ TRY_SECOND(x, float32); \ diff --git a/taichi/lang_util.h b/taichi/lang_util.h index eab69c3956a36..50f1f9f75ab3e 100644 --- a/taichi/lang_util.h +++ b/taichi/lang_util.h @@ -50,7 +50,7 @@ inline DataType get_data_type() { } template -inline PrimitiveType::primitive_type get_data_primitive_type() { +inline PrimitiveType::primitive_type get_primitive_data_type() { if (std::is_same()) { return PrimitiveType::primitive_type::f32; } else if (std::is_same()) { diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index b954e80121901..b85d55aef41c1 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -737,17 +737,10 @@ class BackupSSA : public BasicStmtVisitor { Stmt *load(Stmt *stmt) { if (backup_alloca.find(stmt) == backup_alloca.end()) { - auto alloca = - Stmt::make(stmt->width(), stmt->ret_type.data_type); - // TODO: the line below was deleted during type system refactoring. - // Hopefully it's no longer needed. - // alloca->ret_type.set_is_pointer(stmt->ret_type.is_pointer()); + auto alloca = Stmt::make(stmt->width(), stmt->ret_type); auto alloca_ptr = alloca.get(); independent_block->insert(std::move(alloca), 0); auto local_store = Stmt::make(alloca_ptr, stmt); - // TODO: the line below was deleted during type system refactoring. - // Hopefully it's no longer needed. - // local_store->ret_type.set_is_pointer(stmt->ret_type.is_pointer()); stmt->insert_after_me(std::move(local_store)); backup_alloca[stmt] = alloca_ptr; } diff --git a/tmp.py b/tmp.py deleted file mode 100644 index e2c5d958ca741..0000000000000 --- a/tmp.py +++ /dev/null @@ -1,20 +0,0 @@ -import taichi as ti -import numpy as np - -ti.init(arch=ti.opengl, print_ir=True, print_accessor_ir=True) - -n = 10000 - - -@ti.kernel -def inc(a: ti.ext_arr()): - for i in ti.ndrange([0, n]): - a[i] += i - - -x = np.zeros(dtype=np.int32, shape=n) -for i in range(10): - inc(x) - -for i in range(n): - assert x[i] == i * 10 From 8dee592feb4497fd28f5ea2c125930d067929cc0 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Wed, 14 Oct 2020 16:23:30 -0400 Subject: [PATCH 16/18] finalize --- taichi/backends/opengl/opengl_data_types.h | 2 +- taichi/ir/type.cpp | 2 +- taichi/ir/type.h | 4 ++-- taichi/lang_util.cpp | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/taichi/backends/opengl/opengl_data_types.h b/taichi/backends/opengl/opengl_data_types.h index 72bb1f6406452..59ab24b1d83f5 100644 --- a/taichi/backends/opengl/opengl_data_types.h +++ b/taichi/backends/opengl/opengl_data_types.h @@ -34,7 +34,7 @@ inline bool is_opengl_binary_op_different_return_type(BinaryOpType type) { } inline int opengl_data_address_shifter(DataType type) { - // TODO: fail loudly when feeding a pointer type to this function after type + // TODO: fail loudly when feeding a pointer type to this function, after type // system upgrade. type.set_is_pointer(false); if (type == PrimitiveType::f32 || type == PrimitiveType::i32) diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index b123732274a48..e98e77842a28f 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -68,12 +68,12 @@ std::string PrimitiveType::to_string() const { } DataType LegacyVectorType(int width, DataType data_type, bool is_pointer) { + TI_ASSERT(width == 1); if (is_pointer) { return Program::get_type_factory().get_pointer_type(data_type.get_ptr()); } else { return data_type; } - TI_ASSERT(width == 1); } TLANG_NAMESPACE_END diff --git a/taichi/ir/type.h b/taichi/ir/type.h index cbb3353d7f7b8..29f31f090103c 100644 --- a/taichi/ir/type.h +++ b/taichi/ir/type.h @@ -64,7 +64,8 @@ class DataType { return ptr_; } - // To be compatible with LegacyVectorType + // Temporary API and members + // for LegacyVectorType-compatibility int width{1}; DataType &data_type; @@ -81,7 +82,6 @@ class DataType { void set_is_pointer(bool ptr); - // Temporary API DataType ptr_removed() const; private: diff --git a/taichi/lang_util.cpp b/taichi/lang_util.cpp index ca0e45add4f87..1f7300857c3b3 100644 --- a/taichi/lang_util.cpp +++ b/taichi/lang_util.cpp @@ -91,9 +91,9 @@ std::string data_type_format(DataType dt) { int data_type_size(DataType t) { // TODO: - // 1. ensure in the old code, pointer attributes of t are correct (by setting + // 1. Ensure in the old code, pointer attributes of t are correct (by setting // a loud failure on pointers); - // 2. support pointer here. + // 2. Support pointer types here. t.set_is_pointer(false); if (false) { } else if (t == PrimitiveType::f16) From 13345ec79bcca1bec3dcb623921c92fd182705a9 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Wed, 14 Oct 2020 15:11:20 -0400 Subject: [PATCH 17/18] fix some metal tests --- taichi/backends/metal/codegen_metal.cpp | 3 ++- taichi/backends/metal/data_types.cpp | 1 + tests/python/test_ad_atomic.py | 8 +++----- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index c2de4856e403f..d1ce200eeeb17 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -357,7 +357,8 @@ class KernelCodegen : public IRVisitor { void visit(ThreadLocalPtrStmt *stmt) override { TI_ASSERT(stmt->width() == 1); emit("thread auto* {} = reinterpret_cast({} + {});", - stmt->raw_name(), metal_data_type_name(stmt->element_type()), + stmt->raw_name(), + metal_data_type_name(stmt->element_type().ptr_removed()), kTlsBufferName, stmt->offset); } diff --git a/taichi/backends/metal/data_types.cpp b/taichi/backends/metal/data_types.cpp index 101251a59a236..14112c0db8041 100644 --- a/taichi/backends/metal/data_types.cpp +++ b/taichi/backends/metal/data_types.cpp @@ -4,6 +4,7 @@ TLANG_NAMESPACE_BEGIN namespace metal { MetalDataType to_metal_type(DataType dt) { + dt.set_is_pointer(false); #define METAL_CASE(x) else if (dt == PrimitiveType::x) return MetalDataType::x if (false) { } diff --git a/tests/python/test_ad_atomic.py b/tests/python/test_ad_atomic.py index 21a3ec30faefe..3539bf6b809c2 100644 --- a/tests/python/test_ad_atomic.py +++ b/tests/python/test_ad_atomic.py @@ -2,14 +2,12 @@ from taichi import approx -@ti.all_archs +@ti.test() def test_ad_reduce(): - x = ti.field(ti.f32) - loss = ti.field(ti.f32) - N = 16 - ti.root.place(loss, loss.grad).dense(ti.i, N).place(x, x.grad) + x = ti.field(dtype=ti.f32, shape=N, needs_grad=True) + loss = ti.field(dtype=ti.f32, shape=(), needs_grad=True) @ti.kernel def func(): From e3e859bcad45e6f51430554ab97bdc4cd8e12250 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Wed, 14 Oct 2020 16:31:31 -0400 Subject: [PATCH 18/18] try to fix metal --- taichi/backends/metal/codegen_metal.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index d1ce200eeeb17..3287296bb1d11 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -349,7 +349,7 @@ class KernelCodegen : public IRVisitor { void visit(GlobalTemporaryStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - const auto dt = metal_data_type_name(stmt->element_type()); + const auto dt = metal_data_type_name(stmt->element_type().ptr_removed()); emit("device {}* {} = reinterpret_cast({} + {});", dt, stmt->raw_name(), dt, kGlobalTmpsBufferName, stmt->offset); }