Skip to content

Commit

Permalink
[Bug] [ir] Fix and refactor type check for atomic ops (taichi-dev#4858)
Browse files Browse the repository at this point in the history
* [Bug] [ir] Fix and refactor type check for atomic ops

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove accidental change

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and k-ye committed May 5, 2022
1 parent c38c7cd commit af29e73
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 41 deletions.
1 change: 1 addition & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ void AtomicOpExpression::flatten(FlattenContext *ctx) {
ctx->push_back<AtomicOpStmt>(op_type, dest->stmt, expr->stmt);
}
stmt = ctx->back_stmt();
stmt->tb = tb;
}

void SNodeOpExpression::type_check(CompileConfig *) {
Expand Down
69 changes: 28 additions & 41 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,28 @@ class TypeCheck : public IRVisitor {
private:
CompileConfig config_;

Type *type_check_store(Stmt *stmt,
Stmt *dst,
Stmt *&val,
const std::string &stmt_name) {
auto dst_type = dst->ret_type.ptr_removed();
if (dst_type->is<CustomIntType>() || dst_type->is<CustomFloatType>()) {
// We force the value type to be the compute_type of the bit pointer.
// Casting from compute_type to physical_type is handled in codegen.
dst_type = dst_type->get_compute_type();
}
if (dst_type != val->ret_type) {
auto promoted = promoted_type(dst_type, val->ret_type);
if (dst_type != promoted) {
TI_WARN("[{}] {} may lose precision: {} <- {}\n{}", stmt->name(),
stmt_name, dst_type->to_string(), val->ret_data_type_name(),
stmt->tb);
}
val = insert_type_cast_before(stmt, val, dst_type);
}
return dst_type;
}

public:
explicit TypeCheck(const CompileConfig &config) : config_(config) {
allow_undefined_visitor = true;
Expand Down Expand Up @@ -57,20 +79,9 @@ class TypeCheck : public IRVisitor {
void visit(AtomicOpStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
// TODO(type): test_ad_for fails if we assume dest is a pointer type.
auto dst_type = stmt->dest->ret_type.ptr_removed();
if (auto cit = dst_type->cast<CustomIntType>()) {
dst_type = cit->get_physical_type();
} else if (auto cft = dst_type->cast<CustomFloatType>()) {
auto cit = cft->get_digits_type()->as<CustomIntType>();
dst_type = cit->get_physical_type();
} else if (stmt->val->ret_type != dst_type) {
TI_WARN("[{}] Atomic {} ({} to {}) may lose precision\n{}", stmt->name(),
atomic_op_type_name(stmt->op_type),
data_type_name(stmt->val->ret_type), data_type_name(dst_type),
stmt->tb);
stmt->val = insert_type_cast_before(stmt, stmt->val, dst_type);
}
stmt->ret_type = dst_type;
stmt->ret_type = type_check_store(
stmt, stmt->dest, stmt->val,
fmt::format("Atomic {}", atomic_op_type_name(stmt->op_type)));
}

void visit(LocalLoadStmt *stmt) override {
Expand Down Expand Up @@ -106,17 +117,8 @@ class TypeCheck : public IRVisitor {
// Infer data type for alloca
stmt->dest->ret_type = stmt->val->ret_type;
}
auto dst_value_type = stmt->dest->ret_type.ptr_removed();
if (dst_value_type != stmt->val->ret_type) {
auto promoted = promoted_type(dst_value_type, stmt->val->ret_type);
if (dst_value_type != promoted) {
TI_WARN("[{}] Local store may lose precision {} <- {}\n{}",
stmt->name(), dst_value_type->to_string(),
stmt->val->ret_data_type_name(), stmt->tb);
}
stmt->val = insert_type_cast_before(stmt, stmt->val, dst_value_type);
}
stmt->ret_type = dst_value_type;
stmt->ret_type =
type_check_store(stmt, stmt->dest, stmt->val, "Local store");
}

void visit(GlobalLoadStmt *stmt) override {
Expand Down Expand Up @@ -180,22 +182,7 @@ class TypeCheck : public IRVisitor {
}

void visit(GlobalStoreStmt *stmt) override {
auto dst_value_type = stmt->dest->ret_type.ptr_removed();
if (dst_value_type->is<CustomIntType>() ||
dst_value_type->is<CustomFloatType>()) {
// We force the value type to be the compute_type of the bit pointer.
// Casting from compute_type to physical_type is handled in codegen.
dst_value_type = dst_value_type->get_compute_type();
}
if (dst_value_type != stmt->val->ret_type) {
auto promoted = promoted_type(dst_value_type, stmt->val->ret_type);
if (dst_value_type != promoted) {
TI_WARN("[{}] Global store may lose precision: {} <- {}\n{}",
stmt->name(), dst_value_type->to_string(),
stmt->val->ret_data_type_name(), stmt->tb);
}
stmt->val = insert_type_cast_before(stmt, stmt->val, dst_value_type);
}
type_check_store(stmt, stmt->dest, stmt->val, "Global store");
}

void visit(RangeForStmt *stmt) override {
Expand Down

0 comments on commit af29e73

Please sign in to comment.