diff --git a/python/taichi/lang/simt/warp.py b/python/taichi/lang/simt/warp.py index 21a45329e0810..6075ce5d4eb10 100644 --- a/python/taichi/lang/simt/warp.py +++ b/python/taichi/lang/simt/warp.py @@ -14,9 +14,10 @@ def any_nonzero(mask, predicate): "cuda_any_sync_i32", expr.make_expr_group(mask, predicate), False)) -def unique(): - # TODO - pass +def unique(mask, predicate): + return expr.Expr( + _ti_core.insert_internal_func_call( + "cuda_uni_sync_i32", expr.make_expr_group(mask, predicate), False)) def ballot(predicate): diff --git a/taichi/llvm/llvm_context.cpp b/taichi/llvm/llvm_context.cpp index 5ae61e84e8b36..f71bf3715de7e 100644 --- a/taichi/llvm/llvm_context.cpp +++ b/taichi/llvm/llvm_context.cpp @@ -363,6 +363,9 @@ std::unique_ptr TaichiLLVMContext::clone_module( patch_intrinsic("cuda_any", Intrinsic::nvvm_vote_any); patch_intrinsic("cuda_any_sync", Intrinsic::nvvm_vote_any_sync); + patch_intrinsic("cuda_uni", Intrinsic::nvvm_vote_uni); + patch_intrinsic("cuda_uni_sync", Intrinsic::nvvm_vote_uni_sync); + patch_intrinsic("cuda_ballot", Intrinsic::nvvm_vote_ballot); patch_intrinsic("cuda_ballot_sync", Intrinsic::nvvm_vote_ballot_sync); diff --git a/taichi/runtime/llvm/runtime.cpp b/taichi/runtime/llvm/runtime.cpp index 162d312d867cc..a3ff2bfee596f 100644 --- a/taichi/runtime/llvm/runtime.cpp +++ b/taichi/runtime/llvm/runtime.cpp @@ -1064,6 +1064,14 @@ int32 cuda_any_sync_i32(u32 mask, int32 predicate) { return (int32)cuda_any_sync(mask, (bool)predicate); } +bool cuda_uni_sync(u32 mask, bool bit) { + return false; +} + +int32 cuda_uni_sync_i32(u32 mask, int32 predicate) { + return (int32)cuda_uni_sync(mask, (bool)predicate); +} + int32 cuda_ballot_sync(int32 mask, bool bit) { return 0; } diff --git a/tests/python/test_simt.py b/tests/python/test_simt.py index 2d7204b509d54..a29221d3735d6 100644 --- a/tests/python/test_simt.py +++ b/tests/python/test_simt.py @@ -64,8 +64,38 @@ def foo(): @test_utils.test(arch=ti.cuda) def test_unique(): - # TODO - pass + a = ti.field(dtype=ti.u32, shape=32) + b = ti.field(dtype=ti.u32, shape=32) + + @ti.kernel + def check(): + ti.loop_config(block_dim=32) + for i in range(32): + a[i] = ti.simt.warp.unique(ti.u32(0xFFFFFFFF), b[i]) + + for i in range(32): + b[i] = 0 + a[i] = -1 + + check() + + for i in range(32): + assert a[i] == 1 + + for i in range(32): + b[i] = i + 100 + + check() + + for i in range(32): + assert a[i] == 1 + + b[np.random.randint(0, 32)] = 0 + + check() + + for i in range(32): + assert a[i] == 0 @test_utils.test(arch=ti.cuda)