Skip to content

Commit

Permalink
move expr unification to type check
Browse files Browse the repository at this point in the history
  • Loading branch information
AD1024 committed Aug 18, 2022
1 parent f3ff3d2 commit 39ed915
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 40 deletions.
52 changes: 39 additions & 13 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,37 @@ void UnaryOpExpression::flatten(FlattenContext *ctx) {
ctx->push_back(std::move(unary));
}

Expr to_broadcast_tensor(const Expr &elt, const DataType &dt) {
TI_ASSERT(dt->is<TensorType>());
auto tensor_type = dt->as<TensorType>();
auto elt_type = tensor_type->get_element_type();
TI_ASSERT_INFO(elt_type->is<PrimitiveType>(),
"Only primitive types are supported in Tensors, got {}",
elt_type->to_string());
std::vector<Expr> broadcast_values(tensor_type->get_num_elements(), elt);
return Expr::make<MatrixExpression>(broadcast_values,
tensor_type->get_shape(), elt_type);
}

std::tuple<Expr, Expr> unify_binop_operands(const Expr &e1, const Expr &e2) {
if ((!e1->ret_type->is<TensorType>() && !e2->ret_type->is<TensorType>()) ||
(e1->ret_type->is<TensorType>() && e2->ret_type->is<TensorType>())) {
return std::tuple(e1, e2);
}
if (!e1->ret_type->is<TensorType>()) {
return std::tuple(to_broadcast_tensor(e1, e2->ret_type), e2);
}
return std::tuple(e1, to_broadcast_tensor(e2, e1->ret_type));
}

void BinaryOpExpression::type_check(CompileConfig *config) {
auto [unified_l, unified_r] = unify_binop_operands(lhs, rhs);
lhs = unified_l;
rhs = unified_r;
if (lhs->ret_type == PrimitiveType::unknown)
lhs.type_check(config);
if (rhs->ret_type == PrimitiveType::unknown)
rhs.type_check(config);
TI_ASSERT_TYPE_CHECKED(lhs);
TI_ASSERT_TYPE_CHECKED(rhs);
auto lhs_type = lhs->ret_type;
Expand All @@ -194,22 +224,18 @@ void BinaryOpExpression::type_check(CompileConfig *config) {
};

if (lhs_type->is<TensorType>()) {
TI_ASSERT(rhs_type->is<TensorType>());
auto rhs_tensor_type = rhs_type->cast<TensorType>();
auto dtype = lhs_type->as<TensorType>()->get_element_type();
if (rhs_type->is<PrimitiveType>()) {
ret_type = promoted_type(dtype, rhs_type);
} else {
TI_ASSERT(rhs_type->is<TensorType>());
auto rhs_tensor_type = rhs_type->cast<TensorType>();
if (rhs_tensor_type->get_shape() !=
lhs_type->cast<TensorType>()->get_shape())
error();
auto rhs_elem_type = rhs_type->as<TensorType>()->get_element_type();
if (rhs_elem_type != PrimitiveType::unknown)
ret_type = promoted_type(dtype, rhs_elem_type);
}
if (rhs_tensor_type->get_shape() !=
lhs_type->cast<TensorType>()->get_shape())
error();
auto rhs_elem_type = rhs_type->as<TensorType>()->get_element_type();
if (rhs_elem_type != PrimitiveType::unknown)
dtype = promoted_type(dtype, rhs_elem_type);
// TODO: shape check!
ret_type = TypeFactory::create_tensor_type(
lhs_type->cast<TensorType>()->get_shape(), ret_type);
lhs_type->cast<TensorType>()->get_shape(), dtype);
return;
}

Expand Down
27 changes: 0 additions & 27 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,33 +379,6 @@ class BinaryOpExpression : public Expression {

BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs)
: type(type), lhs(lhs), rhs(rhs) {
auto to_broadcast_tensor = [](const Expr &elt, const DataType &dt) -> Expr {
TI_ASSERT(dt->is<TensorType>());
auto tensor_type = dt->as<TensorType>();
auto elt_type = tensor_type->get_element_type();
TI_ASSERT_INFO(elt_type->is<PrimitiveType>(),
"Only primitive types are supported in Tensors, got {}",
elt_type->to_string());
std::vector<Expr> broadcast_values(tensor_type->get_num_elements(), elt);
return Expr::make<MatrixExpression>(broadcast_values,
tensor_type->get_shape(), elt_type);
};

auto unify_expr = [&](const Expr &e1, const Expr &e2) {
if ((!e1->ret_type->is<TensorType>() &&
!e2->ret_type->is<TensorType>()) ||
(e1->ret_type->is<TensorType>() && e2->ret_type->is<TensorType>())) {
return std::tuple(e1, e2);
}
if (!e1->ret_type->is<TensorType>()) {
return std::tuple(to_broadcast_tensor(e1, e2->ret_type), e2);
}
return std::tuple(e1, to_broadcast_tensor(e2, e1->ret_type));
};

auto [unified_l, unified_r] = unify_expr(lhs, rhs);
this->lhs = unified_l;
this->rhs = unified_r;
}

void type_check(CompileConfig *config) override;
Expand Down

0 comments on commit 39ed915

Please sign in to comment.