From 5d34d49308229464c3563d20ee95ffb216f7d0ab Mon Sep 17 00:00:00 2001 From: Cheng Cao Date: Thu, 14 Apr 2022 00:08:21 -0700 Subject: [PATCH 1/3] add 16 bit float immediate number --- taichi/codegen/spirv/spirv_ir_builder.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/taichi/codegen/spirv/spirv_ir_builder.cpp b/taichi/codegen/spirv/spirv_ir_builder.cpp index 1932ffc82f0b3..d4deed3a80851 100644 --- a/taichi/codegen/spirv/spirv_ir_builder.cpp +++ b/taichi/codegen/spirv/spirv_ir_builder.cpp @@ -214,6 +214,11 @@ Value IRBuilder::float_immediate_number(const SType &dtype, uint32_t *ptr = reinterpret_cast(&fvalue); uint64_t data = ptr[0]; return get_const(dtype, &data, cache); + } else if (data_type_bits(dtype.dt) == 16) { + float fvalue = static_cast(value); + uint16_t *ptr = reinterpret_cast(&fvalue); + uint64_t data = ptr[0]; + return get_const(dtype, &data, cache); } else { TI_ERROR("Type {} not supported.", dtype.dt->to_string()); } From c5efeaab5cdec8e54004abe2825db9efd18aa6ef Mon Sep 17 00:00:00 2001 From: Cheng Cao Date: Thu, 14 Apr 2022 10:28:43 -0700 Subject: [PATCH 2/3] Update f16 test --- tests/python/test_f16.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_f16.py b/tests/python/test_f16.py index 34705e1b72b16..de8fe9269b05c 100644 --- a/tests/python/test_f16.py +++ b/tests/python/test_f16.py @@ -213,7 +213,7 @@ def paint(t: float): c = ti.Vector([-0.8, ti.cos(t) * 0.2], dt=ti.f16) z = ti.Vector([ i / n - 1, j / n - 0.5 - ]) * 2 # FIXME: the kernel crashes when z stores f16 + ], dt=ti.f16) * 2 iterations = 0 while z.norm() < 20 and iterations < 50: z = complex_sqr(z) + c From f3f60123f43b6d76e9184a1754a8c9ab47c5f6ac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Apr 2022 17:29:43 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/python/test_f16.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/test_f16.py b/tests/python/test_f16.py index de8fe9269b05c..5f933aea8fd81 100644 --- a/tests/python/test_f16.py +++ b/tests/python/test_f16.py @@ -211,9 +211,7 @@ def complex_sqr(z): def paint(t: float): for i, j in pixels: # Parallelized over all pixels c = ti.Vector([-0.8, ti.cos(t) * 0.2], dt=ti.f16) - z = ti.Vector([ - i / n - 1, j / n - 0.5 - ], dt=ti.f16) * 2 + z = ti.Vector([i / n - 1, j / n - 0.5], dt=ti.f16) * 2 iterations = 0 while z.norm() < 20 and iterations < 50: z = complex_sqr(z) + c