Skip to content

Commit

Permalink
[Bug] [type] Fix wrong type cast in codegen of storing quant floats (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
strongoier authored Aug 18, 2022
1 parent ce86a1d commit f4305e7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions taichi/codegen/llvm/codegen_llvm_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ void TaskCodeGenLLVM::visit(BitStructStoreStmt *stmt) {
create_call("max_i32", {exponent_bits, tlctx->get_constant(0)});

// Compute the bit pointer of the exponent bits.
val = builder->CreateBitCast(exponent_bits, physical_type);
val = builder->CreateIntCast(exponent_bits, physical_type, false);
val = builder->CreateShl(val, bit_struct->get_member_bit_offset(exp));

if (bit_struct_val == nullptr) {
Expand All @@ -238,7 +238,7 @@ void TaskCodeGenLLVM::visit(BitStructStoreStmt *stmt) {
tlctx->get_constant(0));
val = builder->CreateSelect(exp_non_zero, digit_bits,
tlctx->get_constant(0));
val = builder->CreateBitCast(val, physical_type);
val = builder->CreateIntCast(val, physical_type, false);
val = builder->CreateShl(val, bit_struct->get_member_bit_offset(ch_id));
} else {
val = quant_int_or_quant_fixed_to_bits(val, dtype, physical_type);
Expand Down
5 changes: 3 additions & 2 deletions tests/python/test_quant_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from tests import test_utils


@pytest.mark.parametrize('max_num_bits', [32, 64])
@test_utils.test(require=ti.extension.quant)
def test_quant_float_unsigned():
def test_quant_float_unsigned(max_num_bits):
qflt = ti.types.quant.float(exp=6, frac=13, signed=False)
x = ti.field(dtype=qflt)

bitpack = ti.BitpackedFields(max_num_bits=32)
bitpack = ti.BitpackedFields(max_num_bits=max_num_bits)
bitpack.place(x)
ti.root.place(bitpack)

Expand Down

0 comments on commit f4305e7

Please sign in to comment.