Skip to content

Commit

Permalink
[lang] [test] Fixed logical operation on numeric values and added sup…
Browse files Browse the repository at this point in the history
…port on real type

ghstack-source-id: 61e2efda51085405d6a5a6029f6dbeb052c8f9bb
Pull Request resolved: #8034
  • Loading branch information
listerily authored and Taichi Gardener committed May 18, 2023
1 parent 95c9f23 commit bb7afc3
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 10 deletions.
14 changes: 10 additions & 4 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -622,11 +622,17 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
llvm_val[stmt] =
builder->CreateSRem(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else if (op == BinaryOpType::logical_and) {
llvm_val[stmt] =
builder->CreateAnd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
auto *lhs = builder->CreateIsNotNull(llvm_val[stmt->lhs]);
auto *rhs = builder->CreateIsNotNull(llvm_val[stmt->rhs]);
llvm_val[stmt] = builder->CreateAnd(lhs, rhs);
llvm_val[stmt] = builder->CreateZExtOrTrunc(
llvm_val[stmt], tlctx->get_data_type(stmt->ret_type));
} else if (op == BinaryOpType::logical_or) {
llvm_val[stmt] =
builder->CreateOr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
auto *lhs = builder->CreateIsNotNull(llvm_val[stmt->lhs]);
auto *rhs = builder->CreateIsNotNull(llvm_val[stmt->rhs]);
llvm_val[stmt] = builder->CreateOr(lhs, rhs);
llvm_val[stmt] = builder->CreateZExtOrTrunc(
llvm_val[stmt], tlctx->get_data_type(stmt->ret_type));
} else if (op == BinaryOpType::bit_and) {
llvm_val[stmt] =
builder->CreateAnd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
Expand Down
21 changes: 16 additions & 5 deletions taichi/codegen/spirv/spirv_ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1111,11 +1111,22 @@ DEFINE_BUILDER_CMP_OP(ge, GreaterThanEqual);
DEFINE_BUILDER_CMP_UOP(eq, Equal);
DEFINE_BUILDER_CMP_UOP(ne, NotEqual);

#define DEFINE_BUILDER_LOGICAL_OP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
TI_ASSERT(a.stype.id == b.stype.id); \
TI_ASSERT(is_integral(a.stype.dt)); \
return make_value(spv::OpLogical##_Op, t_bool_, a, b); \
#define DEFINE_BUILDER_LOGICAL_OP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
TI_ASSERT(a.stype.id == b.stype.id); \
if (a.stype.id == t_bool_.id) { \
return make_value(spv::OpLogical##_Op, t_bool_, a, b); \
} else if (is_integral(a.stype.dt)) { \
Value val_a = make_value(spv::OpINotEqual, t_bool_, a, \
int_immediate_number(a.stype, 0)); \
Value val_b = make_value(spv::OpINotEqual, t_bool_, b, \
int_immediate_number(b.stype, 0)); \
Value val_ret = make_value(spv::OpLogical##_Op, t_bool_, val_a, val_b); \
return cast(a.stype, val_ret); \
} else { \
TI_ERROR("Logical ops on real types are not supported."); \
return Value(); \
} \
}

DEFINE_BUILDER_LOGICAL_OP(logical_and, And);
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ void BinaryOpExpression::type_check(const CompileConfig *config) {
if (binary_is_logical(type) && !(is_integral(lhs_type.get_element_type()) &&
is_integral(rhs_type.get_element_type())))
error();
if (is_comparison(type) || binary_is_logical(type)) {
if (is_comparison(type)) {
ret_type = make_dt(PrimitiveType::u1);
return;
}
Expand Down
89 changes: 89 additions & 0 deletions tests/python/test_logical_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import taichi as ti
from tests import test_utils


@test_utils.test(debug=True)
def test_logical_and_i32():
@ti.kernel
def func(x: ti.i32, y: ti.i32) -> ti.i32:
return x and y

assert func(1, 2) == 2
assert func(2, 1) == 1
assert func(0, 1) == 0
assert func(1, 0) == 0


@test_utils.test(debug=True)
def test_logical_or_i32():
@ti.kernel
def func(x: ti.i32, y: ti.i32) -> ti.i32:
return x or y

assert func(1, 2) == 1
assert func(2, 1) == 2
assert func(1, 0) == 1
assert func(0, 1) == 1


@test_utils.test(debug=True)
def test_logical_vec_i32():
vec4d = ti.types.vector(4, ti.i32)

@ti.kernel
def p() -> vec4d:
a = ti.Vector([2, 2, 0, 0])
b = ti.Vector([1, 0, 1, 0])
z = a or b
return z

@ti.kernel
def q() -> vec4d:
a = ti.Vector([2, 2, 0, 0])
b = ti.Vector([1, 0, 1, 0])
z = a and b
return z

x = p()
y = q()

assert x[0] == 1
assert x[1] == 1
assert x[2] == 1
assert x[3] == 0
assert y[0] == 1
assert y[1] == 0
assert y[2] == 0
assert y[3] == 0


# FIXME: bool vectors not supported on spir-v
@test_utils.test(arch=[ti.cpu, ti.cuda], debug=True)
def test_logical_vec_bool():
vec4d = ti.types.vector(4, ti.u1)

@ti.kernel
def p() -> vec4d:
a = ti.Vector([True, True, False, False])
b = ti.Vector([True, False, True, False])
z = a or b
return z

@ti.kernel
def q() -> vec4d:
a = ti.Vector([True, True, False, False])
b = ti.Vector([True, False, True, False])
z = a and b
return z

x = p()
y = q()

assert x[0] == True
assert x[1] == True
assert x[2] == True
assert x[3] == False
assert y[0] == True
assert y[1] == False
assert y[2] == False
assert y[3] == False

0 comments on commit bb7afc3

Please sign in to comment.