diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 95acab8df2f0c..745e82e6dc752 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -481,6 +481,11 @@ void AtomicOpExpression::type_check(CompileConfig *) { void AtomicOpExpression::flatten(FlattenContext *ctx) { // replace atomic sub with negative atomic add if (op_type == AtomicOpType::sub) { + if (val->ret_type != ret_type) { + val.set(Expr::make(UnaryOpType::cast_value, val, + ret_type)); + } + val.set(Expr::make(UnaryOpType::neg, val)); op_type = AtomicOpType::add; } diff --git a/tests/python/test_atomic.py b/tests/python/test_atomic.py index 73896fcfa0dc5..0223e71e6f465 100644 --- a/tests/python/test_atomic.py +++ b/tests/python/test_atomic.py @@ -223,6 +223,47 @@ def test(): assert ret[None] == 1 +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_atomic_sub_with_type_promotion(): + # Test Case 1 + @ti.kernel + def test_u16_sub_u8() -> ti.uint16: + x: ti.uint16 = 1000 + y: ti.uint8 = 255 + + ti.atomic_sub(x, y) + return x + + res = test_u16_sub_u8() + assert res == 745 + + # Test Case 2 + @ti.kernel + def test_u8_sub_u16() -> ti.uint8: + x: ti.uint8 = 255 + y: ti.uint16 = 100 + + ti.atomic_sub(x, y) + return x + + res = test_u8_sub_u16() + assert res == 155 + + # Test Case 3 + A = ti.field(ti.uint8, shape=()) + B = ti.field(ti.uint16, shape=()) + + @ti.kernel + def test_with_field(): + v: ti.uint16 = 1000 + v -= A[None] + B[None] = v + + A[None] = 255 + test_with_field() + assert B[None] == 745 + + @test_utils.test() def test_atomic_sub_expr_evaled(): c = ti.field(ti.i32)